mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-21 15:17:01 +00:00
courier: Add tests for STS policy checks
This patch adds tests for STS policy checks in combination with TLS security levels. This helps ensure we're detecting mis-matches of TLS status (plain/insecure/secure) and STS policy enforcement.
This commit is contained in:
@@ -152,6 +152,7 @@ func (a *attempt) deliver(mx string) (error, bool) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err = c.StartTLS(config)
|
||||
if err != nil {
|
||||
tlsCount.Add("tls:failed", 1)
|
||||
@@ -206,6 +207,9 @@ func (a *attempt) deliver(mx string) (error, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// CA roots to validate against, so we can override it for testing.
|
||||
var certRoots *x509.CertPool = nil
|
||||
|
||||
func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
|
||||
// Validate certificates, using the same logic Go does, and following the
|
||||
// official example at
|
||||
@@ -213,6 +217,7 @@ func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: cs.ServerName,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
Roots: certRoots,
|
||||
}
|
||||
for _, cert := range cs.PeerCertificates[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
|
||||
@@ -3,7 +3,9 @@ package courier
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"os"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||
"blitiri.com.ar/go/chasquid/internal/sts"
|
||||
"blitiri.com.ar/go/chasquid/internal/testlib"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
)
|
||||
@@ -44,6 +47,7 @@ func newSMTP(t *testing.T) (*SMTP, string) {
|
||||
// Fake server, to test SMTP out.
|
||||
type FakeServer struct {
|
||||
t *testing.T
|
||||
tmpDir string
|
||||
responses map[string]string
|
||||
wg *sync.WaitGroup
|
||||
addr string
|
||||
@@ -53,6 +57,7 @@ type FakeServer struct {
|
||||
func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
|
||||
s := &FakeServer{
|
||||
t: t,
|
||||
tmpDir: testlib.MustTempDir(t),
|
||||
responses: responses,
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
@@ -60,24 +65,27 @@ func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *FakeServer) loadTLS() string {
|
||||
tmpDir := testlib.MustTempDir(s.t)
|
||||
func (s *FakeServer) Cleanup() {
|
||||
// Remove our temporary data. Be extra paranoid and make sure the
|
||||
// directory isn't too shallow.
|
||||
if len(s.tmpDir) > 8 {
|
||||
os.RemoveAll(s.tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FakeServer) initTLS() {
|
||||
var err error
|
||||
s.tlsConfig, err = testlib.GenerateCert(tmpDir)
|
||||
s.tlsConfig, err = testlib.GenerateCert(s.tmpDir)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
s.t.Fatalf("error generating cert: %v", err)
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(tmpDir+"/cert.pem", tmpDir+"/key.pem")
|
||||
cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem")
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
s.t.Fatalf("error loading temp cert: %v", err)
|
||||
}
|
||||
|
||||
s.tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
func (s *FakeServer) start() string {
|
||||
@@ -88,15 +96,14 @@ func (s *FakeServer) start() string {
|
||||
}
|
||||
s.addr = l.Addr().String()
|
||||
|
||||
s.initTLS()
|
||||
|
||||
s.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer l.Close()
|
||||
|
||||
tmpDir := s.loadTLS()
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -172,6 +179,7 @@ func TestSMTP(t *testing.T) {
|
||||
"QUIT": "250 quit ok\n",
|
||||
}
|
||||
srv := newFakeServer(t, responses)
|
||||
defer srv.Cleanup()
|
||||
host, port := srv.HostPort()
|
||||
|
||||
// Put a non-existing host first, so we check that if the first host
|
||||
@@ -244,6 +252,7 @@ func TestSMTPErrors(t *testing.T) {
|
||||
|
||||
for _, rs := range responses {
|
||||
srv := newFakeServer(t, rs)
|
||||
defer srv.Cleanup()
|
||||
host, port := srv.HostPort()
|
||||
|
||||
testMX["to"] = []*net.MX{{Host: host, Pref: 10}}
|
||||
@@ -359,10 +368,9 @@ func TestLookupInvalidDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
smtpTotalTimeout = 5 * time.Second
|
||||
|
||||
responses := map[string]string{
|
||||
// Server fake responses for a complete TLS delivery.
|
||||
// We use this in a few tests, so make it common.
|
||||
var tlsResponses = map[string]string{
|
||||
"_welcome": "220 welcome\n",
|
||||
"EHLO hello": "250-ehlo ok\n250 STARTTLS\n",
|
||||
"STARTTLS": "220 starttls go\n",
|
||||
@@ -373,7 +381,11 @@ func TestTLS(t *testing.T) {
|
||||
"_DATA": "250 data ok\n",
|
||||
"QUIT": "250 quit ok\n",
|
||||
}
|
||||
srv := newFakeServer(t, responses)
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
smtpTotalTimeout = 5 * time.Second
|
||||
srv := newFakeServer(t, tlsResponses)
|
||||
defer srv.Cleanup()
|
||||
_, *smtpPort = srv.HostPort()
|
||||
|
||||
testMX["to"] = []*net.MX{
|
||||
@@ -391,7 +403,7 @@ func TestTLS(t *testing.T) {
|
||||
|
||||
// Now do another delivery, but without TLS, to check that the detection
|
||||
// of connection downgrade is working.
|
||||
responses = map[string]string{
|
||||
responses := map[string]string{
|
||||
"_welcome": "220 welcome\n",
|
||||
"EHLO hello": "250 ehlo ok\n",
|
||||
"MAIL FROM:<me@me>": "250 mail ok\n",
|
||||
@@ -401,6 +413,7 @@ func TestTLS(t *testing.T) {
|
||||
"QUIT": "250 quit ok\n",
|
||||
}
|
||||
srv = newFakeServer(t, responses)
|
||||
defer srv.Cleanup()
|
||||
_, *smtpPort = srv.HostPort()
|
||||
|
||||
err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
|
||||
@@ -425,6 +438,7 @@ func TestTLSError(t *testing.T) {
|
||||
"_STARTTLS": "no",
|
||||
}
|
||||
srv := newFakeServer(t, responses)
|
||||
defer srv.Cleanup()
|
||||
_, *smtpPort = srv.HostPort()
|
||||
|
||||
testMX["to"] = []*net.MX{
|
||||
@@ -443,3 +457,76 @@ func TestTLSError(t *testing.T) {
|
||||
|
||||
srv.Wait()
|
||||
}
|
||||
|
||||
func TestSTSPolicyEnforcement(t *testing.T) {
|
||||
smtpTotalTimeout = 5 * time.Second
|
||||
srv := newFakeServer(t, tlsResponses)
|
||||
defer srv.Cleanup()
|
||||
_, *smtpPort = srv.HostPort()
|
||||
|
||||
s, tmpDir := newSMTP(t)
|
||||
defer testlib.RemoveIfOk(t, tmpDir)
|
||||
|
||||
a := &attempt{
|
||||
courier: s,
|
||||
from: "me@me",
|
||||
to: "to@to",
|
||||
toDomain: "to",
|
||||
data: []byte("data"),
|
||||
tr: trace.New("test", "test"),
|
||||
}
|
||||
|
||||
a.stsPolicy = &sts.Policy{
|
||||
Version: "STSv1",
|
||||
Mode: sts.Enforce,
|
||||
MXs: []string{"mx"},
|
||||
MaxAge: 1 * time.Minute,
|
||||
}
|
||||
|
||||
// At this point the cert is not valid, which is incompatible with STS
|
||||
// policy, so we expect it to fail.
|
||||
err, permanent := a.deliver("localhost")
|
||||
if !strings.Contains(err.Error(),
|
||||
"invalid security level (TLS_INSECURE) for STS policy") {
|
||||
t.Errorf("expected invalid sec level error, got %v", err)
|
||||
}
|
||||
if permanent != false {
|
||||
t.Errorf("expected transient error, got permanent")
|
||||
}
|
||||
|
||||
srv.Wait()
|
||||
|
||||
// Do another delivery attempt, but this time we trust the server cert.
|
||||
// This time it should be successful, because the connection level should
|
||||
// be TLS_SECURE which is required by the STS policy.
|
||||
srv = newFakeServer(t, tlsResponses)
|
||||
_, *smtpPort = srv.HostPort()
|
||||
defer srv.Cleanup()
|
||||
|
||||
certRoots = loadCert(t, srv.tmpDir+"/cert.pem")
|
||||
defer func() {
|
||||
certRoots = nil
|
||||
}()
|
||||
|
||||
err, permanent = a.deliver("localhost")
|
||||
if err != nil {
|
||||
t.Errorf("expected success, got %v (permanent=%v)", err, permanent)
|
||||
}
|
||||
|
||||
srv.Wait()
|
||||
}
|
||||
|
||||
func loadCert(t *testing.T, path string) *x509.CertPool {
|
||||
t.Helper()
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading cert %q: %v", path, err)
|
||||
}
|
||||
ok := pool.AppendCertsFromPEM(data)
|
||||
if !ok {
|
||||
t.Fatalf("failed to load cert %q", path)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user