mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
dkim: Implement internal dkim signing and verification
This patch implements internal DKIM signing and verification.
This commit is contained in:
235
internal/dkim/file_test.go
Normal file
235
internal/dkim/file_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package dkim
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestFromFiles(t *testing.T) {
|
||||
msgfs, err := filepath.Glob("testdata/*.msg")
|
||||
if err != nil {
|
||||
t.Fatalf("error finding test files: %v", err)
|
||||
}
|
||||
|
||||
for _, msgf := range msgfs {
|
||||
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
|
||||
t.Run(base, func(t *testing.T) { testOne(t, base) })
|
||||
}
|
||||
}
|
||||
|
||||
// This is the same as TestFromFiles, but it runs the private test files,
|
||||
// which are not included in the git repository.
|
||||
// This is useful for running tests on your own machine, with emails that you
|
||||
// don't necessarily want to share publicly.
|
||||
func TestFromPrivateFiles(t *testing.T) {
|
||||
msgfs, err := filepath.Glob("testdata/private/*/*.msg")
|
||||
if err != nil {
|
||||
t.Fatalf("error finding private test files: %v", err)
|
||||
}
|
||||
|
||||
for _, msgf := range msgfs {
|
||||
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
|
||||
t.Run(base, func(t *testing.T) { testOne(t, base) })
|
||||
}
|
||||
}
|
||||
|
||||
func testOne(t *testing.T, base string) {
|
||||
ctx := context.Background()
|
||||
ctx = WithTraceFunc(ctx, t.Logf)
|
||||
|
||||
ctx = loadDNS(t, ctx, base+".dns")
|
||||
msg := toCRLF(mustReadFile(t, base+".msg"))
|
||||
wantResult := loadResult(t, base+".result")
|
||||
wantError := loadError(t, base+".error")
|
||||
|
||||
t.Logf("Message: %.60q", msg)
|
||||
t.Logf("Want result: %+v", wantResult)
|
||||
t.Logf("Want error: %v", wantError)
|
||||
|
||||
res, err := VerifyMessage(ctx, msg)
|
||||
|
||||
// Write the results out for easy updating.
|
||||
writeResults(t, base, res, err)
|
||||
|
||||
diff := cmp.Diff(wantResult, res, cmp.Comparer(equalErrors))
|
||||
if diff != "" {
|
||||
t.Errorf("VerifyMessage result diff (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// We need to compare them by hand because cmp.Diff won't use our comparer
|
||||
// for top-level errors.
|
||||
if !equalErrors(wantError, err) {
|
||||
diff := cmp.Diff(wantError, err)
|
||||
t.Errorf("VerifyMessage error diff (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Used to make cmp.Diff compare errors by their messages. This is obviously
|
||||
// not great, but it's good enough for this test.
|
||||
func equalErrors(a, b error) bool {
|
||||
if a == nil {
|
||||
return b == nil
|
||||
}
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return a.Error() == b.Error()
|
||||
}
|
||||
|
||||
func mustReadFile(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
contents, err := os.ReadFile(path)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return ""
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error reading %q: %v", path, err)
|
||||
}
|
||||
return string(contents)
|
||||
}
|
||||
|
||||
func loadDNS(t *testing.T, ctx context.Context, path string) context.Context {
|
||||
t.Helper()
|
||||
|
||||
results := map[string][]string{}
|
||||
errors := map[string]error{}
|
||||
txtFunc := func(ctx context.Context, domain string) ([]string, error) {
|
||||
return results[domain], errors[domain]
|
||||
}
|
||||
ctx = WithLookupTXTFunc(ctx, txtFunc)
|
||||
|
||||
c := mustReadFile(t, path)
|
||||
|
||||
// Unfold \-terminated lines.
|
||||
c = strings.ReplaceAll(c, "\\\n", "")
|
||||
|
||||
for _, line := range strings.Split(c, "\n") {
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
domain, txt, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
switch strings.TrimSpace(txt) {
|
||||
case "TEMPERROR":
|
||||
errors[domain] = &net.DNSError{
|
||||
Err: "temporary error (for testing)",
|
||||
IsTemporary: true,
|
||||
}
|
||||
case "PERMERROR":
|
||||
errors[domain] = &net.DNSError{
|
||||
Err: "permanent error (for testing)",
|
||||
IsTemporary: false,
|
||||
}
|
||||
case "NOTFOUND":
|
||||
errors[domain] = &net.DNSError{
|
||||
Err: "domain not found (for testing)",
|
||||
IsNotFound: true,
|
||||
}
|
||||
default:
|
||||
results[domain] = append(results[domain], txt)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Loaded DNS results: %#v", results)
|
||||
t.Logf("Loaded DNS errors: %v", errors)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func loadResult(t *testing.T, path string) *VerifyResult {
|
||||
t.Helper()
|
||||
|
||||
res := &VerifyResult{}
|
||||
c := mustReadFile(t, path)
|
||||
if c == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(c), res)
|
||||
if err != nil {
|
||||
t.Fatalf("error unmarshalling %q: %v", path, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func loadError(t *testing.T, path string) error {
|
||||
t.Helper()
|
||||
|
||||
c := strings.TrimSpace(mustReadFile(t, path))
|
||||
if c == "" || c == "nil" || c == "<nil>" {
|
||||
return nil
|
||||
}
|
||||
return errors.New(c)
|
||||
}
|
||||
|
||||
func mustWriteFile(t *testing.T, path string, c []byte) {
|
||||
t.Helper()
|
||||
err := os.WriteFile(path, c, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing %q: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeResults(t *testing.T, base string, res *VerifyResult, err error) {
|
||||
t.Helper()
|
||||
|
||||
mustWriteFile(t, base+".error.got", []byte(fmt.Sprintf("%v", err)))
|
||||
|
||||
c, err := json.MarshalIndent(res, "", "\t")
|
||||
if err != nil {
|
||||
t.Fatalf("error marshalling result: %v", err)
|
||||
}
|
||||
mustWriteFile(t, base+".result.got", c)
|
||||
}
|
||||
|
||||
// Custom json marshaller so we can write errors as strings.
|
||||
func (or *OneResult) MarshalJSON() ([]byte, error) {
|
||||
// We use an alias to avoid infinite recursion.
|
||||
type Alias OneResult
|
||||
aux := &struct {
|
||||
Error string `json:""`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(or),
|
||||
}
|
||||
if or.Error != nil {
|
||||
aux.Error = or.Error.Error()
|
||||
}
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// Custom json unmarshaller so we can read errors as strings.
|
||||
func (or *OneResult) UnmarshalJSON(b []byte) error {
|
||||
// We use an alias to avoid infinite recursion.
|
||||
type Alias OneResult
|
||||
aux := &struct {
|
||||
Error string `json:""`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(or),
|
||||
}
|
||||
if err := json.Unmarshal(b, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if aux.Error != "" {
|
||||
or.Error = errors.New(aux.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user