1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-17 14:37:02 +00:00

domaininfo: New package to track domain (security) information

This patch introduces a new "domaininfo" package, which implements a
database with information about domains.  In particular, it tracks
incoming and outgoing security levels.

That information is used in incoming and outgoing SMTP to prevent
downgrades.
This commit is contained in:
Alberto Bertogli
2016-10-13 02:28:30 +01:00
parent 1d7a207e00
commit c013c98283
8 changed files with 545 additions and 11 deletions

View File

@@ -0,0 +1,133 @@
// Package domaininfo implements a domain information database, to keep track
// of things we know about a particular domain.
package domaininfo
import (
"fmt"
"sync"
"blitiri.com.ar/go/chasquid/internal/protoio"
"blitiri.com.ar/go/chasquid/internal/trace"
)
// Command to generate domaininfo.pb.go.
//go:generate protoc --go_out=. domaininfo.proto
type DB struct {
// Persistent store with the list of domains we know.
store *protoio.Store
info map[string]*Domain
sync.Mutex
ev *trace.EventLog
}
func New(dir string) (*DB, error) {
st, err := protoio.NewStore(dir)
if err != nil {
return nil, err
}
l := &DB{
store: st,
info: map[string]*Domain{},
}
l.ev = trace.NewEventLog("DomainInfo", fmt.Sprintf("%p", l))
return l, nil
}
// Load the database from disk; should be called once at initialization.
func (db *DB) Load() error {
ids, err := db.store.ListIDs()
if err != nil {
return err
}
for _, id := range ids {
d := &Domain{}
_, err := db.store.Get(id, d)
if err != nil {
return fmt.Errorf("error loading %q: %v", id, err)
}
db.info[d.Name] = d
}
db.ev.Debugf("loaded: %s", ids)
return nil
}
func (db *DB) write(d *Domain) {
err := db.store.Put(d.Name, d)
if err != nil {
db.ev.Errorf("%s error saving: %v", d.Name, err)
} else {
db.ev.Debugf("%s saved", d.Name)
}
}
// IncomingSecLevel checks an incoming security level for the domain.
// Returns true if allowed, false otherwise.
func (db *DB) IncomingSecLevel(domain string, level SecLevel) bool {
db.Lock()
defer db.Unlock()
d, exists := db.info[domain]
if !exists {
d = &Domain{Name: domain}
db.info[domain] = d
defer db.write(d)
}
if level < d.IncomingSecLevel {
db.ev.Errorf("%s incoming denied: %s < %s",
d.Name, level, d.IncomingSecLevel)
return false
} else if level == d.IncomingSecLevel {
db.ev.Debugf("%s incoming allowed: %s == %s",
d.Name, level, d.IncomingSecLevel)
return true
} else {
db.ev.Printf("%s incoming level raised: %s > %s",
d.Name, level, d.IncomingSecLevel)
d.IncomingSecLevel = level
if exists {
defer db.write(d)
}
return true
}
}
// OutgoingSecLevel checks an incoming security level for the domain.
// Returns true if allowed, false otherwise.
func (db *DB) OutgoingSecLevel(domain string, level SecLevel) bool {
db.Lock()
defer db.Unlock()
d, exists := db.info[domain]
if !exists {
d = &Domain{Name: domain}
db.info[domain] = d
defer db.write(d)
}
if level < d.OutgoingSecLevel {
db.ev.Errorf("%s outgoing denied: %s < %s",
d.Name, level, d.OutgoingSecLevel)
return false
} else if level == d.OutgoingSecLevel {
db.ev.Debugf("%s outgoing allowed: %s == %s",
d.Name, level, d.OutgoingSecLevel)
return true
} else {
db.ev.Printf("%s outgoing level raised: %s > %s",
d.Name, level, d.OutgoingSecLevel)
d.OutgoingSecLevel = level
if exists {
defer db.write(d)
}
return true
}
}

View File

