mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
domaininfo: New package to track domain (security) information
This patch introduces a new "domaininfo" package, which implements a database with information about domains. In particular, it tracks incoming and outgoing security levels. That information is used in incoming and outgoing SMTP to prevent downgrades.
This commit is contained in:
60
chasquid.go
60
chasquid.go
@@ -24,6 +24,7 @@ import (
|
|||||||
"blitiri.com.ar/go/chasquid/internal/auth"
|
"blitiri.com.ar/go/chasquid/internal/auth"
|
||||||
"blitiri.com.ar/go/chasquid/internal/config"
|
"blitiri.com.ar/go/chasquid/internal/config"
|
||||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||||
"blitiri.com.ar/go/chasquid/internal/envelope"
|
"blitiri.com.ar/go/chasquid/internal/envelope"
|
||||||
"blitiri.com.ar/go/chasquid/internal/normalize"
|
"blitiri.com.ar/go/chasquid/internal/normalize"
|
||||||
"blitiri.com.ar/go/chasquid/internal/queue"
|
"blitiri.com.ar/go/chasquid/internal/queue"
|
||||||
@@ -53,6 +54,7 @@ var (
|
|||||||
spfResultCount = expvar.NewMap("chasquid/smtpIn/spfResultCount")
|
spfResultCount = expvar.NewMap("chasquid/smtpIn/spfResultCount")
|
||||||
loopsDetected = expvar.NewInt("chasquid/smtpIn/loopsDetected")
|
loopsDetected = expvar.NewInt("chasquid/smtpIn/loopsDetected")
|
||||||
tlsCount = expvar.NewMap("chasquid/smtpIn/tlsCount")
|
tlsCount = expvar.NewMap("chasquid/smtpIn/tlsCount")
|
||||||
|
slcResults = expvar.NewMap("chasquid/smtpIn/securityLevelChecks")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Global event logs.
|
// Global event logs.
|
||||||
@@ -133,12 +135,14 @@ func main() {
|
|||||||
// as a remote domain (for loops, alias resolutions, etc.).
|
// as a remote domain (for loops, alias resolutions, etc.).
|
||||||
s.AddDomain("localhost")
|
s.AddDomain("localhost")
|
||||||
|
|
||||||
|
s.InitDomainInfo(conf.DataDir + "/domaininfo")
|
||||||
|
|
||||||
localC := &courier.Procmail{
|
localC := &courier.Procmail{
|
||||||
Binary: conf.MailDeliveryAgentBin,
|
Binary: conf.MailDeliveryAgentBin,
|
||||||
Args: conf.MailDeliveryAgentArgs,
|
Args: conf.MailDeliveryAgentArgs,
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
remoteC := &courier.SMTP{}
|
remoteC := &courier.SMTP{Dinfo: s.dinfo}
|
||||||
s.InitQueue(conf.DataDir+"/queue", localC, remoteC)
|
s.InitQueue(conf.DataDir+"/queue", localC, remoteC)
|
||||||
|
|
||||||
go s.periodicallyReload()
|
go s.periodicallyReload()
|
||||||
@@ -266,6 +270,9 @@ type Server struct {
|
|||||||
// Aliases resolver.
|
// Aliases resolver.
|
||||||
aliasesR *aliases.Resolver
|
aliasesR *aliases.Resolver
|
||||||
|
|
||||||
|
// Domain info database.
|
||||||
|
dinfo *domaininfo.DB
|
||||||
|
|
||||||
// Time before we give up on a connection, even if it's sending data.
|
// Time before we give up on a connection, even if it's sending data.
|
||||||
connTimeout time.Duration
|
connTimeout time.Duration
|
||||||
|
|
||||||
@@ -314,6 +321,19 @@ func (s *Server) AddUserDB(domain string, db *userdb.DB) {
|
|||||||
s.userDBs[domain] = db
|
s.userDBs[domain] = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) InitDomainInfo(dir string) {
|
||||||
|
var err error
|
||||||
|
s.dinfo, err = domaininfo.New(dir)
|
||||||
|
if err != nil {
|
||||||
|
glog.Fatalf("Error opening domain info database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.dinfo.Load()
|
||||||
|
if err != nil {
|
||||||
|
glog.Fatalf("Error loading domain info database: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) InitQueue(path string, localC, remoteC courier.Courier) {
|
func (s *Server) InitQueue(path string, localC, remoteC courier.Courier) {
|
||||||
q := queue.New(path, s.localDomains, s.aliasesR, localC, remoteC)
|
q := queue.New(path, s.localDomains, s.aliasesR, localC, remoteC)
|
||||||
err := q.Load()
|
err := q.Load()
|
||||||
@@ -399,6 +419,7 @@ func (s *Server) serve(l net.Listener, mode SocketMode) {
|
|||||||
userDBs: s.userDBs,
|
userDBs: s.userDBs,
|
||||||
aliasesR: s.aliasesR,
|
aliasesR: s.aliasesR,
|
||||||
localDomains: s.localDomains,
|
localDomains: s.localDomains,
|
||||||
|
dinfo: s.dinfo,
|
||||||
deadline: time.Now().Add(s.connTimeout),
|
deadline: time.Now().Add(s.connTimeout),
|
||||||
commandTimeout: s.commandTimeout,
|
commandTimeout: s.commandTimeout,
|
||||||
queue: s.queue,
|
queue: s.queue,
|
||||||
@@ -449,6 +470,7 @@ type Conn struct {
|
|||||||
userDBs map[string]*userdb.DB
|
userDBs map[string]*userdb.DB
|
||||||
localDomains *set.String
|
localDomains *set.String
|
||||||
aliasesR *aliases.Resolver
|
aliasesR *aliases.Resolver
|
||||||
|
dinfo *domaininfo.DB
|
||||||
|
|
||||||
// Have we successfully completed AUTH?
|
// Have we successfully completed AUTH?
|
||||||
completedAuth bool
|
completedAuth bool
|
||||||
@@ -697,6 +719,10 @@ func (c *Conn) MAIL(params string) (code int, msg string) {
|
|||||||
"SPF check failed: %v", c.spfError)
|
"SPF check failed: %v", c.spfError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !c.secLevelCheck(addr) {
|
||||||
|
return 550, "security level check failed"
|
||||||
|
}
|
||||||
|
|
||||||
addr, err = normalize.DomainToUnicode(addr)
|
addr, err = normalize.DomainToUnicode(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 501, "malformed address (IDNA conversion failed)"
|
return 501, "malformed address (IDNA conversion failed)"
|
||||||
@@ -727,6 +753,37 @@ func (c *Conn) checkSPF(addr string) (spf.Result, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// secLevelCheck checks if the security level is acceptable for the given
|
||||||
|
// address.
|
||||||
|
func (c *Conn) secLevelCheck(addr string) bool {
|
||||||
|
// Only check if SPF passes. This serves two purposes:
|
||||||
|
// - Skip for authenticated connections (we trust them implicitly).
|
||||||
|
// - Don't apply this if we can't be sure the sender is authorized.
|
||||||
|
// Otherwise anyone could raise the level of any domain.
|
||||||
|
if c.spfResult != spf.Pass {
|
||||||
|
slcResults.Add("skip", 1)
|
||||||
|
c.tr.Debugf("SPF did not pass, skipping security level check")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := envelope.DomainOf(addr)
|
||||||
|
level := domaininfo.SecLevel_PLAIN
|
||||||
|
if c.onTLS {
|
||||||
|
level = domaininfo.SecLevel_TLS_CLIENT
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := c.dinfo.IncomingSecLevel(domain, level)
|
||||||
|
if ok {
|
||||||
|
slcResults.Add("pass", 1)
|
||||||
|
c.tr.Debugf("security level check for %s passed (%s)", domain, level)
|
||||||
|
} else {
|
||||||
|
slcResults.Add("fail", 1)
|
||||||
|
c.tr.Errorf("security level check for %s failed (%s)", domain, level)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) RCPT(params string) (code int, msg string) {
|
func (c *Conn) RCPT(params string) (code int, msg string) {
|
||||||
// params should be: "TO:<name@host>", and possibly followed by options
|
// params should be: "TO:<name@host>", and possibly followed by options
|
||||||
// such as "NOTIFY=SUCCESS,DELAY" (which we ignore).
|
// such as "NOTIFY=SUCCESS,DELAY" (which we ignore).
|
||||||
@@ -793,7 +850,6 @@ func (c *Conn) DATA(params string) (code int, msg string) {
|
|||||||
if c.mailFrom == "" {
|
if c.mailFrom == "" {
|
||||||
return 503, "sender not yet given"
|
return 503, "sender not yet given"
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.rcptTo) == 0 {
|
if len(c.rcptTo) == 0 {
|
||||||
return 503, "need an address to send to"
|
return 503, "need an address to send to"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import (
|
|||||||
|
|
||||||
"blitiri.com.ar/go/chasquid/internal/aliases"
|
"blitiri.com.ar/go/chasquid/internal/aliases"
|
||||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/spf"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||||
|
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
@@ -169,7 +172,7 @@ func TestWrongMailParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Mail("from@from"); err != nil {
|
if err := c.Mail("from@plain"); err != nil {
|
||||||
t.Errorf("Mail: %v", err)
|
t.Errorf("Mail: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,11 +203,11 @@ func TestRcptBeforeMail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRcptOption(t *testing.T) {
|
func TestRcptOption(t *testing.T) {
|
||||||
c := mustDial(t, ModeSMTP, false)
|
c := mustDial(t, ModeSMTP, true)
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
if err := c.Mail("from@localhost"); err != nil {
|
if err := c.Mail("from@localhost"); err != nil {
|
||||||
t.Errorf("Mail: %v", err)
|
t.Fatalf("Mail: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
params := []string{
|
params := []string{
|
||||||
@@ -250,7 +253,7 @@ func TestReset(t *testing.T) {
|
|||||||
c := mustDial(t, ModeSMTP, false)
|
c := mustDial(t, ModeSMTP, false)
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
if err := c.Mail("from@from"); err != nil {
|
if err := c.Mail("from@plain"); err != nil {
|
||||||
t.Fatalf("MAIL FROM: %v", err)
|
t.Fatalf("MAIL FROM: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,7 +261,7 @@ func TestReset(t *testing.T) {
|
|||||||
t.Errorf("RSET: %v", err)
|
t.Errorf("RSET: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Mail("from@from"); err != nil {
|
if err := c.Mail("from@plain"); err != nil {
|
||||||
t.Errorf("MAIL after RSET: %v", err)
|
t.Errorf("MAIL after RSET: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -278,6 +281,55 @@ func TestRepeatedStartTLS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSecLevel(t *testing.T) {
|
||||||
|
// We can't simulate this externally because of the SPF record
|
||||||
|
// requirement, so do a narrow test on Conn.secLevelCheck.
|
||||||
|
tmpDir, err := ioutil.TempDir("", "chasquid_test:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
dinfo, err := domaininfo.New(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create domain info: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Conn{
|
||||||
|
tr: trace.New("testconn", "testconn"),
|
||||||
|
dinfo: dinfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
// No SPF, skip security checks.
|
||||||
|
c.spfResult = spf.None
|
||||||
|
c.onTLS = true
|
||||||
|
if !c.secLevelCheck("from@slc") {
|
||||||
|
t.Fatalf("TLS seclevel failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.onTLS = false
|
||||||
|
if !c.secLevelCheck("from@slc") {
|
||||||
|
t.Fatalf("plain seclevel failed, even though SPF does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now the real checks, once SPF passes.
|
||||||
|
c.spfResult = spf.Pass
|
||||||
|
|
||||||
|
if !c.secLevelCheck("from@slc") {
|
||||||
|
t.Fatalf("plain seclevel failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.onTLS = true
|
||||||
|
if !c.secLevelCheck("from@slc") {
|
||||||
|
t.Fatalf("TLS seclevel failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.onTLS = false
|
||||||
|
if c.secLevelCheck("from@slc") {
|
||||||
|
t.Fatalf("plain seclevel worked, downgrade was allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// === Benchmarks ===
|
// === Benchmarks ===
|
||||||
//
|
//
|
||||||
@@ -438,6 +490,7 @@ func realMain(m *testing.M) int {
|
|||||||
localC := &courier.Procmail{}
|
localC := &courier.Procmail{}
|
||||||
remoteC := &courier.SMTP{}
|
remoteC := &courier.SMTP{}
|
||||||
s.InitQueue(tmpDir+"/queue", localC, remoteC)
|
s.InitQueue(tmpDir+"/queue", localC, remoteC)
|
||||||
|
s.InitDomainInfo(tmpDir + "/domaininfo")
|
||||||
|
|
||||||
udb := userdb.New("/dev/null")
|
udb := userdb.New("/dev/null")
|
||||||
udb.AddUser("testuser", "testpasswd")
|
udb.AddUser("testuser", "testpasswd")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||||
"blitiri.com.ar/go/chasquid/internal/envelope"
|
"blitiri.com.ar/go/chasquid/internal/envelope"
|
||||||
"blitiri.com.ar/go/chasquid/internal/smtp"
|
"blitiri.com.ar/go/chasquid/internal/smtp"
|
||||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
@@ -32,11 +33,13 @@ var (
|
|||||||
|
|
||||||
// Exported variables.
|
// Exported variables.
|
||||||
var (
|
var (
|
||||||
tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount")
|
tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount")
|
||||||
|
slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SMTP delivers remote mail via outgoing SMTP.
|
// SMTP delivers remote mail via outgoing SMTP.
|
||||||
type SMTP struct {
|
type SMTP struct {
|
||||||
|
Dinfo *domaininfo.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
|
func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
|
||||||
@@ -44,7 +47,8 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
|
|||||||
defer tr.Finish()
|
defer tr.Finish()
|
||||||
tr.Debugf("%s -> %s", from, to)
|
tr.Debugf("%s -> %s", from, to)
|
||||||
|
|
||||||
mx, err := lookupMX(envelope.DomainOf(to))
|
toDomain := envelope.DomainOf(to)
|
||||||
|
mx, err := lookupMX(toDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Note this is considered a permanent error.
|
// Note this is considered a permanent error.
|
||||||
// This is in line with what other servers (Exim) do. However, the
|
// This is in line with what other servers (Exim) do. However, the
|
||||||
@@ -68,6 +72,7 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
|
|||||||
// Do we use insecure TLS?
|
// Do we use insecure TLS?
|
||||||
// Set as fallback when retrying.
|
// Set as fallback when retrying.
|
||||||
insecure := false
|
insecure := false
|
||||||
|
secLevel := domaininfo.SecLevel_PLAIN
|
||||||
|
|
||||||
retry:
|
retry:
|
||||||
conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout)
|
conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout)
|
||||||
@@ -110,15 +115,25 @@ retry:
|
|||||||
if config.InsecureSkipVerify {
|
if config.InsecureSkipVerify {
|
||||||
tr.Debugf("Insecure - using TLS, but cert does not match %s", mx)
|
tr.Debugf("Insecure - using TLS, but cert does not match %s", mx)
|
||||||
tlsCount.Add("tls:insecure", 1)
|
tlsCount.Add("tls:insecure", 1)
|
||||||
|
secLevel = domaininfo.SecLevel_TLS_INSECURE
|
||||||
} else {
|
} else {
|
||||||
tlsCount.Add("tls:secure", 1)
|
tlsCount.Add("tls:secure", 1)
|
||||||
tr.Debugf("Secure - using TLS")
|
tr.Debugf("Secure - using TLS")
|
||||||
|
secLevel = domaininfo.SecLevel_TLS_SECURE
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tlsCount.Add("plain", 1)
|
tlsCount.Add("plain", 1)
|
||||||
tr.Debugf("Insecure - NOT using TLS")
|
tr.Debugf("Insecure - NOT using TLS")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if toDomain != "" && !s.Dinfo.OutgoingSecLevel(toDomain, secLevel) {
|
||||||
|
// We consider the failure transient, so transient misconfigurations
|
||||||
|
// do not affect deliveries.
|
||||||
|
slcResults.Add("fail", 1)
|
||||||
|
return tr.Errorf("Security level check failed (level:%s)", secLevel), false
|
||||||
|
}
|
||||||
|
slcResults.Add("pass", 1)
|
||||||
|
|
||||||
// c.Mail will add the <> for us when the address is empty.
|
// c.Mail will add the <> for us when the address is empty.
|
||||||
if from == "<>" {
|
if from == "<>" {
|
||||||
from = ""
|
from = ""
|
||||||
|
|||||||
@@ -2,12 +2,30 @@ package courier
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func newSMTP(t *testing.T) (*SMTP, string) {
|
||||||
|
dir, err := ioutil.TempDir("", "smtp_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dinfo, err := domaininfo.New(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SMTP{dinfo}, dir
|
||||||
|
}
|
||||||
|
|
||||||
// Fake server, to test SMTP out.
|
// Fake server, to test SMTP out.
|
||||||
func fakeServer(t *testing.T, responses map[string]string) string {
|
func fakeServer(t *testing.T, responses map[string]string) string {
|
||||||
l, err := net.Listen("tcp", "localhost:0")
|
l, err := net.Listen("tcp", "localhost:0")
|
||||||
@@ -72,7 +90,8 @@ func TestSMTP(t *testing.T) {
|
|||||||
fakeMX["to"] = host
|
fakeMX["to"] = host
|
||||||
*smtpPort = port
|
*smtpPort = port
|
||||||
|
|
||||||
s := &SMTP{}
|
s, tmpDir := newSMTP(t)
|
||||||
|
defer os.Remove(tmpDir)
|
||||||
err, _ := s.Deliver("me@me", "to@to", []byte("data"))
|
err, _ := s.Deliver("me@me", "to@to", []byte("data"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("deliver failed: %v", err)
|
t.Errorf("deliver failed: %v", err)
|
||||||
@@ -132,7 +151,8 @@ func TestSMTPErrors(t *testing.T) {
|
|||||||
fakeMX["to"] = host
|
fakeMX["to"] = host
|
||||||
*smtpPort = port
|
*smtpPort = port
|
||||||
|
|
||||||
s := &SMTP{}
|
s, tmpDir := newSMTP(t)
|
||||||
|
defer os.Remove(tmpDir)
|
||||||
err, _ := s.Deliver("me@me", "to@to", []byte("data"))
|
err, _ := s.Deliver("me@me", "to@to", []byte("data"))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err)
|
t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err)
|
||||||
|
|||||||
133
internal/domaininfo/domaininfo.go
Normal file
133
internal/domaininfo/domaininfo.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
// Package domaininfo implements a domain information database, to keep track
|
||||||
|
// of things we know about a particular domain.
|
||||||
|
package domaininfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/protoio"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command to generate domaininfo.pb.go.
|
||||||
|
//go:generate protoc --go_out=. domaininfo.proto
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
// Persistent store with the list of domains we know.
|
||||||
|
store *protoio.Store
|
||||||
|
|
||||||
|
info map[string]*Domain
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
ev *trace.EventLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(dir string) (*DB, error) {
|
||||||
|
st, err := protoio.NewStore(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l := &DB{
|
||||||
|
store: st,
|
||||||
|
info: map[string]*Domain{},
|
||||||
|
}
|
||||||
|
l.ev = trace.NewEventLog("DomainInfo", fmt.Sprintf("%p", l))
|
||||||
|
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the database from disk; should be called once at initialization.
|
||||||
|
func (db *DB) Load() error {
|
||||||
|
ids, err := db.store.ListIDs()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
d := &Domain{}
|
||||||
|
_, err := db.store.Get(id, d)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error loading %q: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.info[d.Name] = d
|
||||||
|
}
|
||||||
|
|
||||||
|
db.ev.Debugf("loaded: %s", ids)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) write(d *Domain) {
|
||||||
|
err := db.store.Put(d.Name, d)
|
||||||
|
if err != nil {
|
||||||
|
db.ev.Errorf("%s error saving: %v", d.Name, err)
|
||||||
|
} else {
|
||||||
|
db.ev.Debugf("%s saved", d.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncomingSecLevel checks an incoming security level for the domain.
|
||||||
|
// Returns true if allowed, false otherwise.
|
||||||
|
func (db *DB) IncomingSecLevel(domain string, level SecLevel) bool {
|
||||||
|
db.Lock()
|
||||||
|
defer db.Unlock()
|
||||||
|
|
||||||
|
d, exists := db.info[domain]
|
||||||
|
if !exists {
|
||||||
|
d = &Domain{Name: domain}
|
||||||
|
db.info[domain] = d
|
||||||
|
defer db.write(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
if level < d.IncomingSecLevel {
|
||||||
|
db.ev.Errorf("%s incoming denied: %s < %s",
|
||||||
|
d.Name, level, d.IncomingSecLevel)
|
||||||
|
return false
|
||||||
|
} else if level == d.IncomingSecLevel {
|
||||||
|
db.ev.Debugf("%s incoming allowed: %s == %s",
|
||||||
|
d.Name, level, d.IncomingSecLevel)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
db.ev.Printf("%s incoming level raised: %s > %s",
|
||||||
|
d.Name, level, d.IncomingSecLevel)
|
||||||
|
d.IncomingSecLevel = level
|
||||||
|
if exists {
|
||||||
|
defer db.write(d)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OutgoingSecLevel checks an incoming security level for the domain.
|
||||||
|
// Returns true if allowed, false otherwise.
|
||||||
|
func (db *DB) OutgoingSecLevel(domain string, level SecLevel) bool {
|
||||||
|
db.Lock()
|
||||||
|
defer db.Unlock()
|
||||||
|
|
||||||
|
d, exists := db.info[domain]
|
||||||
|
if !exists {
|
||||||
|
d = &Domain{Name: domain}
|
||||||
|
db.info[domain] = d
|
||||||
|
defer db.write(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
if level < d.OutgoingSecLevel {
|
||||||
|
db.ev.Errorf("%s outgoing denied: %s < %s",
|
||||||
|
d.Name, level, d.OutgoingSecLevel)
|
||||||
|
return false
|
||||||
|
} else if level == d.OutgoingSecLevel {
|
||||||
|
db.ev.Debugf("%s outgoing allowed: %s == %s",
|
||||||
|
d.Name, level, d.OutgoingSecLevel)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
db.ev.Printf("%s outgoing level raised: %s > %s",
|
||||||
|
d.Name, level, d.OutgoingSecLevel)
|
||||||
|
d.OutgoingSecLevel = level
|
||||||
|
if exists {
|
||||||
|
defer db.write(d)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
96
internal/domaininfo/domaininfo.pb.go
Normal file
96
internal/domaininfo/domaininfo.pb.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
// Code generated by protoc-gen-go.
|
||||||
|
// source: domaininfo.proto
|
||||||
|
// DO NOT EDIT!
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package domaininfo is a generated protocol buffer package.
|
||||||
|
|
||||||
|
It is generated from these files:
|
||||||
|
domaininfo.proto
|
||||||
|
|
||||||
|
It has these top-level messages:
|
||||||
|
Domain
|
||||||
|
*/
|
||||||
|
package domaininfo
|
||||||
|
|
||||||
|
import proto "github.com/golang/protobuf/proto"
|
||||||
|
import fmt "fmt"
|
||||||
|
import math "math"
|
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal
|
||||||
|
var _ = fmt.Errorf
|
||||||
|
var _ = math.Inf
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the proto package it is being compiled against.
|
||||||
|
// A compilation error at this line likely means your copy of the
|
||||||
|
// proto package needs to be updated.
|
||||||
|
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||||
|
|
||||||
|
type SecLevel int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Does not do TLS.
|
||||||
|
SecLevel_PLAIN SecLevel = 0
|
||||||
|
// TLS client connection (no certificate validation).
|
||||||
|
SecLevel_TLS_CLIENT SecLevel = 1
|
||||||
|
// TLS, but with invalid certificates.
|
||||||
|
SecLevel_TLS_INSECURE SecLevel = 2
|
||||||
|
// TLS, with valid certificates.
|
||||||
|
SecLevel_TLS_SECURE SecLevel = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
var SecLevel_name = map[int32]string{
|
||||||
|
0: "PLAIN",
|
||||||
|
1: "TLS_CLIENT",
|
||||||
|
2: "TLS_INSECURE",
|
||||||
|
3: "TLS_SECURE",
|
||||||
|
}
|
||||||
|
var SecLevel_value = map[string]int32{
|
||||||
|
"PLAIN": 0,
|
||||||
|
"TLS_CLIENT": 1,
|
||||||
|
"TLS_INSECURE": 2,
|
||||||
|
"TLS_SECURE": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x SecLevel) String() string {
|
||||||
|
return proto.EnumName(SecLevel_name, int32(x))
|
||||||
|
}
|
||||||
|
func (SecLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
|
||||||
|
// Security level for mail coming from this domain (they send to us).
|
||||||
|
IncomingSecLevel SecLevel `protobuf:"varint,2,opt,name=incoming_sec_level,json=incomingSecLevel,enum=domaininfo.SecLevel" json:"incoming_sec_level,omitempty"`
|
||||||
|
// Security level for mail going to this domain (we send to them).
|
||||||
|
OutgoingSecLevel SecLevel `protobuf:"varint,3,opt,name=outgoing_sec_level,json=outgoingSecLevel,enum=domaininfo.SecLevel" json:"outgoing_sec_level,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Domain) Reset() { *m = Domain{} }
|
||||||
|
func (m *Domain) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*Domain) ProtoMessage() {}
|
||||||
|
func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
proto.RegisterType((*Domain)(nil), "domaininfo.Domain")
|
||||||
|
proto.RegisterEnum("domaininfo.SecLevel", SecLevel_name, SecLevel_value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { proto.RegisterFile("domaininfo.proto", fileDescriptor0) }
|
||||||
|
|
||||||
|
var fileDescriptor0 = []byte{
|
||||||
|
// 189 bytes of a gzipped FileDescriptorProto
|
||||||
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xc9, 0xcf, 0x4d,
|
||||||
|
0xcc, 0xcc, 0xcb, 0xcc, 0x4b, 0xcb, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88,
|
||||||
|
0x28, 0x2d, 0x61, 0xe4, 0x62, 0x73, 0x01, 0x73, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53,
|
||||||
|
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x21, 0x27, 0x2e, 0xa1, 0xcc, 0xbc, 0xe4,
|
||||||
|
0xfc, 0xdc, 0xcc, 0xbc, 0xf4, 0xf8, 0xe2, 0xd4, 0xe4, 0xf8, 0x9c, 0xd4, 0xb2, 0xd4, 0x1c, 0x09,
|
||||||
|
0x26, 0xa0, 0x0a, 0x3e, 0x23, 0x11, 0x3d, 0x24, 0x93, 0x83, 0x53, 0x93, 0x7d, 0x40, 0x72, 0x41,
|
||||||
|
0x02, 0x30, 0xf5, 0x30, 0x11, 0x90, 0x19, 0xf9, 0xa5, 0x25, 0xe9, 0xf9, 0xa8, 0x66, 0x30, 0xe3,
|
||||||
|
0x33, 0x03, 0xa6, 0x1e, 0x26, 0xa2, 0xe5, 0xce, 0xc5, 0x01, 0x37, 0x8f, 0x93, 0x8b, 0x35, 0xc0,
|
||||||
|
0xc7, 0xd1, 0xd3, 0x4f, 0x80, 0x41, 0x88, 0x8f, 0x8b, 0x2b, 0xc4, 0x27, 0x38, 0xde, 0xd9, 0xc7,
|
||||||
|
0xd3, 0xd5, 0x2f, 0x44, 0x80, 0x51, 0x48, 0x80, 0x8b, 0x07, 0xc4, 0xf7, 0xf4, 0x0b, 0x76, 0x75,
|
||||||
|
0x0e, 0x0d, 0x72, 0x15, 0x60, 0x82, 0xa9, 0x80, 0xf2, 0x99, 0x93, 0xd8, 0xc0, 0x41, 0x60, 0x0c,
|
||||||
|
0x08, 0x00, 0x00, 0xff, 0xff, 0x2c, 0x78, 0x65, 0x5b, 0x16, 0x01, 0x00, 0x00,
|
||||||
|
}
|
||||||
28
internal/domaininfo/domaininfo.proto
Normal file
28
internal/domaininfo/domaininfo.proto
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package domaininfo;
|
||||||
|
|
||||||
|
enum SecLevel {
|
||||||
|
// Does not do TLS.
|
||||||
|
PLAIN = 0;
|
||||||
|
|
||||||
|
// TLS client connection (no certificate validation).
|
||||||
|
TLS_CLIENT = 1;
|
||||||
|
|
||||||
|
// TLS, but with invalid certificates.
|
||||||
|
TLS_INSECURE = 2;
|
||||||
|
|
||||||
|
// TLS, with valid certificates.
|
||||||
|
TLS_SECURE = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Domain {
|
||||||
|
string name = 1;
|
||||||
|
|
||||||
|
// Security level for mail coming from this domain (they send to us).
|
||||||
|
SecLevel incoming_sec_level = 2;
|
||||||
|
|
||||||
|
// Security level for mail going to this domain (we send to them).
|
||||||
|
SecLevel outgoing_sec_level = 3;
|
||||||
|
}
|
||||||
133
internal/domaininfo/domaininfo_test.go
Normal file
133
internal/domaininfo/domaininfo_test.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package domaininfo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustTempDir(t *testing.T) string {
|
||||||
|
dir, err := ioutil.TempDir("", "greylisting_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("test directory: %q", dir)
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasic(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
db, err := New(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Load(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !db.IncomingSecLevel("d1", SecLevel_PLAIN) {
|
||||||
|
t.Errorf("new domain as plain not allowed")
|
||||||
|
}
|
||||||
|
if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) {
|
||||||
|
t.Errorf("increment to tls-secure not allowed")
|
||||||
|
}
|
||||||
|
if db.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
|
||||||
|
t.Errorf("decrement to tls-insecure was allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until it is written to disk.
|
||||||
|
for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); {
|
||||||
|
d := &Domain{}
|
||||||
|
ok, _ := db.store.Get("d1", d)
|
||||||
|
if ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it was added to the store and a new db sees it.
|
||||||
|
db2, err := New(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db2.Load(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
|
||||||
|
t.Errorf("decrement to tls-insecure was allowed in new DB")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDomain(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
db, err := New(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
domain string
|
||||||
|
level SecLevel
|
||||||
|
}{
|
||||||
|
{"plain", SecLevel_PLAIN},
|
||||||
|
{"insecure", SecLevel_TLS_INSECURE},
|
||||||
|
{"secure", SecLevel_TLS_SECURE},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
if !db.IncomingSecLevel(c.domain, c.level) {
|
||||||
|
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
|
||||||
|
}
|
||||||
|
if !db.OutgoingSecLevel(c.domain, c.level) {
|
||||||
|
t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProgressions(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
db, err := New(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
domain string
|
||||||
|
lvl SecLevel
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"pisis", SecLevel_PLAIN, true},
|
||||||
|
{"pisis", SecLevel_TLS_INSECURE, true},
|
||||||
|
{"pisis", SecLevel_TLS_SECURE, true},
|
||||||
|
{"pisis", SecLevel_TLS_INSECURE, false},
|
||||||
|
{"pisis", SecLevel_TLS_SECURE, true},
|
||||||
|
|
||||||
|
{"ssip", SecLevel_TLS_SECURE, true},
|
||||||
|
{"ssip", SecLevel_TLS_SECURE, true},
|
||||||
|
{"ssip", SecLevel_TLS_INSECURE, false},
|
||||||
|
{"ssip", SecLevel_PLAIN, false},
|
||||||
|
}
|
||||||
|
for i, c := range cases {
|
||||||
|
if ok := db.IncomingSecLevel(c.domain, c.lvl); ok != c.ok {
|
||||||
|
t.Errorf("%2d %q in attempt for %s failed: got %v, expected %v",
|
||||||
|
i, c.domain, c.lvl, ok, c.ok)
|
||||||
|
}
|
||||||
|
if ok := db.OutgoingSecLevel(c.domain, c.lvl); ok != c.ok {
|
||||||
|
t.Errorf("%2d %q out attempt for %s failed: got %v, expected %v",
|
||||||
|
i, c.domain, c.lvl, ok, c.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user