mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
domaininfo: New package to track domain (security) information
This patch introduces a new "domaininfo" package, which implements a database with information about domains. In particular, it tracks incoming and outgoing security levels. That information is used in incoming and outgoing SMTP to prevent downgrades.
This commit is contained in:
133
internal/domaininfo/domaininfo.go
Normal file
133
internal/domaininfo/domaininfo.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Package domaininfo implements a domain information database, to keep track
|
||||
// of things we know about a particular domain.
|
||||
package domaininfo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/protoio"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
)
|
||||
|
||||
// Command to generate domaininfo.pb.go.
|
||||
//go:generate protoc --go_out=. domaininfo.proto
|
||||
|
||||
type DB struct {
|
||||
// Persistent store with the list of domains we know.
|
||||
store *protoio.Store
|
||||
|
||||
info map[string]*Domain
|
||||
sync.Mutex
|
||||
|
||||
ev *trace.EventLog
|
||||
}
|
||||
|
||||
func New(dir string) (*DB, error) {
|
||||
st, err := protoio.NewStore(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &DB{
|
||||
store: st,
|
||||
info: map[string]*Domain{},
|
||||
}
|
||||
l.ev = trace.NewEventLog("DomainInfo", fmt.Sprintf("%p", l))
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Load the database from disk; should be called once at initialization.
|
||||
func (db *DB) Load() error {
|
||||
ids, err := db.store.ListIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
d := &Domain{}
|
||||
_, err := db.store.Get(id, d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading %q: %v", id, err)
|
||||
}
|
||||
|
||||
db.info[d.Name] = d
|
||||
}
|
||||
|
||||
db.ev.Debugf("loaded: %s", ids)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) write(d *Domain) {
|
||||
err := db.store.Put(d.Name, d)
|
||||
if err != nil {
|
||||
db.ev.Errorf("%s error saving: %v", d.Name, err)
|
||||
} else {
|
||||
db.ev.Debugf("%s saved", d.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// IncomingSecLevel checks an incoming security level for the domain.
|
||||
// Returns true if allowed, false otherwise.
|
||||
func (db *DB) IncomingSecLevel(domain string, level SecLevel) bool {
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
d, exists := db.info[domain]
|
||||
if !exists {
|
||||
d = &Domain{Name: domain}
|
||||
db.info[domain] = d
|
||||
defer db.write(d)
|
||||
}
|
||||
|
||||
if level < d.IncomingSecLevel {
|
||||
db.ev.Errorf("%s incoming denied: %s < %s",
|
||||
d.Name, level, d.IncomingSecLevel)
|
||||
return false
|
||||
} else if level == d.IncomingSecLevel {
|
||||
db.ev.Debugf("%s incoming allowed: %s == %s",
|
||||
d.Name, level, d.IncomingSecLevel)
|
||||
return true
|
||||
} else {
|
||||
db.ev.Printf("%s incoming level raised: %s > %s",
|
||||
d.Name, level, d.IncomingSecLevel)
|
||||
d.IncomingSecLevel = level
|
||||
if exists {
|
||||
defer db.write(d)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// OutgoingSecLevel checks an incoming security level for the domain.
|
||||
// Returns true if allowed, false otherwise.
|
||||
func (db *DB) OutgoingSecLevel(domain string, level SecLevel) bool {
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
d, exists := db.info[domain]
|
||||
if !exists {
|
||||
d = &Domain{Name: domain}
|
||||
db.info[domain] = d
|
||||
defer db.write(d)
|
||||
}
|
||||
|
||||
if level < d.OutgoingSecLevel {
|
||||
db.ev.Errorf("%s outgoing denied: %s < %s",
|
||||
d.Name, level, d.OutgoingSecLevel)
|
||||
return false
|
||||
} else if level == d.OutgoingSecLevel {
|
||||
db.ev.Debugf("%s outgoing allowed: %s == %s",
|
||||
d.Name, level, d.OutgoingSecLevel)
|
||||
return true
|
||||
} else {
|
||||
db.ev.Printf("%s outgoing level raised: %s > %s",
|
||||
d.Name, level, d.OutgoingSecLevel)
|
||||
d.OutgoingSecLevel = level
|
||||
if exists {
|
||||
defer db.write(d)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
96
internal/domaininfo/domaininfo.pb.go
Normal file
96
internal/domaininfo/domaininfo.pb.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Code generated by protoc-gen-go.
|
||||
// source: domaininfo.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/*
|
||||
Package domaininfo is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
domaininfo.proto
|
||||
|
||||
It has these top-level messages:
|
||||
Domain
|
||||
*/
|
||||
package domaininfo
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type SecLevel int32
|
||||
|
||||
const (
|
||||
// Does not do TLS.
|
||||
SecLevel_PLAIN SecLevel = 0
|
||||
// TLS client connection (no certificate validation).
|
||||
SecLevel_TLS_CLIENT SecLevel = 1
|
||||
// TLS, but with invalid certificates.
|
||||
SecLevel_TLS_INSECURE SecLevel = 2
|
||||
// TLS, with valid certificates.
|
||||
SecLevel_TLS_SECURE SecLevel = 3
|
||||
)
|
||||
|
||||
var SecLevel_name = map[int32]string{
|
||||
0: "PLAIN",
|
||||
1: "TLS_CLIENT",
|
||||
2: "TLS_INSECURE",
|
||||
3: "TLS_SECURE",
|
||||
}
|
||||
var SecLevel_value = map[string]int32{
|
||||
"PLAIN": 0,
|
||||
"TLS_CLIENT": 1,
|
||||
"TLS_INSECURE": 2,
|
||||
"TLS_SECURE": 3,
|
||||
}
|
||||
|
||||
func (x SecLevel) String() string {
|
||||
return proto.EnumName(SecLevel_name, int32(x))
|
||||
}
|
||||
func (SecLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
type Domain struct {
|
||||
Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
|
||||
// Security level for mail coming from this domain (they send to us).
|
||||
IncomingSecLevel SecLevel `protobuf:"varint,2,opt,name=incoming_sec_level,json=incomingSecLevel,enum=domaininfo.SecLevel" json:"incoming_sec_level,omitempty"`
|
||||
// Security level for mail going to this domain (we send to them).
|
||||
OutgoingSecLevel SecLevel `protobuf:"varint,3,opt,name=outgoing_sec_level,json=outgoingSecLevel,enum=domaininfo.SecLevel" json:"outgoing_sec_level,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Domain) Reset() { *m = Domain{} }
|
||||
func (m *Domain) String() string { return proto.CompactTextString(m) }
|
||||
func (*Domain) ProtoMessage() {}
|
||||
func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Domain)(nil), "domaininfo.Domain")
|
||||
proto.RegisterEnum("domaininfo.SecLevel", SecLevel_name, SecLevel_value)
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("domaininfo.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 189 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xc9, 0xcf, 0x4d,
|
||||
0xcc, 0xcc, 0xcb, 0xcc, 0x4b, 0xcb, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88,
|
||||
0x28, 0x2d, 0x61, 0xe4, 0x62, 0x73, 0x01, 0x73, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53,
|
||||
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x21, 0x27, 0x2e, 0xa1, 0xcc, 0xbc, 0xe4,
|
||||
0xfc, 0xdc, 0xcc, 0xbc, 0xf4, 0xf8, 0xe2, 0xd4, 0xe4, 0xf8, 0x9c, 0xd4, 0xb2, 0xd4, 0x1c, 0x09,
|
||||
0x26, 0xa0, 0x0a, 0x3e, 0x23, 0x11, 0x3d, 0x24, 0x93, 0x83, 0x53, 0x93, 0x7d, 0x40, 0x72, 0x41,
|
||||
0x02, 0x30, 0xf5, 0x30, 0x11, 0x90, 0x19, 0xf9, 0xa5, 0x25, 0xe9, 0xf9, 0xa8, 0x66, 0x30, 0xe3,
|
||||
0x33, 0x03, 0xa6, 0x1e, 0x26, 0xa2, 0xe5, 0xce, 0xc5, 0x01, 0x37, 0x8f, 0x93, 0x8b, 0x35, 0xc0,
|
||||
0xc7, 0xd1, 0xd3, 0x4f, 0x80, 0x41, 0x88, 0x8f, 0x8b, 0x2b, 0xc4, 0x27, 0x38, 0xde, 0xd9, 0xc7,
|
||||
0xd3, 0xd5, 0x2f, 0x44, 0x80, 0x51, 0x48, 0x80, 0x8b, 0x07, 0xc4, 0xf7, 0xf4, 0x0b, 0x76, 0x75,
|
||||
0x0e, 0x0d, 0x72, 0x15, 0x60, 0x82, 0xa9, 0x80, 0xf2, 0x99, 0x93, 0xd8, 0xc0, 0x41, 0x60, 0x0c,
|
||||
0x08, 0x00, 0x00, 0xff, 0xff, 0x2c, 0x78, 0x65, 0x5b, 0x16, 0x01, 0x00, 0x00,
|
||||
}
|
||||
28
internal/domaininfo/domaininfo.proto
Normal file
28
internal/domaininfo/domaininfo.proto
Normal file
@@ -0,0 +1,28 @@
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package domaininfo;
|
||||
|
||||
enum SecLevel {
|
||||
// Does not do TLS.
|
||||
PLAIN = 0;
|
||||
|
||||
// TLS client connection (no certificate validation).
|
||||
TLS_CLIENT = 1;
|
||||
|
||||
// TLS, but with invalid certificates.
|
||||
TLS_INSECURE = 2;
|
||||
|
||||
// TLS, with valid certificates.
|
||||
TLS_SECURE = 3;
|
||||
}
|
||||
|
||||
message Domain {
|
||||
string name = 1;
|
||||
|
||||
// Security level for mail coming from this domain (they send to us).
|
||||
SecLevel incoming_sec_level = 2;
|
||||
|
||||
// Security level for mail going to this domain (we send to them).
|
||||
SecLevel outgoing_sec_level = 3;
|
||||
}
|
||||
133
internal/domaininfo/domaininfo_test.go
Normal file
133
internal/domaininfo/domaininfo_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package domaininfo
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func mustTempDir(t *testing.T) string {
|
||||
dir, err := ioutil.TempDir("", "greylisting_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("test directory: %q", dir)
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
db, err := New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Load(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !db.IncomingSecLevel("d1", SecLevel_PLAIN) {
|
||||
t.Errorf("new domain as plain not allowed")
|
||||
}
|
||||
if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) {
|
||||
t.Errorf("increment to tls-secure not allowed")
|
||||
}
|
||||
if db.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
|
||||
t.Errorf("decrement to tls-insecure was allowed")
|
||||
}
|
||||
|
||||
// Wait until it is written to disk.
|
||||
for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); {
|
||||
d := &Domain{}
|
||||
ok, _ := db.store.Get("d1", d)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Check that it was added to the store and a new db sees it.
|
||||
db2, err := New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db2.Load(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
|
||||
t.Errorf("decrement to tls-insecure was allowed in new DB")
|
||||
}
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDomain(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
db, err := New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
domain string
|
||||
level SecLevel
|
||||
}{
|
||||
{"plain", SecLevel_PLAIN},
|
||||
{"insecure", SecLevel_TLS_INSECURE},
|
||||
{"secure", SecLevel_TLS_SECURE},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if !db.IncomingSecLevel(c.domain, c.level) {
|
||||
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
|
||||
}
|
||||
if !db.OutgoingSecLevel(c.domain, c.level) {
|
||||
t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
|
||||
}
|
||||
}
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressions(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
db, err := New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
domain string
|
||||
lvl SecLevel
|
||||
ok bool
|
||||
}{
|
||||
{"pisis", SecLevel_PLAIN, true},
|
||||
{"pisis", SecLevel_TLS_INSECURE, true},
|
||||
{"pisis", SecLevel_TLS_SECURE, true},
|
||||
{"pisis", SecLevel_TLS_INSECURE, false},
|
||||
{"pisis", SecLevel_TLS_SECURE, true},
|
||||
|
||||
{"ssip", SecLevel_TLS_SECURE, true},
|
||||
{"ssip", SecLevel_TLS_SECURE, true},
|
||||
{"ssip", SecLevel_TLS_INSECURE, false},
|
||||
{"ssip", SecLevel_PLAIN, false},
|
||||
}
|
||||
for i, c := range cases {
|
||||
if ok := db.IncomingSecLevel(c.domain, c.lvl); ok != c.ok {
|
||||
t.Errorf("%2d %q in attempt for %s failed: got %v, expected %v",
|
||||
i, c.domain, c.lvl, ok, c.ok)
|
||||
}
|
||||
if ok := db.OutgoingSecLevel(c.domain, c.lvl); ok != c.ok {
|
||||
t.Errorf("%2d %q out attempt for %s failed: got %v, expected %v",
|
||||
i, c.domain, c.lvl, ok, c.ok)
|
||||
}
|
||||
}
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user