1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-23 15:37:01 +00:00

sts: Make tests more end-to-end, to cover HTTP fetching

The current tests stop short of fetching over HTTP, but that code is
unfortunately not trivial.

This patch changes the testing strategy to use a testing HTTP server,
which we point our URLs to. That way we can cover much more code with the
same tests.
This commit is contained in:
Alberto Bertogli
2017-02-28 23:57:04 +00:00
parent 216cf47ffa
commit e66288e4b4
2 changed files with 65 additions and 33 deletions

View File

@@ -15,6 +15,7 @@ import (
"errors" "errors"
"expvar" "expvar"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
@@ -130,11 +131,7 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
return nil, err return nil, err
} }
// URL composed from the domain, as explained in: url := urlForDomain(domain)
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
url := "https://mta-sts." + domain + "/.well-known/mta-sts.json"
rawPolicy, err := httpGet(ctx, url) rawPolicy, err := httpGet(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -143,6 +140,21 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
return parsePolicy(rawPolicy) return parsePolicy(rawPolicy)
} }
// Fake URL for testing purposes, so we can do more end-to-end tests,
// including the HTTP fetching code.
var fakeURLForTesting string
func urlForDomain(domain string) string {
if fakeURLForTesting != "" {
return fakeURLForTesting + "/" + domain
}
// URL composed from the domain, as explained in:
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
return "https://mta-sts." + domain + "/.well-known/mta-sts.json"
}
// Fetch a policy for the given domain. Note this results in various network // Fetch a policy for the given domain. Note this results in various network
// lookups and HTTPS GETs, so it can be slow. // lookups and HTTPS GETs, so it can be slow.
// The returned policy is parsed and sanity-checked (using Policy.Check), so // The returned policy is parsed and sanity-checked (using Policy.Check), so
@@ -161,9 +173,6 @@ func Fetch(ctx context.Context, domain string) (*Policy, error) {
return p, nil return p, nil
} }
// Fake HTTP content for testing purposes only.
var fakeContent = map[string]string{}
// httpGet performs an HTTP GET of the given URL, using the context and // httpGet performs an HTTP GET of the given URL, using the context and
// rejecting redirects, as per the standard. // rejecting redirects, as per the standard.
func httpGet(ctx context.Context, url string) ([]byte, error) { func httpGet(ctx context.Context, url string) ([]byte, error) {
@@ -179,16 +188,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
client.Timeout = deadline.Sub(time.Now()) client.Timeout = deadline.Sub(time.Now())
} }
if len(fakeContent) > 0 {
// If we have fake content for testing, then return the content for
// the URL, or an error if it's missing.
// This makes sure we don't make actual requests for testing.
if d, ok := fakeContent[url]; ok {
return []byte(d), nil
}
return nil, errors.New("error for testing")
}
resp, err := ctxhttp.Get(ctx, client, url) resp, err := ctxhttp.Get(ctx, client, url)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -3,34 +3,53 @@ package sts
import ( import (
"context" "context"
"expvar" "expvar"
"fmt"
"io/ioutil" "io/ioutil"
"net/http"
"net/http/httptest"
"os" "os"
"testing" "testing"
"time" "time"
) )
func TestMain(m *testing.M) { // Test policy for each of the requested domains. Will be served by the test
// Populate the fake policy contents, used by a few tests. // HTTP server.
// httpGet will use this data instead of using the network. var policyForDomain = map[string]string{
// domain.com -> valid, with reasonable policy. // domain.com -> valid, with reasonable policy.
fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = ` "domain.com": `
{ {
"version": "STSv1", "version": "STSv1",
"mode": "enforce", "mode": "enforce",
"mx": ["*.mail.domain.com"], "mx": ["*.mail.domain.com"],
"max_age": 3600 "max_age": 3600
}` }`,
// version99 -> invalid policy (unknown version). // version99 -> invalid policy (unknown version).
fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = ` "version99": `
{ {
"version": "STSv99", "version": "STSv99",
"mode": "enforce", "mode": "enforce",
"mx": ["*.mail.version99"], "mx": ["*.mail.version99"],
"max_age": 999 "max_age": 999
}` }`,
}
func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
// For testing, the domain in the path (see urlForDomain).
policy, ok := policyForDomain[r.URL.Path[1:]]
if !ok {
http.Error(w, "not found", 404)
return
}
fmt.Fprintln(w, policy)
return
}
func TestMain(m *testing.M) {
// Create a test HTTP server, used by the more end-to-end tests.
httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
fakeURLForTesting = httpServer.URL
os.Exit(m.Run()) os.Exit(m.Run())
} }
@@ -112,8 +131,8 @@ func TestMatchDomain(t *testing.T) {
} }
func TestFetch(t *testing.T) { func TestFetch(t *testing.T) {
// Note the data "fetched" for each domain comes from fakeContent, defined // Note the data "fetched" for each domain comes from policyForDomain,
// in TestMain above. See httpGet for more details. // defined in TestMain above. See httpGet for more details.
// Normal fetch, all valid. // Normal fetch, all valid.
p, err := Fetch(context.Background(), "domain.com") p, err := Fetch(context.Background(), "domain.com")
@@ -170,8 +189,8 @@ func TestCacheBasics(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Note the data "fetched" for each domain comes from fakeContent, defined // Note the data "fetched" for each domain comes from policyForDomain,
// in TestMain above. See httpGet for more details. // defined in TestMain above. See httpGet for more details.
// Reset the expvar counters that we use to validate hits, misses, etc. // Reset the expvar counters that we use to validate hits, misses, etc.
cacheFetches.Set(0) cacheFetches.Set(0)
@@ -258,7 +277,7 @@ func TestCacheBadData(t *testing.T) {
} }
// We now expect Fetch to fall back to getting the policy from the // We now expect Fetch to fall back to getting the policy from the
// network (in our case, from fakeContent). // network (in our case, from policyForDomain).
p, err = c.Fetch(ctx, "domain.com") p, err = c.Fetch(ctx, "domain.com")
if err != nil { if err != nil {
t.Fatalf("Fetch failed: %v", err) t.Fatalf("Fetch failed: %v", err)
@@ -303,7 +322,7 @@ func TestCacheRefresh(t *testing.T) {
ctx := context.Background() ctx := context.Background()
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` policyForDomain["refresh-test"] = `
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}` {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
p := mustFetch(t, c, ctx, "refresh-test") p := mustFetch(t, c, ctx, "refresh-test")
if p.MaxAge != 100*time.Second { if p.MaxAge != 100*time.Second {
@@ -312,7 +331,7 @@ func TestCacheRefresh(t *testing.T) {
// Change the "published" policy, check that we see the old version at // Change the "published" policy, check that we see the old version at
// fetch (should be cached), and a new version after a refresh. // fetch (should be cached), and a new version after a refresh.
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` policyForDomain["refresh-test"] = `
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}` {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
p = mustFetch(t, c, ctx, "refresh-test") p = mustFetch(t, c, ctx, "refresh-test")
@@ -331,3 +350,17 @@ func TestCacheRefresh(t *testing.T) {
os.RemoveAll(dir) os.RemoveAll(dir)
} }
} }
func TestURLForDomain(t *testing.T) {
// This function will behave differently if fakeURLForTesting is set, so
// temporarily unset it.
oldURL := fakeURLForTesting
fakeURLForTesting = ""
defer func() { fakeURLForTesting = oldURL }()
got := urlForDomain("a-test-domain")
expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json"
if got != expected {
t.Errorf("got %q, expected %q", got, expected)
}
}