1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-21 15:17:01 +00:00

courier: Add tests for STS policy checks

This patch adds tests for STS policy checks in combination with TLS
security levels.

This helps ensure we're detecting mis-matches of TLS status
(plain/insecure/secure) and STS policy enforcement.
This commit is contained in:
Alberto Bertogli
2021-10-25 12:39:09 +01:00
parent 14e270b7f5
commit 02322a74e6
2 changed files with 117 additions and 25 deletions

View File

@@ -152,6 +152,7 @@ func (a *attempt) deliver(mx string) (error, bool) {
return nil return nil
}, },
} }
err = c.StartTLS(config) err = c.StartTLS(config)
if err != nil { if err != nil {
tlsCount.Add("tls:failed", 1) tlsCount.Add("tls:failed", 1)
@@ -206,6 +207,9 @@ func (a *attempt) deliver(mx string) (error, bool) {
return nil, false return nil, false
} }
// CA roots to validate against, so we can override it for testing.
var certRoots *x509.CertPool = nil
func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel { func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
// Validate certificates, using the same logic Go does, and following the // Validate certificates, using the same logic Go does, and following the
// official example at // official example at
@@ -213,6 +217,7 @@ func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
DNSName: cs.ServerName, DNSName: cs.ServerName,
Intermediates: x509.NewCertPool(), Intermediates: x509.NewCertPool(),
Roots: certRoots,
} }
for _, cert := range cs.PeerCertificates[1:] { for _, cert := range cs.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert) opts.Intermediates.AddCert(cert)

View File

