mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
Some tests did not make use of testlib.RemoveIfOk, which resulted in some duplication; this patch fixes that. While at it, userdb tests have its own simpler variant, so add some safety checks to it.
566 lines
16 KiB
Go
566 lines
16 KiB
Go
package sts
|
|
|
|
import (
|
|
"context"
|
|
"expvar"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"blitiri.com.ar/go/chasquid/internal/testlib"
|
|
)
|
|
|
|
// Override the lookup function to control its results.
|
|
var txtResults = map[string][]string{
|
|
"dom1": nil,
|
|
"dom2": {},
|
|
"dom3": {"abc", "def"},
|
|
"dom4": {"abc", "v=STSv1; id=blah;"},
|
|
|
|
// Matching policyForDomain below.
|
|
"_mta-sts.domain.com": {"v=STSv1; id=blah;"},
|
|
"_mta-sts.policy404": {"v=STSv1; id=blah;"},
|
|
"_mta-sts.version99": {"v=STSv1; id=blah;"},
|
|
}
|
|
var errTest = fmt.Errorf("error for testing purposes")
|
|
var txtErrors = map[string]error{
|
|
"_mta-sts.domErr": errTest,
|
|
}
|
|
|
|
func testLookupTXT(domain string) ([]string, error) {
|
|
return txtResults[domain], txtErrors[domain]
|
|
}
|
|
|
|
// Test policy for each of the requested domains. Will be served by the test
|
|
// HTTP server.
|
|
var policyForDomain = map[string]string{
|
|
// domain.com -> valid, with reasonable policy.
|
|
"domain.com": `
|
|
version: STSv1
|
|
mode: enforce
|
|
mx: *.mail.domain.com
|
|
max_age: 3600
|
|
`,
|
|
|
|
// version99 -> invalid policy (unknown version).
|
|
"version99": `
|
|
version: STSv99
|
|
mode: enforce
|
|
mx: *.mail.version99
|
|
max_age: 999
|
|
`,
|
|
}
|
|
|
|
func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
|
|
// For testing, the domain in the path (see urlForDomain).
|
|
policy, ok := policyForDomain[r.URL.Path[1:]]
|
|
if !ok {
|
|
http.Error(w, "not found", 404)
|
|
return
|
|
}
|
|
fmt.Fprintln(w, policy)
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
lookupTXT = testLookupTXT
|
|
|
|
// Create a test HTTP server, used by the more end-to-end tests.
|
|
httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
|
|
|
|
fakeURLForTesting = httpServer.URL
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func TestParsePolicy(t *testing.T) {
|
|
const pol1 = `
|
|
version: STSv1
|
|
mode: enforce
|
|
mx: *.mail.example.com
|
|
max_age: 123456
|
|
`
|
|
p, err := parsePolicy([]byte(pol1))
|
|
if err != nil {
|
|
t.Errorf("failed to parse policy: %v", err)
|
|
}
|
|
|
|
t.Logf("pol1: %+v", p)
|
|
}
|
|
|
|
func TestCheckPolicy(t *testing.T) {
|
|
validPs := []Policy{
|
|
{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
|
|
MXs: []string{"mx1", "mx2"}},
|
|
{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
|
|
MXs: []string{"mx1"}},
|
|
{Version: "STSv1", Mode: "none", MaxAge: 1 * time.Hour,
|
|
MXs: []string{"mx1"}},
|
|
{Version: "STSv1", Mode: "none", MaxAge: 31557600 * time.Second,
|
|
MXs: []string{"mx1"}},
|
|
}
|
|
for i, p := range validPs {
|
|
if err := p.Check(); err != nil {
|
|
t.Errorf("%d policy %v failed check: %v", i, p, err)
|
|
}
|
|
}
|
|
|
|
invalid := []struct {
|
|
p Policy
|
|
expected error
|
|
}{
|
|
{Policy{Version: "STSv2"}, ErrUnknownVersion},
|
|
{Policy{Version: "STSv1"}, ErrInvalidMaxAge},
|
|
{Policy{Version: "STSv1", MaxAge: 31557601 * time.Second},
|
|
ErrInvalidMaxAge},
|
|
{Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode},
|
|
{Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce"}, ErrInvalidMX},
|
|
{Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce", MXs: []string{}},
|
|
ErrInvalidMX},
|
|
}
|
|
for i, c := range invalid {
|
|
if err := c.p.Check(); err != c.expected {
|
|
t.Errorf("%d policy %v check: expected %v, got %v", i, c.p,
|
|
c.expected, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMatchDomain(t *testing.T) {
|
|
cases := []struct {
|
|
domain, pattern string
|
|
expected bool
|
|
}{
|
|
{"lalala", "lalala", true},
|
|
{"a.b.", "a.b", true},
|
|
{"a.b", "a.b.", true},
|
|
{"abc.com", "*.com", true},
|
|
|
|
{"abc.com", "abc.*.com", false},
|
|
{"abc.com", "x.abc.com", false},
|
|
{"x.abc.com", "*.*.com", false},
|
|
{"abc.def.com", "abc.*.com", false},
|
|
|
|
{"ñaca.com", "ñaca.com", true},
|
|
{"Ñaca.com", "ñaca.com", true},
|
|
{"ñaca.com", "Ñaca.com", true},
|
|
{"x.ñaca.com", "x.xn--aca-6ma.com", true},
|
|
{"x.naca.com", "x.xn--aca-6ma.com", false},
|
|
|
|
// Triggers errors in domainToASCII.
|
|
{strings.Repeat("x", 65536) + "\uff00", "x.com", false},
|
|
|
|
// Examples from the RFC.
|
|
{"mail.example.com", "*.example.com", true},
|
|
{"example.com", "*.example.com", false},
|
|
{"foo.bar.example.com", "*.example.com", false},
|
|
|
|
// Missing "*" (invalid, seen in the wild).
|
|
{"aa.b.cc.com", ".aa.b.cc.com", false},
|
|
{"zz.aa.b.cc.com", ".aa.b.cc.com", false},
|
|
{"zz.aa.b.cc.com", "*.aa.b.cc.com", true},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
if r := matchDomain(c.domain, c.pattern); r != c.expected {
|
|
t.Errorf("matchDomain(%q, %q) = %v, expected %v",
|
|
c.domain, c.pattern, r, c.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMXIsAllowed(t *testing.T) {
|
|
p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
|
|
MXs: []string{"mx1", "mx2"}}
|
|
if p.MXIsAllowed("notamx") {
|
|
t.Errorf("notamx should not be allowed")
|
|
}
|
|
if !p.MXIsAllowed("mx1") {
|
|
t.Errorf("mx1 should be allowed")
|
|
}
|
|
if !p.MXIsAllowed("mx2") {
|
|
t.Errorf("mx2 should be allowed")
|
|
}
|
|
|
|
p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
|
|
MXs: []string{"mx1"}}
|
|
if !p.MXIsAllowed("notamx") {
|
|
t.Errorf("notamx should be allowed (policy not enforced)")
|
|
}
|
|
}
|
|
|
|
func TestFetch(t *testing.T) {
|
|
// Note the data "fetched" for each domain comes from policyForDomain,
|
|
// defined in TestMain above. See httpGet for more details.
|
|
|
|
// Normal fetch, all valid.
|
|
p, err := Fetch(context.Background(), "domain.com")
|
|
if err != nil {
|
|
t.Errorf("failed to fetch policy: %v", err)
|
|
}
|
|
t.Logf("domain.com: %+v", p)
|
|
|
|
// Domain without a policy (HTTP get fails).
|
|
p, err = Fetch(context.Background(), "policy404")
|
|
if err == nil {
|
|
t.Errorf("fetched unknown policy: %v", p)
|
|
}
|
|
t.Logf("policy404: got error as expected: %v", err)
|
|
|
|
// Domain with an invalid policy (unknown version).
|
|
p, err = Fetch(context.Background(), "version99")
|
|
if err != ErrUnknownVersion {
|
|
t.Errorf("expected error %v, got %v (and policy: %v)",
|
|
ErrUnknownVersion, err, p)
|
|
}
|
|
t.Logf("version99: got expected error: %v", err)
|
|
|
|
// Error fetching TXT record for this domain.
|
|
p, err = Fetch(context.Background(), "domErr")
|
|
if err != errTest {
|
|
t.Errorf("expected error %v, got %v (and policy: %v)",
|
|
errTest, err, p)
|
|
}
|
|
t.Logf("domErr: got expected error: %v", err)
|
|
}
|
|
|
|
func TestPolicyTooBig(t *testing.T) {
|
|
// Construct a valid but very large JSON as a policy.
|
|
raw := `{"version": "STSv1", "mode": "enforce", "mx": [`
|
|
for i := 0; i < 2000; i++ {
|
|
raw += fmt.Sprintf("\"mx%d\", ", i)
|
|
}
|
|
raw += `"mxlast"], "max_age": 100}`
|
|
policyForDomain["toobig"] = raw
|
|
|
|
_, err := Fetch(context.Background(), "toobig")
|
|
if err == nil {
|
|
t.Errorf("fetch worked, but should have failed")
|
|
}
|
|
t.Logf("got error as expected: %v", err)
|
|
}
|
|
|
|
// Tests for the policy cache.
|
|
|
|
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int64) {
|
|
if v.Value() != expected {
|
|
t.Errorf("%s is %d, expected %d", name, v.Value(), expected)
|
|
}
|
|
}
|
|
|
|
func TestCacheBasics(t *testing.T) {
|
|
dir := testlib.MustTempDir(t)
|
|
defer testlib.RemoveIfOk(t, dir)
|
|
|
|
c, err := NewCache(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Note the data "fetched" for each domain comes from policyForDomain,
|
|
// defined in TestMain above. See httpGet for more details.
|
|
|
|
// Reset the expvar counters that we use to validate hits, misses, etc.
|
|
cacheFetches.Set(0)
|
|
cacheHits.Set(0)
|
|
|
|
ctx := context.Background()
|
|
|
|
// Fetch domain.com, check we get a reasonable policy, and that it's a
|
|
// cache miss.
|
|
p, err := c.Fetch(ctx, "domain.com")
|
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
|
|
|
// Fetch domain.com again, this time we should see a cache hit.
|
|
p, err = c.Fetch(ctx, "domain.com")
|
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
|
|
|
// Simulate an expired cache entry by changing the mtime of domain.com's
|
|
// entry to the past.
|
|
expires := time.Now().Add(-1 * time.Minute)
|
|
os.Chtimes(c.domainPath("domain.com"), expires, expires)
|
|
|
|
// Do a third fetch, check that we don't get a cache hit.
|
|
p, err = c.Fetch(ctx, "domain.com")
|
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
|
|
|
// Fetch for a domain without policy.
|
|
p, err = c.Fetch(ctx, "domErr")
|
|
if err == nil || p != nil {
|
|
t.Errorf("expected failure, got: policy = %v ; error = %v", p, err)
|
|
}
|
|
t.Logf("cache fetched domErr: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 4)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
|
expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1)
|
|
}
|
|
|
|
// Test how the cache behaves when the files are corrupt.
|
|
func TestCacheBadData(t *testing.T) {
|
|
dir := testlib.MustTempDir(t)
|
|
defer testlib.RemoveIfOk(t, dir)
|
|
|
|
c, err := NewCache(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
cacheUnmarshalErrors.Set(0)
|
|
cacheInvalid.Set(0)
|
|
|
|
cases := []string{
|
|
// Case 1: A file with invalid json, which will fail unmarshalling.
|
|
"this is not valid json",
|
|
|
|
// Case 2: A file with a parseable but invalid policy.
|
|
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`,
|
|
}
|
|
|
|
for _, badContent := range cases {
|
|
// Reset the expvar counters that we use to validate hits, misses, etc.
|
|
cacheFetches.Set(0)
|
|
cacheHits.Set(0)
|
|
|
|
// Fetch domain.com, should result in the file being added to the
|
|
// cache.
|
|
p, err := c.Fetch(ctx, "domain.com")
|
|
if err != nil {
|
|
t.Fatalf("Fetch failed: %v", err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
|
|
|
// Edit the file, filling it with the bad content for this case.
|
|
fname := c.domainPath("domain.com")
|
|
mustRewriteAndChtime(t, fname, badContent)
|
|
|
|
// We now expect Fetch to fall back to getting the policy from the
|
|
// network (in our case, from policyForDomain).
|
|
p, err = c.Fetch(ctx, "domain.com")
|
|
if err != nil {
|
|
t.Fatalf("Fetch failed: %v", err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
|
|
|
// And now the file should be fine, resulting in a cache hit.
|
|
p, err = c.Fetch(ctx, "domain.com")
|
|
if err != nil {
|
|
t.Fatalf("Fetch failed: %v", err)
|
|
}
|
|
t.Logf("cache fetched domain.com: %v", p)
|
|
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
|
|
|
// Remove the file, to start with a clean slate for the next case.
|
|
os.Remove(fname)
|
|
}
|
|
|
|
expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1)
|
|
expvarMustEq(t, "cacheInvalid", cacheInvalid, 1)
|
|
}
|
|
|
|
func (c *PolicyCache) mustFetch(ctx context.Context, t *testing.T, d string) *Policy {
|
|
p, err := c.Fetch(ctx, d)
|
|
if err != nil {
|
|
t.Fatalf("Fetch %q failed: %v", d, err)
|
|
}
|
|
t.Logf("Fetch %q: %v", d, p)
|
|
return p
|
|
}
|
|
|
|
func mustRewriteAndChtime(t *testing.T, fname, content string) {
|
|
testlib.Rewrite(t, fname, content)
|
|
|
|
// Advance the expiration time to the future, so the rewritten policy is
|
|
// not considered expired.
|
|
expires := time.Now().Add(10 * time.Second)
|
|
err := os.Chtimes(fname, expires, expires)
|
|
if err != nil {
|
|
t.Fatalf("failed to chtime %q to the past: %v", fname, err)
|
|
}
|
|
}
|
|
|
|
func TestCacheRefresh(t *testing.T) {
|
|
dir := testlib.MustTempDir(t)
|
|
defer testlib.RemoveIfOk(t, dir)
|
|
|
|
c, err := NewCache(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
txtResults["_mta-sts.refresh-test"] = []string{"v=STSv1; id=blah;"}
|
|
policyForDomain["refresh-test"] = `
|
|
version: STSv1
|
|
mode: enforce
|
|
mx: mx
|
|
max_age: 100`
|
|
p := c.mustFetch(ctx, t, "refresh-test")
|
|
if p.MaxAge != 100*time.Second {
|
|
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
|
}
|
|
|
|
// Change the "published" policy, check that we see the old version at
|
|
// fetch (should be cached), and a new version after a refresh.
|
|
policyForDomain["refresh-test"] = `
|
|
version: STSv1
|
|
mode: enforce
|
|
mx: mx
|
|
max_age: 200`
|
|
|
|
p = c.mustFetch(ctx, t, "refresh-test")
|
|
if p.MaxAge != 100*time.Second {
|
|
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
|
}
|
|
|
|
// Launch background refreshes, and wait for one to complete.
|
|
cacheRefreshCycles.Set(0)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
defer cancel()
|
|
go c.PeriodicallyRefresh(ctx)
|
|
for cacheRefreshCycles.Value() == 0 {
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
|
|
p = c.mustFetch(ctx, t, "refresh-test")
|
|
if p.MaxAge != 200*time.Second {
|
|
t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge)
|
|
}
|
|
}
|
|
|
|
func TestCacheSlashSafe(t *testing.T) {
|
|
dir := testlib.MustTempDir(t)
|
|
c, err := NewCache(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Logf("recovered: %v", r)
|
|
} else {
|
|
t.Fatalf("check did not panic as expected")
|
|
}
|
|
}()
|
|
|
|
c.domainPath("a/b")
|
|
}
|
|
|
|
func TestURLForDomain(t *testing.T) {
|
|
// This function will behave differently if fakeURLForTesting is set, so
|
|
// temporarily unset it.
|
|
oldURL := fakeURLForTesting
|
|
fakeURLForTesting = ""
|
|
defer func() { fakeURLForTesting = oldURL }()
|
|
|
|
got := urlForDomain("a-test-domain")
|
|
expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.txt"
|
|
if got != expected {
|
|
t.Errorf("got %q, expected %q", got, expected)
|
|
}
|
|
}
|
|
|
|
func TestHasSTSRecord(t *testing.T) {
|
|
txtResults["_mta-sts.dom1"] = nil
|
|
txtResults["_mta-sts.dom2"] = []string{}
|
|
txtResults["_mta-sts.dom3"] = []string{"abc", "def"}
|
|
txtResults["_mta-sts.dom4"] = []string{"abc", "v=STSv1; id=blah;"}
|
|
|
|
cases := []struct {
|
|
domain string
|
|
ok bool
|
|
err error
|
|
}{
|
|
{"", false, nil},
|
|
{"dom1", false, nil},
|
|
{"dom2", false, nil},
|
|
{"dom3", false, nil},
|
|
{"dom4", true, nil},
|
|
{"domErr", false, errTest},
|
|
}
|
|
for _, c := range cases {
|
|
ok, err := hasSTSRecord(c.domain)
|
|
if ok != c.ok || err != c.err {
|
|
t.Errorf("%s: expected {%v, %v}, got {%v, %v}", c.domain,
|
|
c.ok, c.err, ok, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHTTPGet(t *testing.T) {
|
|
// Basic test, it should work.
|
|
srv1 := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(policyForDomain["domain.com"]))
|
|
}))
|
|
defer srv1.Close()
|
|
|
|
ctx := context.Background()
|
|
raw, err := httpGet(ctx, srv1.URL)
|
|
if err != nil {
|
|
t.Errorf("GET failed: got %q, %v", raw, err)
|
|
}
|
|
|
|
// Test that redirects are rejected.
|
|
srv2 := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently)
|
|
}))
|
|
defer srv2.Close()
|
|
|
|
raw, err = httpGet(ctx, srv2.URL)
|
|
if err == nil {
|
|
t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err)
|
|
}
|
|
|
|
// Content type != text/plain should be rejected.
|
|
srv3 := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/json")
|
|
w.Write([]byte(policyForDomain["domain.com"]))
|
|
}))
|
|
defer srv3.Close()
|
|
|
|
raw, err = httpGet(ctx, srv3.URL)
|
|
if err != ErrInvalidMediaType {
|
|
t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err)
|
|
}
|
|
|
|
// Invalid (unparseable) media type.
|
|
srv4 := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "invalid/content/type")
|
|
w.Write([]byte(policyForDomain["domain.com"]))
|
|
}))
|
|
defer srv4.Close()
|
|
|
|
raw, err = httpGet(ctx, srv4.URL)
|
|
if err == nil || err == ErrInvalidMediaType {
|
|
t.Errorf("invalid content type was allowed: got %q, %v", raw, err)
|
|
}
|
|
}
|