mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
sts: DNS TXT record support
This patch adds support for checking the MTA-STS TXT record before fetching the policy via https. The content of the record is unused.
This commit is contained in:
@@ -3,9 +3,7 @@
|
|||||||
//
|
//
|
||||||
// This is an EXPERIMENTAL implementation for now.
|
// This is an EXPERIMENTAL implementation for now.
|
||||||
//
|
//
|
||||||
// It lacks (at least) the following:
|
// Note that "report" mode is not supported.
|
||||||
// - DNS TXT checking.
|
|
||||||
// - Facilities for reporting.
|
|
||||||
//
|
//
|
||||||
package sts
|
package sts
|
||||||
|
|
||||||
@@ -20,6 +18,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"mime"
|
"mime"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -167,6 +166,14 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ok, err := hasSTSRecord(domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("MTA-STS TXT record missing")
|
||||||
|
}
|
||||||
|
|
||||||
url := urlForDomain(domain)
|
url := urlForDomain(domain)
|
||||||
rawPolicy, err := httpGet(ctx, url)
|
rawPolicy, err := httpGet(ctx, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -295,6 +302,29 @@ func domainToASCII(domain string) (string, error) {
|
|||||||
return idna.ToASCII(domain)
|
return idna.ToASCII(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function that we override for testing purposes.
|
||||||
|
// In the future we will override net.DefaultResolver, but we don't do that
|
||||||
|
// yet for backwards compatibility.
|
||||||
|
var lookupTXT = net.LookupTXT
|
||||||
|
|
||||||
|
// hasSTSRecord checks if there is a valid MTA-STS TXT record for the domain.
|
||||||
|
// We don't do full parsing and don't care about the "id=" field, as it is
|
||||||
|
// unused in this implementation.
|
||||||
|
func hasSTSRecord(domain string) (bool, error) {
|
||||||
|
txts, err := lookupTXT("_mta-sts." + domain)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, txt := range txts {
|
||||||
|
if strings.HasPrefix(txt, "v=STSv1;") {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// PolicyCache is a caching layer for fetching policies.
|
// PolicyCache is a caching layer for fetching policies.
|
||||||
//
|
//
|
||||||
// Policies are cached by domain, and stored in a single directory.
|
// Policies are cached by domain, and stored in a single directory.
|
||||||
|
|||||||
@@ -13,6 +13,27 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Override the lookup function to control its results.
|
||||||
|
var txtResults = map[string][]string{
|
||||||
|
"dom1": nil,
|
||||||
|
"dom2": {},
|
||||||
|
"dom3": {"abc", "def"},
|
||||||
|
"dom4": {"abc", "v=STSv1; id=blah;"},
|
||||||
|
|
||||||
|
// Matching policyForDomain below.
|
||||||
|
"_mta-sts.domain.com": {"v=STSv1; id=blah;"},
|
||||||
|
"_mta-sts.policy404": {"v=STSv1; id=blah;"},
|
||||||
|
"_mta-sts.version99": {"v=STSv1; id=blah;"},
|
||||||
|
}
|
||||||
|
var testError = fmt.Errorf("error for testing purposes")
|
||||||
|
var txtErrors = map[string]error{
|
||||||
|
"_mta-sts.domErr": testError,
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLookupTXT(domain string) ([]string, error) {
|
||||||
|
return txtResults[domain], txtErrors[domain]
|
||||||
|
}
|
||||||
|
|
||||||
// Test policy for each of the requested domains. Will be served by the test
|
// Test policy for each of the requested domains. Will be served by the test
|
||||||
// HTTP server.
|
// HTTP server.
|
||||||
var policyForDomain = map[string]string{
|
var policyForDomain = map[string]string{
|
||||||
@@ -45,6 +66,8 @@ func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
lookupTXT = testLookupTXT
|
||||||
|
|
||||||
// Create a test HTTP server, used by the more end-to-end tests.
|
// Create a test HTTP server, used by the more end-to-end tests.
|
||||||
httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
|
httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
|
||||||
|
|
||||||
@@ -148,11 +171,11 @@ func TestFetch(t *testing.T) {
|
|||||||
t.Logf("domain.com: %+v", p)
|
t.Logf("domain.com: %+v", p)
|
||||||
|
|
||||||
// Domain without a policy (HTTP get fails).
|
// Domain without a policy (HTTP get fails).
|
||||||
p, err = Fetch(context.Background(), "unknown")
|
p, err = Fetch(context.Background(), "policy404")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("fetched unknown policy: %v", p)
|
t.Errorf("fetched unknown policy: %v", p)
|
||||||
}
|
}
|
||||||
t.Logf("unknown: got error as expected: %v", err)
|
t.Logf("policy404: got error as expected: %v", err)
|
||||||
|
|
||||||
// Domain with an invalid policy (unknown version).
|
// Domain with an invalid policy (unknown version).
|
||||||
p, err = Fetch(context.Background(), "version99")
|
p, err = Fetch(context.Background(), "version99")
|
||||||
@@ -161,6 +184,14 @@ func TestFetch(t *testing.T) {
|
|||||||
ErrUnknownVersion, err, p)
|
ErrUnknownVersion, err, p)
|
||||||
}
|
}
|
||||||
t.Logf("version99: got expected error: %v", err)
|
t.Logf("version99: got expected error: %v", err)
|
||||||
|
|
||||||
|
// Error fetching TXT record for this domain.
|
||||||
|
p, err = Fetch(context.Background(), "domErr")
|
||||||
|
if err != testError {
|
||||||
|
t.Errorf("expected error %v, got %v (and policy: %v)",
|
||||||
|
testError, err, p)
|
||||||
|
}
|
||||||
|
t.Logf("domErr: got expected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicyTooBig(t *testing.T) {
|
func TestPolicyTooBig(t *testing.T) {
|
||||||
@@ -345,6 +376,7 @@ func TestCacheRefresh(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
txtResults["_mta-sts.refresh-test"] = []string{"v=STSv1; id=blah;"}
|
||||||
policyForDomain["refresh-test"] = `
|
policyForDomain["refresh-test"] = `
|
||||||
version: STSv1
|
version: STSv1
|
||||||
mode: enforce
|
mode: enforce
|
||||||
@@ -393,3 +425,30 @@ func TestURLForDomain(t *testing.T) {
|
|||||||
t.Errorf("got %q, expected %q", got, expected)
|
t.Errorf("got %q, expected %q", got, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasSTSRecord(t *testing.T) {
|
||||||
|
txtResults["_mta-sts.dom1"] = nil
|
||||||
|
txtResults["_mta-sts.dom2"] = []string{}
|
||||||
|
txtResults["_mta-sts.dom3"] = []string{"abc", "def"}
|
||||||
|
txtResults["_mta-sts.dom4"] = []string{"abc", "v=STSv1; id=blah;"}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
domain string
|
||||||
|
ok bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"", false, nil},
|
||||||
|
{"dom1", false, nil},
|
||||||
|
{"dom2", false, nil},
|
||||||
|
{"dom3", false, nil},
|
||||||
|
{"dom4", true, nil},
|
||||||
|
{"domErr", false, testError},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
ok, err := hasSTSRecord(c.domain)
|
||||||
|
if ok != c.ok || err != c.err {
|
||||||
|
t.Errorf("%s: expected {%v, %v}, got {%v, %v}", c.domain,
|
||||||
|
c.ok, c.err, ok, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user