mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-22 15:27:02 +00:00
sts: Add an on-disk cache implementation
This patch adds an on-disk cache for STS policies. Policies are cached by domain, and stored on files in a single directory. The files will have as mtime the time when the policy expires, this makes the store simpler, as it can avoid keeping additional metadata. There is no in-memory caching. This may be added in the future, but for now disk is good enough for our purposes.
This commit is contained in:
@@ -2,10 +2,38 @@ package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"expvar"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Populate the fake policy contents, used by a few tests.
|
||||
// httpGet will use this data instead of using the network.
|
||||
|
||||
// domain.com -> valid, with reasonable policy.
|
||||
fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
|
||||
{
|
||||
"version": "STSv1",
|
||||
"mode": "enforce",
|
||||
"mx": ["*.mail.domain.com"],
|
||||
"max_age": 3600
|
||||
}`
|
||||
|
||||
// version99 -> invalid policy (unknown version).
|
||||
fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
|
||||
{
|
||||
"version": "STSv99",
|
||||
"mode": "enforce",
|
||||
"mx": ["*.mail.version99"],
|
||||
"max_age": 999
|
||||
}`
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestParsePolicy(t *testing.T) {
|
||||
const pol1 = `{
|
||||
"version": "STSv1",
|
||||
@@ -84,14 +112,10 @@ func TestMatchDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFetch(t *testing.T) {
|
||||
// Note the data "fetched" for each domain comes from fakeContent, defined
|
||||
// in TestMain above. See httpGet for more details.
|
||||
|
||||
// Normal fetch, all valid.
|
||||
fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
|
||||
{
|
||||
"version": "STSv1",
|
||||
"mode": "enforce",
|
||||
"mx": ["*.mail.example.com"],
|
||||
"max_age": 123456
|
||||
}`
|
||||
p, err := Fetch(context.Background(), "domain.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to fetch policy: %v", err)
|
||||
@@ -106,13 +130,6 @@ func TestFetch(t *testing.T) {
|
||||
t.Logf("unknown: got error as expected: %v", err)
|
||||
|
||||
// Domain with an invalid policy (unknown version).
|
||||
fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
|
||||
{
|
||||
"version": "STSv99",
|
||||
"mode": "enforce",
|
||||
"mx": ["*.mail.example.com"],
|
||||
"max_age": 123456
|
||||
}`
|
||||
p, err = Fetch(context.Background(), "version99")
|
||||
if err != ErrUnknownVersion {
|
||||
t.Errorf("expected error %v, got %v (and policy: %v)",
|
||||
@@ -120,3 +137,197 @@ func TestFetch(t *testing.T) {
|
||||
}
|
||||
t.Logf("version99: got expected error: %v", err)
|
||||
}
|
||||
|
||||
// Tests for the policy cache.
|
||||
|
||||
func mustTempDir(t *testing.T) string {
|
||||
dir, err := ioutil.TempDir("", "sts_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.Chdir(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("test directory: %q", dir)
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int64) {
|
||||
value := v.Value()
|
||||
if value != expected {
|
||||
t.Errorf("%s is %d, expected %d", name, value, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheBasics(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Note the data "fetched" for each domain comes from fakeContent, defined
|
||||
// in TestMain above. See httpGet for more details.
|
||||
|
||||
// Reset the expvar counters that we use to validate hits, misses, etc.
|
||||
cacheFetches.Set(0)
|
||||
cacheHits.Set(0)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fetch domain.com, check we get a reasonable policy, and that it's a
|
||||
// cache miss.
|
||||
p, err := c.Fetch(ctx, "domain.com")
|
||||
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||
|
||||
// Fetch domain.com again, this time we should see a cache hit.
|
||||
p, err = c.Fetch(ctx, "domain.com")
|
||||
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||
|
||||
// Simulate an expired cache entry by changing the mtime of domain.com's
|
||||
// entry to the past.
|
||||
expires := time.Now().Add(-1 * time.Minute)
|
||||
os.Chtimes(c.domainPath("domain.com"), expires, expires)
|
||||
|
||||
// Do a third fetch, check that we don't get a cache hit.
|
||||
p, err = c.Fetch(ctx, "domain.com")
|
||||
if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
|
||||
t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
|
||||
// Test how the cache behaves when the files are corrupt.
|
||||
func TestCacheBadData(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
cases := []string{
|
||||
// Case 1: A file with invalid json, which will fail unmarshalling.
|
||||
"this is not valid json",
|
||||
|
||||
// Case 2: A file with a parseable but invalid policy.
|
||||
`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
|
||||
}
|
||||
|
||||
for _, badContent := range cases {
|
||||
// Reset the expvar counters that we use to validate hits, misses, etc.
|
||||
cacheFetches.Set(0)
|
||||
cacheHits.Set(0)
|
||||
|
||||
// Fetch domain.com, should result in the file being added to the
|
||||
// cache.
|
||||
p, err := c.Fetch(ctx, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch failed: %v", err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 1)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||
|
||||
// Edit the file, filling it with the bad content for this case.
|
||||
fname := c.domainPath("domain.com")
|
||||
err = ioutil.WriteFile(fname, []byte(badContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing file: %v", err)
|
||||
}
|
||||
|
||||
// We now expect Fetch to fall back to getting the policy from the
|
||||
// network (in our case, from fakeContent).
|
||||
p, err = c.Fetch(ctx, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch failed: %v", err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 2)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 0)
|
||||
|
||||
// And now the file should be fine, resulting in a cache hit.
|
||||
p, err = c.Fetch(ctx, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch failed: %v", err)
|
||||
}
|
||||
t.Logf("cache fetched domain.com: %v", p)
|
||||
expvarMustEq(t, "cacheFetches", cacheFetches, 3)
|
||||
expvarMustEq(t, "cacheHits", cacheHits, 1)
|
||||
|
||||
// Remove the file, to start with a clean slate for the next case.
|
||||
os.Remove(fname)
|
||||
}
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
|
||||
func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy {
|
||||
p, err := c.Fetch(ctx, d)
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch %q failed: %v", d, err)
|
||||
}
|
||||
t.Logf("Fetch %q: %v", d, p)
|
||||
return p
|
||||
}
|
||||
|
||||
func TestCacheRefresh(t *testing.T) {
|
||||
dir := mustTempDir(t)
|
||||
c, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
|
||||
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
|
||||
p := mustFetch(t, c, ctx, "refresh-test")
|
||||
if p.MaxAge != 100*time.Second {
|
||||
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
||||
}
|
||||
|
||||
// Change the "published" policy, check that we see the old version at
|
||||
// fetch (should be cached), and a new version after a refresh.
|
||||
fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
|
||||
{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
|
||||
|
||||
p = mustFetch(t, c, ctx, "refresh-test")
|
||||
if p.MaxAge != 100*time.Second {
|
||||
t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
|
||||
}
|
||||
|
||||
c.refresh(ctx)
|
||||
|
||||
p = mustFetch(t, c, ctx, "refresh-test")
|
||||
if p.MaxAge != 200*time.Second {
|
||||
t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge)
|
||||
}
|
||||
|
||||
if !t.Failed() {
|
||||
os.RemoveAll(dir)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user