mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-18 14:47:03 +00:00
sts: Add miscellaneous tests
This patch adds a few miscellaneous tests to the sts package, covering various previously-untested code paths.
This commit is contained in:
@@ -225,12 +225,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
|
||||
CheckRedirect: rejectRedirect,
|
||||
}
|
||||
|
||||
// Note that http does not care for the context deadline, so we need to
|
||||
// construct it here.
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
client.Timeout = deadline.Sub(time.Now())
|
||||
}
|
||||
|
||||
resp, err := ctxhttp.Get(ctx, client, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -458,9 +452,8 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error)
|
||||
|
||||
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
|
||||
for ctx.Err() == nil {
|
||||
cacheRefreshCycles.Add(1)
|
||||
|
||||
c.refresh(ctx)
|
||||
cacheRefreshCycles.Add(1)
|
||||
|
||||
// Wait 10 minutes between passes; this is a background refresh and
|
||||
// there's no need to poke the servers very often.
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"context"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/testlib"
|
||||
)
|
||||
|
||||
// Override the lookup function to control its results.
|
||||
@@ -145,10 +147,18 @@ func TestMatchDomain(t *testing.T) {
|
||||
{"x.ñaca.com", "x.xn--aca-6ma.com", true},
|
||||
{"x.naca.com", "x.xn--aca-6ma.com", false},
|
||||
|
||||
// Triggers errors in domainToASCII.
|
||||
{strings.Repeat("x", 65536) + "\uff00", "x.com", false},
|
||||
|
||||
// Examples from the RFC.
|
||||
{"mail.example.com", "*.example.com", true},
|
||||
{"example.com", "*.example.com", false},
|
||||
{"foo.bar.example.com", "*.example.com", false},
|
||||
|
||||
// Missing "*" (invalid, seen in the wild).
|
||||
{"aa.b.cc.com", ".aa.b.cc.com", false},
|
||||
{"zz.aa.b.cc.com", ".aa.b.cc.com", false},
|
||||
{"zz.aa.b.cc.com", "*.aa.b.cc.com", true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
@@ -159,6 +169,26 @@ func TestMatchDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMXIsAllowed(t *testing.T) {
|
||||
p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
|
||||
MXs: []string{"mx1", "mx2"}}
|
||||
if p.MXIsAllowed("notamx") {
|
||||
t.Errorf("notamx should not be allowed")
|
||||
}
|
||||
if !p.MXIsAllowed("mx1") {
|
||||
t.Errorf("mx1 should be allowed")
|
||||
}
|
||||
if !p.MXIsAllowed("mx2") {
|
||||
t.Errorf("mx2 should be allowed")
|
||||
}
|
||||
|
||||
p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
|
||||
MXs: []string{"mx1"}}
|
||||
if !p.MXIsAllowed("notamx") {
|
||||
t.Errorf("notamx should be allowed (policy not enforced)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetch(t *testing.T) {
|
||||
// Note the data "fetched" for each domain comes from policyForDomain,
|
||||
// defined in TestMain above. See httpGet for more details.
|
||||
@@ -212,22 +242,6 @@ func TestPolicyTooBig(t *testing.T) {
|
||||
|
||||
// Tests for the policy cache.
|
||||
|
||||
func mustTempDir(t *testing.T) string {
|
||||
dir, err := ioutil.TempDir("", "sts_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.Chdir(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("test directory: %q", dir)
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
|
||||
// TODO: Use v.Value once we drop support of Go 1.7.
|
||||
value, _ := strconv.Atoi(v.String())
|
||||
@@ -237,7 +251,7 @@ func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
|
||||
}
|
||||
|
||||
func TestCacheBasics(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
dir := testlib.MustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -285,6 +299,16 @@ func TestCacheBasics(t *testing.T) {
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||
|
||||
// Fetch for a domain without policy.
|
||||
p, err = c.Fetch(ctx, "domErr")
|
||||
if err == nil || p != nil {
|
||||
t.Errorf("expected failure, got: policy = %v ; error = %v", p, err)
|
||||
}
|
||||
t.Logf("cache fetched domErr: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 4)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||
expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1)
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
@@ -292,7 +316,7 @@ func TestCacheBasics(t *testing.T) {
|
||||
|
||||
// Test how the cache behaves when the files are corrupt.
|
||||
func TestCacheBadData(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
dir := testlib.MustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -300,12 +324,15 @@ func TestCacheBadData(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
cacheUnmarshalErrors.Set(0)
|
||||
cacheInvalid.Set(0)
|
||||
|
||||
cases := []string{
|
||||
// Case 1: A file with invalid json, which will fail unmarshalling.
|
||||
"this is not valid json",
|
||||
|
||||
// Case 2: A file with a parseable but invalid policy.
|
||||
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
|
||||
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`,
|
||||
}
|
||||
|
||||
for _, badContent := range cases {
|
||||
@@ -325,10 +352,7 @@ func TestCacheBadData(t *testing.T) {
|
||||
|
||||
// Edit the file, filling it with the bad content for this case.
|
||||
fname := c.domainPath("domain.com")
|
||||
err = ioutil.WriteFile(fname, []byte(badContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing file: %v", err)
|
||||
}
|
||||
mustRewriteAndChtime(t, fname, badContent)
|
||||
|
||||
// We now expect Fetch to fall back to getting the policy from the
|
||||
// network (in our case, from policyForDomain).
|
||||
@@ -353,6 +377,9 @@ func TestCacheBadData(t *testing.T) {
|
||||
os.Remove(fname)
|
||||
}
|
||||
|
||||
expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1)
|
||||
expvarMustEq(t, "cacheInvalid", cacheInvalid, 1)
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
@@ -367,8 +394,20 @@ func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Pol
|
||||
return p
|
||||
}
|
||||
|
||||
func mustRewriteAndChtime(t *testing.T, fname, content string) {
|
||||
testlib.Rewrite(t, fname, content)
|
||||
|
||||
// Advance the expiration time to the future, so the rewritten policy is
|
||||
// not considered expired.
|
||||
expires := time.Now().Add(10 * time.Second)
|
||||
err := os.Chtimes(fname, expires, expires)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to chtime %q to the past: %v", fname, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRefresh(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
dir := testlib.MustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -400,7 +439,16 @@ func TestCacheRefresh(t *testing.T) {
|
||||
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
||||
}
|
||||
|
||||
c.refresh(ctx)
|
||||
// Launch background refreshes, and wait for one to complete.
|
||||
// TODO: change to cacheRefreshCycles.Value once we drop support for Go
|
||||
// 1.7.
|
||||
cacheRefreshCycles.Set(0)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
go c.PeriodicallyRefresh(ctx)
|
||||
for cacheRefreshCycles.String() == "0" {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
p = mustFetch(t, c, ctx, "refresh-test")
|
||||
if p.MaxAge != 200*time.Second {
|
||||
@@ -412,6 +460,24 @@ func TestCacheRefresh(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSlashSafe(t *testing.T) {
|
||||
dir := testlib.MustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("recovered: %v", r)
|
||||
} else {
|
||||
t.Fatalf("check did not panic as expected")
|
||||
}
|
||||
}()
|
||||
|
||||
c.domainPath("a/b")
|
||||
}
|
||||
|
||||
func TestURLForDomain(t *testing.T) {
|
||||
// This function will behave differently if fakeURLForTesting is set, so
|
||||
// temporarily unset it.
|
||||
@@ -452,3 +518,56 @@ func TestHasSTSRecord(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPGet(t *testing.T) {
|
||||
// Basic test, it should work.
|
||||
srv1 := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(policyForDomain["domain.com"]))
|
||||
}))
|
||||
defer srv1.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
raw, err := httpGet(ctx, srv1.URL)
|
||||
if err != nil {
|
||||
t.Errorf("GET failed: got %q, %v", raw, err)
|
||||
}
|
||||
|
||||
// Test that redirects are rejected.
|
||||
srv2 := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently)
|
||||
}))
|
||||
defer srv2.Close()
|
||||
|
||||
raw, err = httpGet(ctx, srv2.URL)
|
||||
if err == nil {
|
||||
t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err)
|
||||
}
|
||||
|
||||
// Content type != text/plain should be rejected.
|
||||
srv3 := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Write([]byte(policyForDomain["domain.com"]))
|
||||
}))
|
||||
defer srv3.Close()
|
||||
|
||||
raw, err = httpGet(ctx, srv3.URL)
|
||||
if err != ErrInvalidMediaType {
|
||||
t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err)
|
||||
}
|
||||
|
||||
// Invalid (unparseable) media type.
|
||||
srv4 := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "invalid/content/type")
|
||||
w.Write([]byte(policyForDomain["domain.com"]))
|
||||
}))
|
||||
defer srv4.Close()
|
||||
|
||||
raw, err = httpGet(ctx, srv4.URL)
|
||||
if err == nil || err == ErrInvalidMediaType {
|
||||
t.Errorf("invalid content type was allowed: got %q, %v", raw, err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user