@@ -3,7 +3,9 @@ package courier
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/textproto" "net/textproto"
"os" "os"
@@ -13,6 +15,7 @@ import (
"time" "time"
"blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/domaininfo"
"blitiri.com.ar/go/chasquid/internal/sts"
"blitiri.com.ar/go/chasquid/internal/testlib" "blitiri.com.ar/go/chasquid/internal/testlib"
"blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/chasquid/internal/trace"
) )
@@ -44,6 +47,7 @@ func newSMTP(t *testing.T) (*SMTP, string) {
// Fake server, to test SMTP out. // Fake server, to test SMTP out.
type FakeServer struct { type FakeServer struct {
t *testing.T t *testing.T
tmpDir string
responses map[string]string responses map[string]string
wg *sync.WaitGroup wg *sync.WaitGroup
addr string addr string
@@ -53,6 +57,7 @@ type FakeServer struct {
func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
s := &FakeServer{ s := &FakeServer{
t: t, t: t,
tmpDir: testlib.MustTempDir(t),
responses: responses, responses: responses,
wg: &sync.WaitGroup{}, wg: &sync.WaitGroup{},
} }
@@ -60,24 +65,27 @@ func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
return s return s
} }
func (s *FakeServer) loadTLS() string { func (s *FakeServer) Cleanup() {
tmpDir := testlib.MustTempDir(s.t) // Remove our temporary data. Be extra paranoid and make sure the
// directory isn't too shallow.
if len(s.tmpDir) > 8 {
os.RemoveAll(s.tmpDir)
}
}
func (s *FakeServer) initTLS() {
var err error var err error
s.tlsConfig, err = testlib.GenerateCert(tmpDir) s.tlsConfig, err = testlib.GenerateCert(s.tmpDir)
if err != nil { if err != nil {
os.RemoveAll(tmpDir)
s.t.Fatalf("error generating cert: %v", err) s.t.Fatalf("error generating cert: %v", err)
} }
cert, err := tls.LoadX509KeyPair(tmpDir+"/cert.pem", tmpDir+"/key.pem") cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem")
if err != nil { if err != nil {
os.RemoveAll(tmpDir)
s.t.Fatalf("error loading temp cert: %v", err) s.t.Fatalf("error loading temp cert: %v", err)
} }
s.tlsConfig.Certificates = []tls.Certificate{cert} s.tlsConfig.Certificates = []tls.Certificate{cert}
return tmpDir
} }
func (s *FakeServer) start() string { func (s *FakeServer) start() string {
@@ -88,15 +96,14 @@ func (s *FakeServer) start() string {
} }
s.addr = l.Addr().String() s.addr = l.Addr().String()
s.initTLS()
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
defer l.Close() defer l.Close()
tmpDir := s.loadTLS()
defer os.RemoveAll(tmpDir)
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
panic(err) panic(err)
@@ -172,6 +179,7 @@ func TestSMTP(t *testing.T) {
"QUIT": "250 quit ok\n", "QUIT": "250 quit ok\n",
} }
srv := newFakeServer(t, responses) srv := newFakeServer(t, responses)
defer srv.Cleanup()
host, port := srv.HostPort() host, port := srv.HostPort()
// Put a non-existing host first, so we check that if the first host // Put a non-existing host first, so we check that if the first host
@@ -244,6 +252,7 @@ func TestSMTPErrors(t *testing.T) {
for _, rs := range responses { for _, rs := range responses {
srv := newFakeServer(t, rs) srv := newFakeServer(t, rs)
defer srv.Cleanup()
host, port := srv.HostPort() host, port := srv.HostPort()
testMX["to"] = []*net.MX{{Host: host, Pref: 10}} testMX["to"] = []*net.MX{{Host: host, Pref: 10}}
@@ -359,21 +368,24 @@ func TestLookupInvalidDomain(t *testing.T) {
} }
} }
// Server fake responses for a complete TLS delivery.
// We use this in a few tests, so make it common.
var tlsResponses = map[string]string{
"_welcome": "220 welcome\n",
"EHLO hello": "250-ehlo ok\n250 STARTTLS\n",
"STARTTLS": "220 starttls go\n",
"_STARTTLS": "ok",
"MAIL FROM:<me@me>": "250 mail ok\n",
"RCPT TO:<to@to>": "250 rcpt ok\n",
"DATA": "354 send data\n",
"_DATA": "250 data ok\n",
"QUIT": "250 quit ok\n",
}
func TestTLS(t *testing.T) { func TestTLS(t *testing.T) {
smtpTotalTimeout = 5 * time.Second smtpTotalTimeout = 5 * time.Second
srv := newFakeServer(t, tlsResponses)
responses := map[string]string{ defer srv.Cleanup()
"_welcome": "220 welcome\n",
"EHLO hello": "250-ehlo ok\n250 STARTTLS\n",
"STARTTLS": "220 starttls go\n",
"_STARTTLS": "ok",
"MAIL FROM:<me@me>": "250 mail ok\n",
"RCPT TO:<to@to>": "250 rcpt ok\n",
"DATA": "354 send data\n",
"_DATA": "250 data ok\n",
"QUIT": "250 quit ok\n",
}
srv := newFakeServer(t, responses)
_, *smtpPort = srv.HostPort() _, *smtpPort = srv.HostPort()
testMX["to"] = []*net.MX{ testMX["to"] = []*net.MX{
@@ -391,7 +403,7 @@ func TestTLS(t *testing.T) {
// Now do another delivery, but without TLS, to check that the detection // Now do another delivery, but without TLS, to check that the detection
// of connection downgrade is working. // of connection downgrade is working.
responses = map[string]string{ responses := map[string]string{
"_welcome": "220 welcome\n", "_welcome": "220 welcome\n",
"EHLO hello": "250 ehlo ok\n", "EHLO hello": "250 ehlo ok\n",
"MAIL FROM:<me@me>": "250 mail ok\n", "MAIL FROM:<me@me>": "250 mail ok\n",
@@ -401,6 +413,7 @@ func TestTLS(t *testing.T) {
"QUIT": "250 quit ok\n", "QUIT": "250 quit ok\n",
} }
srv = newFakeServer(t, responses) srv = newFakeServer(t, responses)
defer srv.Cleanup()
_, *smtpPort = srv.HostPort() _, *smtpPort = srv.HostPort()
err, permanent := s.Deliver("me@me", "to@to", []byte("data")) err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
@@ -425,6 +438,7 @@ func TestTLSError(t *testing.T) {
"_STARTTLS": "no", "_STARTTLS": "no",
} }
srv := newFakeServer(t, responses) srv := newFakeServer(t, responses)
defer srv.Cleanup()
_, *smtpPort = srv.HostPort() _, *smtpPort = srv.HostPort()
testMX["to"] = []*net.MX{ testMX["to"] = []*net.MX{
@@ -443,3 +457,76 @@ func TestTLSError(t *testing.T) {
srv.Wait() srv.Wait()
} }
func TestSTSPolicyEnforcement(t *testing.T) {
smtpTotalTimeout = 5 * time.Second
srv := newFakeServer(t, tlsResponses)
defer srv.Cleanup()
_, *smtpPort = srv.HostPort()
s, tmpDir := newSMTP(t)
defer testlib.RemoveIfOk(t, tmpDir)
a := &attempt{
courier: s,
from: "me@me",
to: "to@to",
toDomain: "to",
data: []byte("data"),
tr: trace.New("test", "test"),
}
a.stsPolicy = &sts.Policy{
Version: "STSv1",
Mode: sts.Enforce,
MXs: []string{"mx"},
MaxAge: 1 * time.Minute,
}
// At this point the cert is not valid, which is incompatible with STS
// policy, so we expect it to fail.
err, permanent := a.deliver("localhost")
if !strings.Contains(err.Error(),
"invalid security level (TLS_INSECURE) for STS policy") {
t.Errorf("expected invalid sec level error, got %v", err)
}
if permanent != false {
t.Errorf("expected transient error, got permanent")
}
srv.Wait()
// Do another delivery attempt, but this time we trust the server cert.
// This time it should be successful, because the connection level should
// be TLS_SECURE which is required by the STS policy.
srv = newFakeServer(t, tlsResponses)
_, *smtpPort = srv.HostPort()
defer srv.Cleanup()
certRoots = loadCert(t, srv.tmpDir+"/cert.pem")
defer func() {
certRoots = nil
}()
err, permanent = a.deliver("localhost")
if err != nil {
t.Errorf("expected success, got %v (permanent=%v)", err, permanent)
}
srv.Wait()
}
func loadCert(t *testing.T, path string) *x509.CertPool {
t.Helper()
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(path)
if err != nil {
t.Fatalf("error reading cert %q: %v", path, err)
}
ok := pool.AppendCertsFromPEM(data)
if !ok {
t.Fatalf("failed to load cert %q", path)
}
return pool
}