1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-18 14:47:03 +00:00

sts: Add miscellaneous tests

This patch adds a few miscellaneous tests to the sts package, covering
various previously-untested code paths.
This commit is contained in:
Alberto Bertogli
2018-05-27 16:20:35 +01:00
parent 79a8cfc21c
commit 46bce576e8
2 changed files with 146 additions and 34 deletions

View File

@@ -225,12 +225,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
CheckRedirect: rejectRedirect, CheckRedirect: rejectRedirect,
} }
// Note that http does not care for the context deadline, so we need to
// construct it here.
if deadline, ok := ctx.Deadline(); ok {
client.Timeout = deadline.Sub(time.Now())
}
resp, err := ctxhttp.Get(ctx, client, url) resp, err := ctxhttp.Get(ctx, client, url)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -458,9 +452,8 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error)
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
for ctx.Err() == nil { for ctx.Err() == nil {
cacheRefreshCycles.Add(1)
c.refresh(ctx) c.refresh(ctx)
cacheRefreshCycles.Add(1)
// Wait 10 minutes between passes; this is a background refresh and // Wait 10 minutes between passes; this is a background refresh and
// there's no need to poke the servers very often. // there's no need to poke the servers very often.

View File

@@ -4,13 +4,15 @@ import (
"context" "context"
"expvar" "expvar"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
"blitiri.com.ar/go/chasquid/internal/testlib"
) )
// Override the lookup function to control its results. // Override the lookup function to control its results.
@@ -145,10 +147,18 @@ func TestMatchDomain(t *testing.T) {
{"x.ñaca.com", "x.xn--aca-6ma.com", true}, {"x.ñaca.com", "x.xn--aca-6ma.com", true},
{"x.naca.com", "x.xn--aca-6ma.com", false}, {"x.naca.com", "x.xn--aca-6ma.com", false},
// Triggers errors in domainToASCII.
{strings.Repeat("x", 65536) + "\uff00", "x.com", false},
// Examples from the RFC. // Examples from the RFC.
{"mail.example.com", "*.example.com", true}, {"mail.example.com", "*.example.com", true},
{"example.com", "*.example.com", false}, {"example.com", "*.example.com", false},
{"foo.bar.example.com", "*.example.com", false}, {"foo.bar.example.com", "*.example.com", false},
// Missing "*" (invalid, seen in the wild).
{"aa.b.cc.com", ".aa.b.cc.com", false},
{"zz.aa.b.cc.com", ".aa.b.cc.com", false},
{"zz.aa.b.cc.com", "*.aa.b.cc.com", true},
} }
for _, c := range cases { for _, c := range cases {
@@ -159,6 +169,26 @@ func TestMatchDomain(t *testing.T) {
} }
} }
func TestMXIsAllowed(t *testing.T) {
p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
MXs: []string{"mx1", "mx2"}}
if p.MXIsAllowed("notamx") {
t.Errorf("notamx should not be allowed")
}
if !p.MXIsAllowed("mx1") {
t.Errorf("mx1 should be allowed")
}
if !p.MXIsAllowed("mx2") {
t.Errorf("mx2 should be allowed")
}
p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
MXs: []string{"mx1"}}
if !p.MXIsAllowed("notamx") {
t.Errorf("notamx should be allowed (policy not enforced)")
}
}
func TestFetch(t *testing.T) { func TestFetch(t *testing.T) {
// Note the data "fetched" for each domain comes from policyForDomain, // Note the data "fetched" for each domain comes from policyForDomain,
// defined in TestMain above. See httpGet for more details. // defined in TestMain above. See httpGet for more details.
@@ -212,22 +242,6 @@ func TestPolicyTooBig(t *testing.T) {
// Tests for the policy cache. // Tests for the policy cache.
func mustTempDir(t *testing.T) string {
dir, err := ioutil.TempDir("", "sts_test")
if err != nil {
t.Fatal(err)
}
err = os.Chdir(dir)
if err != nil {
t.Fatal(err)
}
t.Logf("test directory: %q", dir)
return dir
}
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) { func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
// TODO: Use v.Value once we drop support of Go 1.7. // TODO: Use v.Value once we drop support of Go 1.7.
value, _ := strconv.Atoi(v.String()) value, _ := strconv.Atoi(v.String())
@@ -237,7 +251,7 @@ func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
} }
func TestCacheBasics(t *testing.T) { func TestCacheBasics(t *testing.T) {
dir := mustTempDir(t) dir := testlib.MustTempDir(t)
c, err := NewCache(dir) c, err := NewCache(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -285,6 +299,16 @@ func TestCacheBasics(t *testing.T) {
expvarMustEq(t, "cacheFetches", cacheFetches, 3) expvarMustEq(t, "cacheFetches", cacheFetches, 3)
expvarMustEq(t, "cacheHits", cacheHits, 1) expvarMustEq(t, "cacheHits", cacheHits, 1)
// Fetch for a domain without policy.
p, err = c.Fetch(ctx, "domErr")
if err == nil || p != nil {
t.Errorf("expected failure, got: policy = %v ; error = %v", p, err)
}
t.Logf("cache fetched domErr: %v", p)
expvarMustEq(t, "cacheFetches", cacheFetches, 4)
expvarMustEq(t, "cacheHits", cacheHits, 1)
expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1)
if !t.Failed() { if !t.Failed() {
os.RemoveAll(dir) os.RemoveAll(dir)
} }
@@ -292,7 +316,7 @@ func TestCacheBasics(t *testing.T) {
// Test how the cache behaves when the files are corrupt. // Test how the cache behaves when the files are corrupt.
func TestCacheBadData(t *testing.T) { func TestCacheBadData(t *testing.T) {
dir := mustTempDir(t) dir := testlib.MustTempDir(t)
c, err := NewCache(dir) c, err := NewCache(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -300,12 +324,15 @@ func TestCacheBadData(t *testing.T) {
ctx := context.Background() ctx := context.Background()
cacheUnmarshalErrors.Set(0)
cacheInvalid.Set(0)
cases := []string{ cases := []string{
// Case 1: A file with invalid json, which will fail unmarshalling. // Case 1: A file with invalid json, which will fail unmarshalling.
"this is not valid json", "this is not valid json",
// Case 2: A file with a parseable but invalid policy. // Case 2: A file with a parseable but invalid policy.
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`, `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`,
} }
for _, badContent := range cases { for _, badContent := range cases {
@@ -325,10 +352,7 @@ func TestCacheBadData(t *testing.T) {
// Edit the file, filling it with the bad content for this case. // Edit the file, filling it with the bad content for this case.
fname := c.domainPath("domain.com") fname := c.domainPath("domain.com")
err = ioutil.WriteFile(fname, []byte(badContent), 0644) mustRewriteAndChtime(t, fname, badContent)
if err != nil {
t.Fatalf("error writing file: %v", err)
}
// We now expect Fetch to fall back to getting the policy from the // We now expect Fetch to fall back to getting the policy from the
// network (in our case, from policyForDomain). // network (in our case, from policyForDomain).
@@ -353,6 +377,9 @@ func TestCacheBadData(t *testing.T) {
os.Remove(fname) os.Remove(fname)
} }
expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1)
expvarMustEq(t, "cacheInvalid", cacheInvalid, 1)
if !t.Failed() { if !t.Failed() {
os.RemoveAll(dir) os.RemoveAll(dir)
} }
@@ -367,8 +394,20 @@ func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Pol
return p return p
} }
func mustRewriteAndChtime(t *testing.T, fname, content string) {
testlib.Rewrite(t, fname, content)
// Advance the expiration time to the future, so the rewritten policy is
// not considered expired.
expires := time.Now().Add(10 * time.Second)
err := os.Chtimes(fname, expires, expires)
if err != nil {
t.Fatalf("failed to chtime %q to the past: %v", fname, err)
}
}
func TestCacheRefresh(t *testing.T) { func TestCacheRefresh(t *testing.T) {
dir := mustTempDir(t) dir := testlib.MustTempDir(t)
c, err := NewCache(dir) c, err := NewCache(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -400,7 +439,16 @@ func TestCacheRefresh(t *testing.T) {
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
} }
c.refresh(ctx) // Launch background refreshes, and wait for one to complete.
// TODO: change to cacheRefreshCycles.Value once we drop support for Go
// 1.7.
cacheRefreshCycles.Set(0)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
go c.PeriodicallyRefresh(ctx)
for cacheRefreshCycles.String() == "0" {
time.Sleep(5 * time.Millisecond)
}
p = mustFetch(t, c, ctx, "refresh-test") p = mustFetch(t, c, ctx, "refresh-test")
if p.MaxAge != 200*time.Second { if p.MaxAge != 200*time.Second {
@@ -412,6 +460,24 @@ func TestCacheRefresh(t *testing.T) {
} }
} }
func TestCacheSlashSafe(t *testing.T) {
dir := testlib.MustTempDir(t)
c, err := NewCache(dir)
if err != nil {
t.Fatal(err)
}
defer func() {
if r := recover(); r != nil {
t.Logf("recovered: %v", r)
} else {
t.Fatalf("check did not panic as expected")
}
}()
c.domainPath("a/b")
}
func TestURLForDomain(t *testing.T) { func TestURLForDomain(t *testing.T) {
// This function will behave differently if fakeURLForTesting is set, so // This function will behave differently if fakeURLForTesting is set, so
// temporarily unset it. // temporarily unset it.
@@ -452,3 +518,56 @@ func TestHasSTSRecord(t *testing.T) {
} }
} }
} }
func TestHTTPGet(t *testing.T) {
// Basic test, it should work.
srv1 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv1.Close()
ctx := context.Background()
raw, err := httpGet(ctx, srv1.URL)
if err != nil {
t.Errorf("GET failed: got %q, %v", raw, err)
}
// Test that redirects are rejected.
srv2 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently)
}))
defer srv2.Close()
raw, err = httpGet(ctx, srv2.URL)
if err == nil {
t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err)
}
// Content type != text/plain should be rejected.
srv3 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/json")
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv3.Close()
raw, err = httpGet(ctx, srv3.URL)
if err != ErrInvalidMediaType {
t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err)
}
// Invalid (unparseable) media type.
srv4 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "invalid/content/type")
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv4.Close()
raw, err = httpGet(ctx, srv4.URL)
if err == nil || err == ErrInvalidMediaType {
t.Errorf("invalid content type was allowed: got %q, %v", raw, err)
}
}