mirror of
https://blitiri.com.ar/repos/chasquid
synced 2026-01-27 20:45:56 +00:00
sts: Add an on-disk cache implementation
This patch adds an on-disk cache for STS policies. Policies are cached by domain, and stored on files in a single directory. The files will have as mtime the time when the policy expires, this makes the store simpler, as it can avoid keeping additional metadata. There is no in-memory caching. This may be added in the future, but for now disk is good enough for our purposes.
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
// This is an EXPERIMENTAL implementation for now.
|
||||
//
|
||||
// It lacks (at least) the following:
|
||||
// - Caching.
|
||||
// - DNS TXT checking.
|
||||
// - Facilities for reporting.
|
||||
//
|
||||
@@ -14,15 +13,40 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/safeio"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// Exported variables.
|
||||
var (
|
||||
cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches")
|
||||
cacheHits = expvar.NewInt("chasquid/sts/cache/hits")
|
||||
cacheExpired = expvar.NewInt("chasquid/sts/cache/expired")
|
||||
|
||||
cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors")
|
||||
cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch")
|
||||
cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid")
|
||||
|
||||
cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors")
|
||||
cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors")
|
||||
|
||||
cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles")
|
||||
cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes")
|
||||
cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors")
|
||||
)
|
||||
|
||||
// Policy represents a parsed policy.
|
||||
// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
|
||||
type Policy struct {
|
||||
@@ -169,9 +193,12 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode)
|
||||
}
|
||||
|
||||
var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
|
||||
@@ -222,3 +249,186 @@ func domainToASCII(domain string) (string, error) {
|
||||
domain = strings.ToLower(domain)
|
||||
return idna.ToASCII(domain)
|
||||
}
|
||||
|
||||
// PolicyCache is a caching layer for fetching policies.
|
||||
//
|
||||
// Policies are cached by domain, and stored in a single directory.
|
||||
// The files will have as mtime the time when the policy expires, this makes
|
||||
// the store simpler, as it can avoid keeping additional metadata.
|
||||
//
|
||||
// There is no in-memory caching. This may be added in the future, but for
|
||||
// now disk is good enough for our purposes.
|
||||
type PolicyCache struct {
|
||||
dir string
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewCache(dir string) (*PolicyCache, error) {
|
||||
c := &PolicyCache{
|
||||
dir: dir,
|
||||
}
|
||||
err := os.MkdirAll(dir, 0770)
|
||||
return c, err
|
||||
}
|
||||
|
||||
const pathPrefix = "pol:"
|
||||
|
||||
func (c *PolicyCache) domainPath(domain string) string {
|
||||
// We assume the domain is well formed, sanity check just in case.
|
||||
if strings.Contains(domain, "/") {
|
||||
panic("domain contains slash")
|
||||
}
|
||||
|
||||
return c.dir + "/" + pathPrefix + domain
|
||||
}
|
||||
|
||||
var ErrExpired = errors.New("cache entry expired")
|
||||
|
||||
func (c *PolicyCache) load(domain string) (*Policy, error) {
|
||||
fname := c.domainPath(domain)
|
||||
|
||||
fi, err := os.Stat(fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if time.Since(fi.ModTime()) > 0 {
|
||||
cacheExpired.Add(1)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(fname)
|
||||
if err != nil {
|
||||
cacheIOErrors.Add(1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Policy{}
|
||||
err = json.Unmarshal(data, p)
|
||||
if err != nil {
|
||||
cacheUnmarshalErrors.Add(1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The policy should always be valid, as we marshalled it ourselves;
|
||||
// however, check it just to be safe.
|
||||
if err := p.Check(); err != nil {
|
||||
cacheInvalid.Add(1)
|
||||
return nil, fmt.Errorf(
|
||||
"%s unmarshalled invalid policy %v: %v", domain, p, err)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (c *PolicyCache) store(domain string, p *Policy) error {
|
||||
data, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
cacheMarshalErrors.Add(1)
|
||||
return fmt.Errorf("%s failed to marshal policy %v, error: %v",
|
||||
domain, p, err)
|
||||
}
|
||||
|
||||
// Change the modification time to the future, when the policy expires.
|
||||
// load will check for this to detect expired cache entries, see above for
|
||||
// the details.
|
||||
expires := time.Now().Add(p.MaxAge)
|
||||
chTime := func(fname string) error {
|
||||
return os.Chtimes(fname, expires, expires)
|
||||
}
|
||||
|
||||
fname := c.domainPath(domain)
|
||||
err = safeio.WriteFile(fname, data, 0640, chTime)
|
||||
if err != nil {
|
||||
cacheIOErrors.Add(1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) {
|
||||
cacheFetches.Add(1)
|
||||
tr := trace.New("STSCache.Fetch", domain)
|
||||
defer tr.Finish()
|
||||
|
||||
p, err := c.load(domain)
|
||||
if err == nil {
|
||||
tr.Debugf("cache hit: %v", p)
|
||||
cacheHits.Add(1)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
p, err = Fetch(ctx, domain)
|
||||
if err != nil {
|
||||
tr.Debugf("failed to fetch: %v", err)
|
||||
cacheFailedFetch.Add(1)
|
||||
return nil, err
|
||||
}
|
||||
tr.Debugf("fetched: %v", p)
|
||||
|
||||
// We could do this asynchronously, as we got the policy to give to the
|
||||
// caller. However, to make troubleshooting easier and the cost of storing
|
||||
// entries easier to track down, we store synchronously.
|
||||
// Note that even if the store returns an error, we pass on the policy: at
|
||||
// this point we rather use the policy even if we couldn't store it in the
|
||||
// cache.
|
||||
err = c.store(domain, p)
|
||||
if err != nil {
|
||||
tr.Errorf("failed to store: %v", err)
|
||||
} else {
|
||||
tr.Debugf("stored")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
|
||||
for ctx.Err() == nil {
|
||||
cacheRefreshCycles.Add(1)
|
||||
|
||||
c.refresh(ctx)
|
||||
|
||||
// Wait 10 minutes between passes; this is a background refresh and
|
||||
// there's no need to poke the servers very often.
|
||||
time.Sleep(10 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PolicyCache) refresh(ctx context.Context) {
|
||||
tr := trace.New("STSCache.Refresh", c.dir)
|
||||
defer tr.Finish()
|
||||
|
||||
entries, err := ioutil.ReadDir(c.dir)
|
||||
if err != nil {
|
||||
tr.Errorf("failed to list directory %q: %v", c.dir, err)
|
||||
return
|
||||
}
|
||||
tr.Debugf("%d entries", len(entries))
|
||||
|
||||
for _, e := range entries {
|
||||
if !strings.HasPrefix(e.Name(), pathPrefix) {
|
||||
continue
|
||||
}
|
||||
domain := e.Name()[len(pathPrefix):]
|
||||
cacheRefreshes.Add(1)
|
||||
tr.Debugf("%v: refreshing", domain)
|
||||
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
p, err := Fetch(fetchCtx, domain)
|
||||
cancel()
|
||||
if err != nil {
|
||||
tr.Debugf("%v: failed to fetch: %v", domain, err)
|
||||
cacheRefreshErrors.Add(1)
|
||||
continue
|
||||
}
|
||||
tr.Debugf("%v: fetched", domain)
|
||||
|
||||
err = c.store(domain, p)
|
||||
if err != nil {
|
||||
tr.Errorf("%v: failed to store: %v", domain, err)
|
||||
} else {
|
||||
tr.Debugf("%v: stored", domain)
|
||||
}
|
||||
}
|
||||
|
||||
tr.Debugf("refresh done")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user