1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-18 14:47:03 +00:00

sts: Add miscellaneous tests

This patch adds a few miscellaneous tests to the sts package, covering
various previously-untested code paths.
This commit is contained in:
Alberto Bertogli
2018-05-27 16:20:35 +01:00
parent 79a8cfc21c
commit 46bce576e8
2 changed files with 146 additions and 34 deletions

View File

@@ -225,12 +225,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
CheckRedirect: rejectRedirect,
}
// Note that http does not care for the context deadline, so we need to
// construct it here.
if deadline, ok := ctx.Deadline(); ok {
client.Timeout = deadline.Sub(time.Now())
}
resp, err := ctxhttp.Get(ctx, client, url)
if err != nil {
return nil, err
@@ -458,9 +452,8 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error)
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
for ctx.Err() == nil {
cacheRefreshCycles.Add(1)
c.refresh(ctx)
cacheRefreshCycles.Add(1)
// Wait 10 minutes between passes; this is a background refresh and
// there's no need to poke the servers very often.

View File

@@ -4,13 +4,15 @@ import (
"context"
"expvar"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
"time"
"blitiri.com.ar/go/chasquid/internal/testlib"
)
// Override the lookup function to control its results.
@@ -145,10 +147,18 @@ func TestMatchDomain(t *testing.T) {
{"x.ñaca.com", "x.xn--aca-6ma.com", true},
{"x.naca.com", "x.xn--aca-6ma.com", false},
// Triggers errors in domainToASCII.
{strings.Repeat("x", 65536) + "\uff00", "x.com", false},
// Examples from the RFC.
{"mail.example.com", "*.example.com", true},
{"example.com", "*.example.com", false},
{"foo.bar.example.com", "*.example.com", false},
// Missing "*" (invalid, seen in the wild).
{"aa.b.cc.com", ".aa.b.cc.com", false},
{"zz.aa.b.cc.com", ".aa.b.cc.com", false},
{"zz.aa.b.cc.com", "*.aa.b.cc.com", true},
}
for _, c := range cases {
@@ -159,6 +169,26 @@ func TestMatchDomain(t *testing.T) {
}
}
func TestMXIsAllowed(t *testing.T) {
p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
MXs: []string{"mx1", "mx2"}}
if p.MXIsAllowed("notamx") {
t.Errorf("notamx should not be allowed")
}
if !p.MXIsAllowed("mx1") {
t.Errorf("mx1 should be allowed")
}
if !p.MXIsAllowed("mx2") {
t.Errorf("mx2 should be allowed")
}
p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
MXs: []string{"mx1"}}
if !p.MXIsAllowed("notamx") {
t.Errorf("notamx should be allowed (policy not enforced)")
}
}
func TestFetch(t *testing.T) {
// Note the data "fetched" for each domain comes from policyForDomain,
// defined in TestMain above. See httpGet for more details.
@@ -212,22 +242,6 @@ func TestPolicyTooBig(t *testing.T) {
// Tests for the policy cache.
func mustTempDir(t *testing.T) string {
dir, err := ioutil.TempDir("", "sts_test")
if err != nil {
t.Fatal(err)
}
err = os.Chdir(dir)
if err != nil {
t.Fatal(err)
}
t.Logf("test directory: %q", dir)
return dir
}
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
// TODO: Use v.Value once we drop support of Go 1.7.
value, _ := strconv.Atoi(v.String())
@@ -237,7 +251,7 @@ func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
}
func TestCacheBasics(t *testing.T) {
dir := mustTempDir(t)
dir := testlib.MustTempDir(t)
c, err := NewCache(dir)
if err != nil {
t.Fatal(err)
@@ -285,6 +299,16 @@ func TestCacheBasics(t *testing.T) {
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
expvarMustEq(t, "cacheHits", cacheHits, 1)
// Fetch for a domain without policy.
p, err = c.Fetch(ctx, "domErr")
if err == nil || p != nil {
t.Errorf("expected failure, got: policy = %v ; error = %v", p, err)
}
t.Logf("cache fetched domErr: %v", p)
expvarMustEq(t, "cacheFetches", cacheFetches, 4)
expvarMustEq(t, "cacheHits", cacheHits, 1)
expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1)
if !t.Failed() {
os.RemoveAll(dir)
}
@@ -292,7 +316,7 @@ func TestCacheBasics(t *testing.T) {
// Test how the cache behaves when the files are corrupt.
func TestCacheBadData(t *testing.T) {
dir := mustTempDir(t)
dir := testlib.MustTempDir(t)
c, err := NewCache(dir)
if err != nil {
t.Fatal(err)
@@ -300,12 +324,15 @@ func TestCacheBadData(t *testing.T) {
ctx := context.Background()
cacheUnmarshalErrors.Set(0)
cacheInvalid.Set(0)
cases := []string{
// Case 1: A file with invalid json, which will fail unmarshalling.
"this is not valid json",
// Case 2: A file with a parseable but invalid policy.
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`,
}
for _, badContent := range cases {
@@ -325,10 +352,7 @@ func TestCacheBadData(t *testing.T) {
// Edit the file, filling it with the bad content for this case.
fname := c.domainPath("domain.com")
err = ioutil.WriteFile(fname, []byte(badContent), 0644)
if err != nil {
t.Fatalf("error writing file: %v", err)
}
mustRewriteAndChtime(t, fname, badContent)
// We now expect Fetch to fall back to getting the policy from the
// network (in our case, from policyForDomain).
@@ -353,6 +377,9 @@ func TestCacheBadData(t *testing.T) {
os.Remove(fname)
}
expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1)
expvarMustEq(t, "cacheInvalid", cacheInvalid, 1)
if !t.Failed() {
os.RemoveAll(dir)
}
@@ -367,8 +394,20 @@ func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Pol
return p
}
func mustRewriteAndChtime(t *testing.T, fname, content string) {
testlib.Rewrite(t, fname, content)
// Advance the expiration time to the future, so the rewritten policy is
// not considered expired.
expires := time.Now().Add(10 * time.Second)
err := os.Chtimes(fname, expires, expires)
if err != nil {
t.Fatalf("failed to chtime %q to the past: %v", fname, err)
}
}
func TestCacheRefresh(t *testing.T) {
dir := mustTempDir(t)
dir := testlib.MustTempDir(t)
c, err := NewCache(dir)
if err != nil {
t.Fatal(err)
@@ -400,7 +439,16 @@ func TestCacheRefresh(t *testing.T) {
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
}
c.refresh(ctx)
// Launch background refreshes, and wait for one to complete.
// TODO: change to cacheRefreshCycles.Value once we drop support for Go
// 1.7.
cacheRefreshCycles.Set(0)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
go c.PeriodicallyRefresh(ctx)
for cacheRefreshCycles.String() == "0" {
time.Sleep(5 * time.Millisecond)
}
p = mustFetch(t, c, ctx, "refresh-test")
if p.MaxAge != 200*time.Second {
@@ -412,6 +460,24 @@ func TestCacheRefresh(t *testing.T) {
}
}
func TestCacheSlashSafe(t *testing.T) {
dir := testlib.MustTempDir(t)
c, err := NewCache(dir)
if err != nil {
t.Fatal(err)
}
defer func() {
if r := recover(); r != nil {
t.Logf("recovered: %v", r)
} else {
t.Fatalf("check did not panic as expected")
}
}()
c.domainPath("a/b")
}
func TestURLForDomain(t *testing.T) {
// This function will behave differently if fakeURLForTesting is set, so
// temporarily unset it.
@@ -452,3 +518,56 @@ func TestHasSTSRecord(t *testing.T) {
}
}
}
func TestHTTPGet(t *testing.T) {
// Basic test, it should work.
srv1 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv1.Close()
ctx := context.Background()
raw, err := httpGet(ctx, srv1.URL)
if err != nil {
t.Errorf("GET failed: got %q, %v", raw, err)
}
// Test that redirects are rejected.
srv2 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently)
}))
defer srv2.Close()
raw, err = httpGet(ctx, srv2.URL)
if err == nil {
t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err)
}
// Content type != text/plain should be rejected.
srv3 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/json")
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv3.Close()
raw, err = httpGet(ctx, srv3.URL)
if err != ErrInvalidMediaType {
t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err)
}
// Invalid (unparseable) media type.
srv4 := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "invalid/content/type")
w.Write([]byte(policyForDomain["domain.com"]))
}))
defer srv4.Close()
raw, err = httpGet(ctx, srv4.URL)
if err == nil || err == ErrInvalidMediaType {
t.Errorf("invalid content type was allowed: got %q, %v", raw, err)
}
}