// Package sts implements the MTA-STS (Strict Transport Security), based on // the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02. // // This is an EXPERIMENTAL implementation for now. // // It lacks (at least) the following: // - DNS TXT checking. // - Facilities for reporting. // package sts import ( "context" "encoding/json" "errors" "expvar" "fmt" "io" "io/ioutil" "net/http" "os" "strings" "sync" "time" "blitiri.com.ar/go/chasquid/internal/safeio" "blitiri.com.ar/go/chasquid/internal/trace" "golang.org/x/net/context/ctxhttp" "golang.org/x/net/idna" ) // Exported variables. var ( cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches") cacheHits = expvar.NewInt("chasquid/sts/cache/hits") cacheExpired = expvar.NewInt("chasquid/sts/cache/expired") cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors") cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch") cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid") cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors") cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors") cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles") cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes") cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors") ) // Policy represents a parsed policy. // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 type Policy struct { Version string `json:"version"` Mode Mode `json:"mode"` MXs []string `json:"mx"` MaxAge time.Duration `json:"max_age"` } type Mode string // Valid modes. const ( Enforce = Mode("enforce") Report = Mode("report") ) // parsePolicy parses a JSON representation of the policy, and returns the // corresponding Policy structure. func parsePolicy(raw []byte) (*Policy, error) { p := &Policy{} if err := json.Unmarshal(raw, p); err != nil { return nil, err } // MaxAge is in seconds. p.MaxAge = p.MaxAge * time.Second return p, nil } var ( ErrUnknownVersion = errors.New("unknown policy version") ErrInvalidMaxAge = errors.New("invalid max_age") ErrInvalidMode = errors.New("invalid mode") ErrInvalidMX = errors.New("invalid mx") ) // Check that the policy contents are valid. func (p *Policy) Check() error { if p.Version != "STSv1" { return ErrUnknownVersion } if p.MaxAge <= 0 { return ErrInvalidMaxAge } if p.Mode != Enforce && p.Mode != Report { return ErrInvalidMode } // "mx" field is required, and the policy is invalid if it's not present. // https://mailarchive.ietf.org/arch/msg/uta/Omqo1Bw6rJbrTMl2Zo69IJr35Qo if len(p.MXs) == 0 { return ErrInvalidMX } return nil } // MXMatches checks if the given MX is allowed, according to the policy. // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1 func (p *Policy) MXIsAllowed(mx string) bool { for _, pattern := range p.MXs { if matchDomain(mx, pattern) { return true } } return false } // UncheckedFetch fetches and parses the policy, but does NOT check it. // This can be useful for debugging and troubleshooting, but you should always // call Check on the policy before using it. func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { // Convert the domain to ascii form, as httpGet does not support IDNs in // any other way. domain, err := idna.ToASCII(domain) if err != nil { return nil, err } url := urlForDomain(domain) rawPolicy, err := httpGet(ctx, url) if err != nil { return nil, err } return parsePolicy(rawPolicy) } // Fake URL for testing purposes, so we can do more end-to-end tests, // including the HTTP fetching code. var fakeURLForTesting string func urlForDomain(domain string) string { if fakeURLForTesting != "" { return fakeURLForTesting + "/" + domain } // URL composed from the domain, as explained in: // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 return "https://mta-sts." + domain + "/.well-known/mta-sts.json" } // Fetch a policy for the given domain. Note this results in various network // lookups and HTTPS GETs, so it can be slow. // The returned policy is parsed and sanity-checked (using Policy.Check), so // it should be safe to use. func Fetch(ctx context.Context, domain string) (*Policy, error) { p, err := UncheckedFetch(ctx, domain) if err != nil { return nil, err } err = p.Check() if err != nil { return nil, err } return p, nil } // httpGet performs an HTTP GET of the given URL, using the context and // rejecting redirects, as per the standard. func httpGet(ctx context.Context, url string) ([]byte, error) { client := &http.Client{ // We MUST NOT follow redirects, see // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 CheckRedirect: rejectRedirect, } // Note that http does not care for the context deadline, so we need to // construct it here. if deadline, ok := ctx.Deadline(); ok { client.Timeout = deadline.Sub(time.Now()) } resp, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { // Read but up to 10k; policies should be way smaller than that, and // having a limit prevents abuse/accidents with very large replies. return ioutil.ReadAll(&io.LimitedReader{resp.Body, 10 * 1024}) } return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode) } var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") func rejectRedirect(req *http.Request, via []*http.Request) error { return errRejectRedirect } // matchDomain checks if the domain matches the given pattern, according to // https://tools.ietf.org/html/rfc6125#section-6.4 // (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1). func matchDomain(domain, pattern string) bool { domain, dErr := domainToASCII(domain) pattern, pErr := domainToASCII(pattern) if dErr != nil || pErr != nil { // Domains should already have been checked and normalized by the // caller, exposing this is not worth the API complexity in this case. return false } domainLabels := strings.Split(domain, ".") patternLabels := strings.Split(pattern, ".") if len(domainLabels) != len(patternLabels) { return false } for i, p := range patternLabels { // Wildcards only apply to the first part, see // https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2. // This also allows us to do the lenght comparison above. if p == "*" && i == 0 { continue } if p != domainLabels[i] { return false } } return true } // domainToASCII converts the domain to ASCII form, similar to idna.ToASCII // but with some preprocessing convenient for our use cases. func domainToASCII(domain string) (string, error) { domain = strings.TrimSuffix(domain, ".") domain = strings.ToLower(domain) return idna.ToASCII(domain) } // PolicyCache is a caching layer for fetching policies. // // Policies are cached by domain, and stored in a single directory. // The files will have as mtime the time when the policy expires, this makes // the store simpler, as it can avoid keeping additional metadata. // // There is no in-memory caching. This may be added in the future, but for // now disk is good enough for our purposes. type PolicyCache struct { dir string sync.Mutex } func NewCache(dir string) (*PolicyCache, error) { c := &PolicyCache{ dir: dir, } err := os.MkdirAll(dir, 0770) return c, err } const pathPrefix = "pol:" func (c *PolicyCache) domainPath(domain string) string { // We assume the domain is well formed, sanity check just in case. if strings.Contains(domain, "/") { panic("domain contains slash") } return c.dir + "/" + pathPrefix + domain } var ErrExpired = errors.New("cache entry expired") func (c *PolicyCache) load(domain string) (*Policy, error) { fname := c.domainPath(domain) fi, err := os.Stat(fname) if err != nil { return nil, err } if time.Since(fi.ModTime()) > 0 { cacheExpired.Add(1) return nil, ErrExpired } data, err := ioutil.ReadFile(fname) if err != nil { cacheIOErrors.Add(1) return nil, err } p := &Policy{} err = json.Unmarshal(data, p) if err != nil { cacheUnmarshalErrors.Add(1) return nil, err } // The policy should always be valid, as we marshalled it ourselves; // however, check it just to be safe. if err := p.Check(); err != nil { cacheInvalid.Add(1) return nil, fmt.Errorf( "%s unmarshalled invalid policy %v: %v", domain, p, err) } return p, nil } func (c *PolicyCache) store(domain string, p *Policy) error { data, err := json.Marshal(p) if err != nil { cacheMarshalErrors.Add(1) return fmt.Errorf("%s failed to marshal policy %v, error: %v", domain, p, err) } // Change the modification time to the future, when the policy expires. // load will check for this to detect expired cache entries, see above for // the details. expires := time.Now().Add(p.MaxAge) chTime := func(fname string) error { return os.Chtimes(fname, expires, expires) } fname := c.domainPath(domain) err = safeio.WriteFile(fname, data, 0640, chTime) if err != nil { cacheIOErrors.Add(1) } return err } func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { cacheFetches.Add(1) tr := trace.New("STSCache.Fetch", domain) defer tr.Finish() p, err := c.load(domain) if err == nil { tr.Debugf("cache hit: %v", p) cacheHits.Add(1) return p, nil } p, err = Fetch(ctx, domain) if err != nil { tr.Debugf("failed to fetch: %v", err) cacheFailedFetch.Add(1) return nil, err } tr.Debugf("fetched: %v", p) // We could do this asynchronously, as we got the policy to give to the // caller. However, to make troubleshooting easier and the cost of storing // entries easier to track down, we store synchronously. // Note that even if the store returns an error, we pass on the policy: at // this point we rather use the policy even if we couldn't store it in the // cache. err = c.store(domain, p) if err != nil { tr.Errorf("failed to store: %v", err) } else { tr.Debugf("stored") } return p, nil } func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { for ctx.Err() == nil { cacheRefreshCycles.Add(1) c.refresh(ctx) // Wait 10 minutes between passes; this is a background refresh and // there's no need to poke the servers very often. time.Sleep(10 * time.Minute) } } func (c *PolicyCache) refresh(ctx context.Context) { tr := trace.New("STSCache.Refresh", c.dir) defer tr.Finish() entries, err := ioutil.ReadDir(c.dir) if err != nil { tr.Errorf("failed to list directory %q: %v", c.dir, err) return } tr.Debugf("%d entries", len(entries)) for _, e := range entries { if !strings.HasPrefix(e.Name(), pathPrefix) { continue } domain := e.Name()[len(pathPrefix):] cacheRefreshes.Add(1) tr.Debugf("%v: refreshing", domain) fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) p, err := Fetch(fetchCtx, domain) cancel() if err != nil { tr.Debugf("%v: failed to fetch: %v", domain, err) cacheRefreshErrors.Add(1) continue } tr.Debugf("%v: fetched", domain) err = c.store(domain, p) if err != nil { tr.Errorf("%v: failed to store: %v", domain, err) } else { tr.Debugf("%v: stored", domain) } } tr.Debugf("refresh done") }