package courier import ( "fmt" "net" "strings" "testing" "time" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/testlib" "blitiri.com.ar/go/chasquid/internal/trace" ) // This domain will cause idna.ToASCII to fail. var invalidDomain = "test " + strings.Repeat("x", 65536) + "\uff00" // Override the netLookupMX function, to return controlled results for // testing. var testMX = map[string][]*net.MX{} var testMXErr = map[string]error{} func init() { netLookupMX = func(name string) ([]*net.MX, error) { return testMX[name], testMXErr[name] } } func newSMTP(t *testing.T) (*SMTP, string) { dir := testlib.MustTempDir(t) dinfo, err := domaininfo.New(dir) if err != nil { t.Fatal(err) } return &SMTP{"hello", dinfo, nil}, dir } func TestSMTP(t *testing.T) { // Shorten the total timeout, so the test fails quickly if the protocol // gets stuck. smtpTotalTimeout = 5 * time.Second responses := map[string]string{ "_welcome": "220 welcome\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "250 rcpt ok\n", "DATA": "354 send data\n", "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } srv := newFakeServer(t, responses) defer srv.Cleanup() host, port := srv.HostPort() // Put a non-existing host first, so we check that if the first host // doesn't work, we try with the rest. // The host we use is invalid, to avoid having to do an actual network // lookup whick makes the test more hermetic. This is a hack, ideally we // would be able to override the default resolver, but Go does not // implement that yet. testMX["to"] = []*net.MX{ {Host: ":::", Pref: 10}, {Host: host, Pref: 20}, } *smtpPort = port s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) err, _ := s.Deliver("me@me", "to@to", []byte("data")) if err != nil { t.Errorf("deliver failed: %v", err) } srv.Wait() } func TestSMTPErrors(t *testing.T) { // Shorten the total timeout, so the test fails quickly if the protocol // gets stuck. smtpTotalTimeout = 1 * time.Second responses := []map[string]string{ // First test: hang response, should fail due to timeout. { "_welcome": "220 no newline", }, // MAIL FROM not allowed. { "_welcome": "220 mail from not allowed\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "501 mail error\n", }, // RCPT TO not allowed. { "_welcome": "220 rcpt to not allowed\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "501 rcpt error\n", }, // DATA error. { "_welcome": "220 data error\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "250 rcpt ok\n", "DATA": "554 data error\n", }, // DATA response error. { "_welcome": "220 data response error\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "250 rcpt ok\n", "DATA": "354 send data\n", "_DATA": "551 data response error\n", }, } for _, rs := range responses { srv := newFakeServer(t, rs) defer srv.Cleanup() host, port := srv.HostPort() testMX["to"] = []*net.MX{{Host: host, Pref: 10}} *smtpPort = port s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) err, _ := s.Deliver("me@me", "to@to", []byte("data")) if err == nil { t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err) } t.Logf("failed as expected: %v", err) srv.Wait() } } func TestNoMXServer(t *testing.T) { testMX["to"] = []*net.MX{} s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) err, permanent := s.Deliver("me@me", "to@to", []byte("data")) if err == nil { t.Errorf("delivery worked, expected failure") } if !permanent { t.Errorf("expected permanent failure, got transient (%v)", err) } t.Logf("got permanent failure, as expected: %v", err) } func TestTooManyMX(t *testing.T) { tr := trace.New("test", "test") testMX["domain"] = []*net.MX{ {Host: "h1", Pref: 10}, {Host: "h2", Pref: 20}, {Host: "h3", Pref: 30}, {Host: "h4", Pref: 40}, {Host: "h5", Pref: 50}, {Host: "h5", Pref: 60}, } mxs, err, perm := lookupMXs(tr, "domain") if err != nil { t.Fatalf("unexpected error: %v", err) } if perm != true { t.Fatalf("expected perm == true") } if len(mxs) != 5 { t.Errorf("expected len(mxs) == 5, got: %v", mxs) } } func TestFallbackToA(t *testing.T) { tr := trace.New("test", "test") testMX["domain"] = nil testMXErr["domain"] = &net.DNSError{ Err: "no such host (test)", IsTemporary: false, IsNotFound: true, } mxs, err, perm := lookupMXs(tr, "domain") if err != nil { t.Errorf("unexpected error: %v", err) } if perm != true { t.Errorf("expected perm == true") } if !(len(mxs) == 1 && mxs[0] == "domain") { t.Errorf("expected mxs == [domain], got: %v", mxs) } } func TestTemporaryDNSerror(t *testing.T) { tr := trace.New("test", "test") testMX["domain"] = nil testMXErr["domain"] = &net.DNSError{ Err: "temp error (test)", IsTemporary: true, } mxs, err, perm := lookupMXs(tr, "domain") if !(mxs == nil && err == testMXErr["domain"]) { t.Errorf("expected mxs == nil, err == test error, got: %v, %v", mxs, err) } if perm != false { t.Errorf("expected perm == false") } } func TestMXLookupError(t *testing.T) { tr := trace.New("test", "test") testMX["domain"] = nil testMXErr["domain"] = fmt.Errorf("test error") mxs, err, perm := lookupMXs(tr, "domain") if !(mxs == nil && err == testMXErr["domain"]) { t.Errorf("expected mxs == nil, err == test error, got: %v, %v", mxs, err) } if perm != false { t.Errorf("expected perm == false") } } func TestLookupInvalidDomain(t *testing.T) { tr := trace.New("test", "test") mxs, err, perm := lookupMXs(tr, invalidDomain) if !(mxs == nil && err != nil) { t.Errorf("expected err != nil, got: %v, %v", mxs, err) } if perm != true { t.Fatalf("expected perm == true") } } // Server fake responses for a complete TLS delivery. // We use this in a few tests, so make it common. var tlsResponses = map[string]string{ "_welcome": "220 welcome\n", "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", "STARTTLS": "220 starttls go\n", "_STARTTLS": "ok", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "250 rcpt ok\n", "DATA": "354 send data\n", "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } func TestTLS(t *testing.T) { smtpTotalTimeout = 5 * time.Second srv := newFakeServer(t, tlsResponses) defer srv.Cleanup() _, *smtpPort = srv.HostPort() testMX["to"] = []*net.MX{ {Host: "localhost", Pref: 20}, } s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) err, _ := s.Deliver("me@me", "to@to", []byte("data")) if err != nil { t.Errorf("deliver failed: %v", err) } srv.Wait() // Now do another delivery, but without TLS, to check that the detection // of connection downgrade is working. responses := map[string]string{ "_welcome": "220 welcome\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:": "250 mail ok\n", "RCPT TO:": "250 rcpt ok\n", "DATA": "354 send data\n", "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } srv = newFakeServer(t, responses) defer srv.Cleanup() _, *smtpPort = srv.HostPort() err, permanent := s.Deliver("me@me", "to@to", []byte("data")) if !strings.Contains(err.Error(), "Security level check failed (level:PLAIN)") { t.Errorf("expected sec level check failed, got: %v", err) } if permanent != false { t.Errorf("expected transient failure, got permanent") } srv.Wait() } func TestTLSError(t *testing.T) { smtpTotalTimeout = 5 * time.Second responses := map[string]string{ "_welcome": "220 welcome\n", "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", "STARTTLS": "500 starttls err\n", "_STARTTLS": "no", } srv := newFakeServer(t, responses) defer srv.Cleanup() _, *smtpPort = srv.HostPort() testMX["to"] = []*net.MX{ {Host: "localhost", Pref: 20}, } s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) err, permanent := s.Deliver("me@me", "to@to", []byte("data")) if !strings.Contains(err.Error(), "TLS error:") { t.Errorf("expected TLS error, got: %v", err) } if permanent != false { t.Errorf("expected transient failure, got permanent") } srv.Wait() } func TestSTSPolicyEnforcement(t *testing.T) { smtpTotalTimeout = 5 * time.Second srv := newFakeServer(t, tlsResponses) defer srv.Cleanup() _, *smtpPort = srv.HostPort() s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) a := &attempt{ courier: s, from: "me@me", to: "to@to", toDomain: "to", data: []byte("data"), tr: trace.New("test", "test"), } a.stsPolicy = &sts.Policy{ Version: "STSv1", Mode: sts.Enforce, MXs: []string{"mx"}, MaxAge: 1 * time.Minute, } // At this point the cert is not valid, which is incompatible with STS // policy, so we expect it to fail. err, permanent := a.deliver("localhost") if !strings.Contains(err.Error(), "invalid security level (TLS_INSECURE) for STS policy") { t.Errorf("expected invalid sec level error, got %v", err) } if permanent != false { t.Errorf("expected transient error, got permanent") } srv.Wait() // Do another delivery attempt, but this time we trust the server cert. // This time it should be successful, because the connection level should // be TLS_SECURE which is required by the STS policy. srv = newFakeServer(t, tlsResponses) _, *smtpPort = srv.HostPort() defer srv.Cleanup() certRoots = srv.rootCA() defer func() { certRoots = nil }() err, permanent = a.deliver("localhost") if err != nil { t.Errorf("expected success, got %v (permanent=%v)", err, permanent) } srv.Wait() }