@@ -0,0 +1,96 @@
// Code generated by protoc-gen-go.
// source: domaininfo.proto
// DO NOT EDIT!
/*
Package domaininfo is a generated protocol buffer package.
It is generated from these files:
domaininfo.proto
It has these top-level messages:
Domain
*/
package domaininfo
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type SecLevel int32
const (
// Does not do TLS.
SecLevel_PLAIN SecLevel = 0
// TLS client connection (no certificate validation).
SecLevel_TLS_CLIENT SecLevel = 1
// TLS, but with invalid certificates.
SecLevel_TLS_INSECURE SecLevel = 2
// TLS, with valid certificates.
SecLevel_TLS_SECURE SecLevel = 3
)
var SecLevel_name = map[int32]string{
0: "PLAIN",
1: "TLS_CLIENT",
2: "TLS_INSECURE",
3: "TLS_SECURE",
}
var SecLevel_value = map[string]int32{
"PLAIN": 0,
"TLS_CLIENT": 1,
"TLS_INSECURE": 2,
"TLS_SECURE": 3,
}
func (x SecLevel) String() string {
return proto.EnumName(SecLevel_name, int32(x))
}
func (SecLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
type Domain struct {
Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
// Security level for mail coming from this domain (they send to us).
IncomingSecLevel SecLevel `protobuf:"varint,2,opt,name=incoming_sec_level,json=incomingSecLevel,enum=domaininfo.SecLevel" json:"incoming_sec_level,omitempty"`
// Security level for mail going to this domain (we send to them).
OutgoingSecLevel SecLevel `protobuf:"varint,3,opt,name=outgoing_sec_level,json=outgoingSecLevel,enum=domaininfo.SecLevel" json:"outgoing_sec_level,omitempty"`
}
func (m *Domain) Reset() { *m = Domain{} }
func (m *Domain) String() string { return proto.CompactTextString(m) }
func (*Domain) ProtoMessage() {}
func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func init() {
proto.RegisterType((*Domain)(nil), "domaininfo.Domain")
proto.RegisterEnum("domaininfo.SecLevel", SecLevel_name, SecLevel_value)
}
func init() { proto.RegisterFile("domaininfo.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 189 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xc9, 0xcf, 0x4d,
0xcc, 0xcc, 0xcb, 0xcc, 0x4b, 0xcb, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88,
0x28, 0x2d, 0x61, 0xe4, 0x62, 0x73, 0x01, 0x73, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53,
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x21, 0x27, 0x2e, 0xa1, 0xcc, 0xbc, 0xe4,
0xfc, 0xdc, 0xcc, 0xbc, 0xf4, 0xf8, 0xe2, 0xd4, 0xe4, 0xf8, 0x9c, 0xd4, 0xb2, 0xd4, 0x1c, 0x09,
0x26, 0xa0, 0x0a, 0x3e, 0x23, 0x11, 0x3d, 0x24, 0x93, 0x83, 0x53, 0x93, 0x7d, 0x40, 0x72, 0x41,
0x02, 0x30, 0xf5, 0x30, 0x11, 0x90, 0x19, 0xf9, 0xa5, 0x25, 0xe9, 0xf9, 0xa8, 0x66, 0x30, 0xe3,
0x33, 0x03, 0xa6, 0x1e, 0x26, 0xa2, 0xe5, 0xce, 0xc5, 0x01, 0x37, 0x8f, 0x93, 0x8b, 0x35, 0xc0,
0xc7, 0xd1, 0xd3, 0x4f, 0x80, 0x41, 0x88, 0x8f, 0x8b, 0x2b, 0xc4, 0x27, 0x38, 0xde, 0xd9, 0xc7,
0xd3, 0xd5, 0x2f, 0x44, 0x80, 0x51, 0x48, 0x80, 0x8b, 0x07, 0xc4, 0xf7, 0xf4, 0x0b, 0x76, 0x75,
0x0e, 0x0d, 0x72, 0x15, 0x60, 0x82, 0xa9, 0x80, 0xf2, 0x99, 0x93, 0xd8, 0xc0, 0x41, 0x60, 0x0c,
0x08, 0x00, 0x00, 0xff, 0xff, 0x2c, 0x78, 0x65, 0x5b, 0x16, 0x01, 0x00, 0x00,
}

View File

@@ -0,0 +1,28 @@
syntax = "proto3";
package domaininfo;
enum SecLevel {
// Does not do TLS.
PLAIN = 0;
// TLS client connection (no certificate validation).
TLS_CLIENT = 1;
// TLS, but with invalid certificates.
TLS_INSECURE = 2;
// TLS, with valid certificates.
TLS_SECURE = 3;
}
message Domain {
string name = 1;
// Security level for mail coming from this domain (they send to us).
SecLevel incoming_sec_level = 2;
// Security level for mail going to this domain (we send to them).
SecLevel outgoing_sec_level = 3;
}

View File

@@ -0,0 +1,133 @@
package domaininfo
import (
"io/ioutil"
"os"
"testing"
"time"
)
func mustTempDir(t *testing.T) string {
dir, err := ioutil.TempDir("", "greylisting_test")
if err != nil {
t.Fatal(err)
}
t.Logf("test directory: %q", dir)
return dir
}
func TestBasic(t *testing.T) {
dir := mustTempDir(t)
db, err := New(dir)
if err != nil {
t.Fatal(err)
}
if err := db.Load(); err != nil {
t.Fatal(err)
}
if !db.IncomingSecLevel("d1", SecLevel_PLAIN) {
t.Errorf("new domain as plain not allowed")
}
if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) {
t.Errorf("increment to tls-secure not allowed")
}
if db.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
t.Errorf("decrement to tls-insecure was allowed")
}
// Wait until it is written to disk.
for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); {
d := &Domain{}
ok, _ := db.store.Get("d1", d)
if ok {
break
}
time.Sleep(50 * time.Millisecond)
}
// Check that it was added to the store and a new db sees it.
db2, err := New(dir)
if err != nil {
t.Fatal(err)
}
if err := db2.Load(); err != nil {
t.Fatal(err)
}
if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
t.Errorf("decrement to tls-insecure was allowed in new DB")
}
if !t.Failed() {
os.RemoveAll(dir)
}
}
func TestNewDomain(t *testing.T) {
dir := mustTempDir(t)
db, err := New(dir)
if err != nil {
t.Fatal(err)
}
cases := []struct {
domain string
level SecLevel
}{
{"plain", SecLevel_PLAIN},
{"insecure", SecLevel_TLS_INSECURE},
{"secure", SecLevel_TLS_SECURE},
}
for _, c := range cases {
if !db.IncomingSecLevel(c.domain, c.level) {
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
}
if !db.OutgoingSecLevel(c.domain, c.level) {
t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
}
}
if !t.Failed() {
os.RemoveAll(dir)
}
}
func TestProgressions(t *testing.T) {
dir := mustTempDir(t)
db, err := New(dir)
if err != nil {
t.Fatal(err)
}
cases := []struct {
domain string
lvl SecLevel
ok bool
}{
{"pisis", SecLevel_PLAIN, true},
{"pisis", SecLevel_TLS_INSECURE, true},
{"pisis", SecLevel_TLS_SECURE, true},
{"pisis", SecLevel_TLS_INSECURE, false},
{"pisis", SecLevel_TLS_SECURE, true},
{"ssip", SecLevel_TLS_SECURE, true},
{"ssip", SecLevel_TLS_SECURE, true},
{"ssip", SecLevel_TLS_INSECURE, false},
{"ssip", SecLevel_PLAIN, false},
}
for i, c := range cases {
if ok := db.IncomingSecLevel(c.domain, c.lvl); ok != c.ok {
t.Errorf("%2d %q in attempt for %s failed: got %v, expected %v",
i, c.domain, c.lvl, ok, c.ok)
}
if ok := db.OutgoingSecLevel(c.domain, c.lvl); ok != c.ok {
t.Errorf("%2d %q out attempt for %s failed: got %v, expected %v",
i, c.domain, c.lvl, ok, c.ok)
}
}
if !t.Failed() {
os.RemoveAll(dir)
}
}