1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-20 15:07:03 +00:00

sts: Add documentation and fix minor style issues

This patch adds some missing function documentation entries, and fixes
minor style issues caught by the linter.

No functional changes.
This commit is contained in:
Alberto Bertogli
2018-07-01 12:29:46 +01:00
parent 46bce576e8
commit 0ee7cb4cce
2 changed files with 19 additions and 11 deletions

View File

@@ -61,6 +61,8 @@ type Policy struct {
MaxAge time.Duration `json:"max_age"` MaxAge time.Duration `json:"max_age"`
} }
// The Mode of a policy. Valid values (according to the standard) are
// constants below.
type Mode string type Mode string
// Valid modes. // Valid modes.
@@ -93,8 +95,8 @@ func parsePolicy(raw []byte) (*Policy, error) {
p.Mode = Mode(value) p.Mode = Mode(value)
case "max_age": case "max_age":
// On error, p.MaxAge will be 0 which is invalid. // On error, p.MaxAge will be 0 which is invalid.
max_age, _ := strconv.Atoi(value) maxAge, _ := strconv.Atoi(value)
p.MaxAge = time.Duration(max_age) * time.Second p.MaxAge = time.Duration(maxAge) * time.Second
case "mx": case "mx":
p.MXs = append(p.MXs, value) p.MXs = append(p.MXs, value)
} }
@@ -106,14 +108,16 @@ func parsePolicy(raw []byte) (*Policy, error) {
return p, nil return p, nil
} }
var (
// Check errors. // Check errors.
var (
ErrUnknownVersion = errors.New("unknown policy version") ErrUnknownVersion = errors.New("unknown policy version")
ErrInvalidMaxAge = errors.New("invalid max_age") ErrInvalidMaxAge = errors.New("invalid max_age")
ErrInvalidMode = errors.New("invalid mode") ErrInvalidMode = errors.New("invalid mode")
ErrInvalidMX = errors.New("invalid mx") ErrInvalidMX = errors.New("invalid mx")
)
// Fetch errors. // Fetch errors.
var (
ErrInvalidMediaType = errors.New("invalid HTTP media type") ErrInvalidMediaType = errors.New("invalid HTTP media type")
) )
@@ -333,6 +337,8 @@ type PolicyCache struct {
sync.Mutex sync.Mutex
} }
// NewCache creates an instance of PolicyCache using the given directory as
// backing storage. The directory will be created if it does not exist.
func NewCache(dir string) (*PolicyCache, error) { func NewCache(dir string) (*PolicyCache, error) {
c := &PolicyCache{ c := &PolicyCache{
dir: dir, dir: dir,
@@ -352,7 +358,7 @@ func (c *PolicyCache) domainPath(domain string) string {
return c.dir + "/" + pathPrefix + domain return c.dir + "/" + pathPrefix + domain
} }
var ErrExpired = errors.New("cache entry expired") var errExpired = errors.New("cache entry expired")
func (c *PolicyCache) load(domain string) (*Policy, error) { func (c *PolicyCache) load(domain string) (*Policy, error) {
fname := c.domainPath(domain) fname := c.domainPath(domain)
@@ -363,7 +369,7 @@ func (c *PolicyCache) load(domain string) (*Policy, error) {
} }
if time.Since(fi.ModTime()) > 0 { if time.Since(fi.ModTime()) > 0 {
cacheExpired.Add(1) cacheExpired.Add(1)
return nil, ErrExpired return nil, errExpired
} }
data, err := ioutil.ReadFile(fname) data, err := ioutil.ReadFile(fname)
@@ -414,6 +420,7 @@ func (c *PolicyCache) store(domain string, p *Policy) error {
return err return err
} }
// Fetch a policy for the given domain, using the cache.
func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) {
cacheFetches.Add(1) cacheFetches.Add(1)
tr := trace.New("STSCache.Fetch", domain) tr := trace.New("STSCache.Fetch", domain)
@@ -450,6 +457,7 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error)
return p, nil return p, nil
} }
// PeriodicallyRefresh the cache, by re-fetching all entries.
func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
for ctx.Err() == nil { for ctx.Err() == nil {
c.refresh(ctx) c.refresh(ctx)

View File

@@ -27,9 +27,9 @@ var txtResults = map[string][]string{
"_mta-sts.policy404": {"v=STSv1; id=blah;"}, "_mta-sts.policy404": {"v=STSv1; id=blah;"},
"_mta-sts.version99": {"v=STSv1; id=blah;"}, "_mta-sts.version99": {"v=STSv1; id=blah;"},
} }
var testError = fmt.Errorf("error for testing purposes") var errTest = fmt.Errorf("error for testing purposes")
var txtErrors = map[string]error{ var txtErrors = map[string]error{
"_mta-sts.domErr": testError, "_mta-sts.domErr": errTest,
} }
func testLookupTXT(domain string) ([]string, error) { func testLookupTXT(domain string) ([]string, error) {
@@ -217,9 +217,9 @@ func TestFetch(t *testing.T) {
// Error fetching TXT record for this domain. // Error fetching TXT record for this domain.
p, err = Fetch(context.Background(), "domErr") p, err = Fetch(context.Background(), "domErr")
if err != testError { if err != errTest {
t.Errorf("expected error %v, got %v (and policy: %v)", t.Errorf("expected error %v, got %v (and policy: %v)",
testError, err, p) errTest, err, p)
} }
t.Logf("domErr: got expected error: %v", err) t.Logf("domErr: got expected error: %v", err)
} }
@@ -508,7 +508,7 @@ func TestHasSTSRecord(t *testing.T) {
{"dom2", false, nil}, {"dom2", false, nil},
{"dom3", false, nil}, {"dom3", false, nil},
{"dom4", true, nil}, {"dom4", true, nil},
{"domErr", false, testError}, {"domErr", false, errTest},
} }
for _, c := range cases { for _, c := range cases {
ok, err := hasSTSRecord(c.domain) ok, err := hasSTSRecord(c.domain)