mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
sts: Add an on-disk cache implementation
This patch adds an on-disk cache for STS policies. Policies are cached by domain, and stored on files in a single directory. The files will have as mtime the time when the policy expires, this makes the store simpler, as it can avoid keeping additional metadata. There is no in-memory caching. This may be added in the future, but for now disk is good enough for our purposes.
This commit is contained in:
@@ -4,7 +4,6 @@
|
|||||||
// This is an EXPERIMENTAL implementation for now.
|
// This is an EXPERIMENTAL implementation for now.
|
||||||
//
|
//
|
||||||
// It lacks (at least) the following:
|
// It lacks (at least) the following:
|
||||||
// - Caching.
|
|
||||||
// - DNS TXT checking.
|
// - DNS TXT checking.
|
||||||
// - Facilities for reporting.
|
// - Facilities for reporting.
|
||||||
//
|
//
|
||||||
@@ -14,15 +13,40 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"expvar"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/safeio"
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
|
|
||||||
"golang.org/x/net/context/ctxhttp"
|
"golang.org/x/net/context/ctxhttp"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Exported variables.
|
||||||
|
var (
|
||||||
|
cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches")
|
||||||
|
cacheHits = expvar.NewInt("chasquid/sts/cache/hits")
|
||||||
|
cacheExpired = expvar.NewInt("chasquid/sts/cache/expired")
|
||||||
|
|
||||||
|
cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors")
|
||||||
|
cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch")
|
||||||
|
cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid")
|
||||||
|
|
||||||
|
cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors")
|
||||||
|
cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors")
|
||||||
|
|
||||||
|
cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles")
|
||||||
|
cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes")
|
||||||
|
cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors")
|
||||||
|
)
|
||||||
|
|
||||||
// Policy represents a parsed policy.
|
// Policy represents a parsed policy.
|
||||||
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
|
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
|
||||||
type Policy struct {
|
type Policy struct {
|
||||||
@@ -169,9 +193,12 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
return ioutil.ReadAll(resp.Body)
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
return ioutil.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
|
var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
|
||||||
@@ -222,3 +249,186 @@ func domainToASCII(domain string) (string, error) {
|
|||||||
domain = strings.ToLower(domain)
|
domain = strings.ToLower(domain)
|
||||||
return idna.ToASCII(domain)
|
return idna.ToASCII(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PolicyCache is a caching layer for fetching policies.
|
||||||
|
//
|
||||||
|
// Policies are cached by domain, and stored in a single directory.
|
||||||
|
// The files will have as mtime the time when the policy expires, this makes
|
||||||
|
// the store simpler, as it can avoid keeping additional metadata.
|
||||||
|
//
|
||||||
|
// There is no in-memory caching. This may be added in the future, but for
|
||||||
|
// now disk is good enough for our purposes.
|
||||||
|
type PolicyCache struct {
|
||||||
|
dir string
|
||||||
|
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCache(dir string) (*PolicyCache, error) {
|
||||||
|
c := &PolicyCache{
|
||||||
|
dir: dir,
|
||||||
|
}
|
||||||
|
err := os.MkdirAll(dir, 0770)
|
||||||
|
return c, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const pathPrefix = "pol:"
|
||||||
|
|
||||||
|
func (c *PolicyCache) domainPath(domain string) string {
|
||||||
|
// We assume the domain is well formed, sanity check just in case.
|
||||||
|
if strings.Contains(domain, "/") {
|
||||||
|
panic("domain contains slash")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.dir + "/" + pathPrefix + domain
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrExpired = errors.New("cache entry expired")
|
||||||
|
|
||||||
|
func (c *PolicyCache) load(domain string) (*Policy, error) {
|
||||||
|
fname := c.domainPath(domain)
|
||||||
|
|
||||||
|
fi, err := os.Stat(fname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if time.Since(fi.ModTime()) > 0 {
|
||||||
|
cacheExpired.Add(1)
|
||||||
|
return nil, ErrExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(fname)
|
||||||
|
if err != nil {
|
||||||
|
cacheIOErrors.Add(1)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Policy{}
|
||||||
|
err = json.Unmarshal(data, p)
|
||||||
|
if err != nil {
|
||||||
|
cacheUnmarshalErrors.Add(1)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The policy should always be valid, as we marshalled it ourselves;
|
||||||
|
// however, check it just to be safe.
|
||||||
|
if err := p.Check(); err != nil {
|
||||||
|
cacheInvalid.Add(1)
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"%s unmarshalled invalid policy %v: %v", domain, p, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PolicyCache) store(domain string, p *Policy) error {
|
||||||
|
data, err := json.Marshal(p)
|
||||||
|
if err != nil {
|
||||||
|
cacheMarshalErrors.Add(1)
|
||||||
|
return fmt.Errorf("%s failed to marshal policy %v, error: %v",
|
||||||
|
domain, p, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the modification time to the future, when the policy expires.
|
||||||
|
// load will check for this to detect expired cache entries, see above for
|
||||||
|
// the details.
|
||||||
|
expires := time.Now().Add(p.MaxAge)
|
||||||
|
chTime := func(fname string) error {
|
||||||
|
return os.Chtimes(fname, expires, expires)
|
||||||
|
}
|
||||||
|
|
||||||
|
fname := c.domainPath(domain)
|
||||||
|
err = safeio.WriteFile(fname, data, 0640, chTime)
|
||||||
|
if err != nil {
|
||||||
|
cacheIOErrors.Add(1)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) {
|
||||||
|
cacheFetches.Add(1)
|
||||||
|
tr := trace.New("STSCache.Fetch", domain)
|
||||||
|
defer tr.Finish()
|
||||||
|
|
||||||
|
p, err := c.load(domain)
|
||||||
|
if err == nil {
|
||||||
|
tr.Debugf("cache hit: %v", p)
|
||||||
|
cacheHits.Add(1)
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err = Fetch(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
tr.Debugf("failed to fetch: %v", err)
|
||||||
|
cacheFailedFetch.Add(1)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tr.Debugf("fetched: %v", p)
|
||||||
|
|
||||||
|
// We could do this asynchronously, as we got the policy to give to the
|
||||||
|
// caller. However, to make troubleshooting easier and the cost of storing
|
||||||
|
// entries easier to track down, we store synchronously.
|
||||||
|
// Note that even if the store returns an error, we pass on the policy: at
|
||||||
|
// this point we rather use the policy even if we couldn't store it in the
|
||||||
|
// cache.
|
||||||
|
err = c.store(domain, p)
|
||||||
|
if err != nil {
|
||||||
|
tr.Errorf("failed to store: %v", err)
|
||||||
|
} else {
|
||||||
|
tr.Debugf("stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
cacheRefreshCycles.Add(1)
|
||||||
|
|
||||||
|
c.refresh(ctx)
|
||||||
|
|
||||||
|
// Wait 10 minutes between passes; this is a background refresh and
|
||||||
|
// there's no need to poke the servers very often.
|
||||||
|
time.Sleep(10 * time.Minute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PolicyCache) refresh(ctx context.Context) {
|
||||||
|
tr := trace.New("STSCache.Refresh", c.dir)
|
||||||
|
defer tr.Finish()
|
||||||
|
|
||||||
|
entries, err := ioutil.ReadDir(c.dir)
|
||||||
|
if err != nil {
|
||||||
|
tr.Errorf("failed to list directory %q: %v", c.dir, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tr.Debugf("%d entries", len(entries))
|
||||||
|
|
||||||
|
for _, e := range entries {
|
||||||
|
if !strings.HasPrefix(e.Name(), pathPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domain := e.Name()[len(pathPrefix):]
|
||||||
|
cacheRefreshes.Add(1)
|
||||||
|
tr.Debugf("%v: refreshing", domain)
|
||||||
|
|
||||||
|
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
p, err := Fetch(fetchCtx, domain)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
tr.Debugf("%v: failed to fetch: %v", domain, err)
|
||||||
|
cacheRefreshErrors.Add(1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tr.Debugf("%v: fetched", domain)
|
||||||
|
|
||||||
|
err = c.store(domain, p)
|
||||||
|
if err != nil {
|
||||||
|
tr.Errorf("%v: failed to store: %v", domain, err)
|
||||||
|
} else {
|
||||||
|
tr.Debugf("%v: stored", domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tr.Debugf("refresh done")
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,38 @@ package sts
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"expvar"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
// Populate the fake policy contents, used by a few tests.
|
||||||
|
// httpGet will use this data instead of using the network.
|
||||||
|
|
||||||
|
// domain.com -> valid, with reasonable policy.
|
||||||
|
fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
|
||||||
|
{
|
||||||
|
"version": "STSv1",
|
||||||
|
"mode": "enforce",
|
||||||
|
"mx": ["*.mail.domain.com"],
|
||||||
|
"max_age": 3600
|
||||||
|
}`
|
||||||
|
|
||||||
|
// version99 -> invalid policy (unknown version).
|
||||||
|
fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
|
||||||
|
{
|
||||||
|
"version": "STSv99",
|
||||||
|
"mode": "enforce",
|
||||||
|
"mx": ["*.mail.version99"],
|
||||||
|
"max_age": 999
|
||||||
|
}`
|
||||||
|
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
func TestParsePolicy(t *testing.T) {
|
func TestParsePolicy(t *testing.T) {
|
||||||
const pol1 = `{
|
const pol1 = `{
|
||||||
"version": "STSv1",
|
"version": "STSv1",
|
||||||
@@ -84,14 +112,10 @@ func TestMatchDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFetch(t *testing.T) {
|
func TestFetch(t *testing.T) {
|
||||||
|
// Note the data "fetched" for each domain comes from fakeContent, defined
|
||||||
|
// in TestMain above. See httpGet for more details.
|
||||||
|
|
||||||
// Normal fetch, all valid.
|
// Normal fetch, all valid.
|
||||||
fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
|
|
||||||
{
|
|
||||||
"version": "STSv1",
|
|
||||||
"mode": "enforce",
|
|
||||||
"mx": ["*.mail.example.com"],
|
|
||||||
"max_age": 123456
|
|
||||||
}`
|
|
||||||
p, err := Fetch(context.Background(), "domain.com")
|
p, err := Fetch(context.Background(), "domain.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to fetch policy: %v", err)
|
t.Errorf("failed to fetch policy: %v", err)
|
||||||
@@ -106,13 +130,6 @@ func TestFetch(t *testing.T) {
|
|||||||
t.Logf("unknown: got error as expected: %v", err)
|
t.Logf("unknown: got error as expected: %v", err)
|
||||||
|
|
||||||
// Domain with an invalid policy (unknown version).
|
// Domain with an invalid policy (unknown version).
|
||||||
fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
|
|
||||||
{
|
|
||||||
"version": "STSv99",
|
|
||||||
"mode": "enforce",
|
|
||||||
"mx": ["*.mail.example.com"],
|
|
||||||
"max_age": 123456
|
|
||||||
}`
|
|
||||||
p, err = Fetch(context.Background(), "version99")
|
p, err = Fetch(context.Background(), "version99")
|
||||||
if err != ErrUnknownVersion {
|
if err != ErrUnknownVersion {
|
||||||
t.Errorf("expected error %v, got %v (and policy: %v)",
|
t.Errorf("expected error %v, got %v (and policy: %v)",
|
||||||
@@ -120,3 +137,197 @@ func TestFetch(t *testing.T) {
|
|||||||
}
|
}
|
||||||
t.Logf("version99: got expected error: %v", err)
|
t.Logf("version99: got expected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests for the policy cache.
|
||||||
|
|
||||||
|
func mustTempDir(t *testing.T) string {
|
||||||
|
dir, err := ioutil.TempDir("", "sts_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Chdir(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("test directory: %q", dir)
|
||||||
|
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int64) {
|
||||||
|
value := v.Value()
|
||||||
|
if value != expected {
|
||||||
|
t.Errorf("%s is %d, expected %d", name, value, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheBasics(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
c, err := NewCache(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note the data "fetched" for each domain comes from fakeContent, defined
|
||||||
|
// in TestMain above. See httpGet for more details.
|
||||||
|
|
||||||
|
// Reset the expvar counters that we use to validate hits, misses, etc.
|
||||||
|
cacheFetches.Set(0)
|
||||||
|
cacheHits.Set(0)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Fetch domain.com, check we get a reasonable policy, and that it's a
|
||||||
|
// cache miss.
|
||||||
|
p, err := c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||||
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||||
|
|
||||||
|
// Fetch domain.com again, this time we should see a cache hit.
|
||||||
|
p, err = c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||||
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||||
|
|
||||||
|
// Simulate an expired cache entry by changing the mtime of domain.com's
|
||||||
|
// entry to the past.
|
||||||
|
expires := time.Now().Add(-1 * time.Minute)
|
||||||
|
os.Chtimes(c.domainPath("domain.com"), expires, expires)
|
||||||
|
|
||||||
|
// Do a third fetch, check that we don't get a cache hit.
|
||||||
|
p, err = c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||||
|
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||||
|
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test how the cache behaves when the files are corrupt.
|
||||||
|
func TestCacheBadData(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
c, err := NewCache(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
// Case 1: A file with invalid json, which will fail unmarshalling.
|
||||||
|
"this is not valid json",
|
||||||
|
|
||||||
|
// Case 2: A file with a parseable but invalid policy.
|
||||||
|
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, badContent := range cases {
|
||||||
|
// Reset the expvar counters that we use to validate hits, misses, etc.
|
||||||
|
cacheFetches.Set(0)
|
||||||
|
cacheHits.Set(0)
|
||||||
|
|
||||||
|
// Fetch domain.com, should result in the file being added to the
|
||||||
|
// cache.
|
||||||
|
p, err := c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Fetch failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||||
|
|
||||||
|
// Edit the file, filling it with the bad content for this case.
|
||||||
|
fname := c.domainPath("domain.com")
|
||||||
|
err = ioutil.WriteFile(fname, []byte(badContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error writing file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We now expect Fetch to fall back to getting the policy from the
|
||||||
|
// network (in our case, from fakeContent).
|
||||||
|
p, err = c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Fetch failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||||
|
|
||||||
|
// And now the file should be fine, resulting in a cache hit.
|
||||||
|
p, err = c.Fetch(ctx, "domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Fetch failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("cache fetched domain.com: %v", p)
|
||||||
|
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
||||||
|
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||||
|
|
||||||
|
// Remove the file, to start with a clean slate for the next case.
|
||||||
|
os.Remove(fname)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy {
|
||||||
|
p, err := c.Fetch(ctx, d)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Fetch %q failed: %v", d, err)
|
||||||
|
}
|
||||||
|
t.Logf("Fetch %q: %v", d, p)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheRefresh(t *testing.T) {
|
||||||
|
dir := mustTempDir(t)
|
||||||
|
c, err := NewCache(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
|
||||||
|
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
|
||||||
|
p := mustFetch(t, c, ctx, "refresh-test")
|
||||||
|
if p.MaxAge != 100*time.Second {
|
||||||
|
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the "published" policy, check that we see the old version at
|
||||||
|
// fetch (should be cached), and a new version after a refresh.
|
||||||
|
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
|
||||||
|
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
|
||||||
|
|
||||||
|
p = mustFetch(t, c, ctx, "refresh-test")
|
||||||
|
if p.MaxAge != 100*time.Second {
|
||||||
|
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.refresh(ctx)
|
||||||
|
|
||||||
|
p = mustFetch(t, c, ctx, "refresh-test")
|
||||||
|
if p.MaxAge != 200*time.Second {
|
||||||
|
t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.Failed() {
|
||||||
|
os.RemoveAll(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user