diff --git a/smtpd/utils.go b/smtpd/utils.go index 9af1d3f..c06d11c 100644 --- a/smtpd/utils.go +++ b/smtpd/utils.go @@ -91,26 +91,35 @@ func ValidateDomainPart(domain string) bool { return true } -// ValidateLocalPart returns true if the string complies with RFC3696 recommendations -func ValidateLocalPart(local string) bool { - length := len(local) - if 1 > length || length > 64 { - // Invalid length - return false +// ParseEmailAddress unescapes an email address, and splits the local part from the domain part. +// An error is returned if the local or domain parts fail validation following the guidelines +// in RFC3696. +func ParseEmailAddress(address string) (local string, domain string, err error) { + if address == "" { + return "", "", fmt.Errorf("Empty address") } - if local[length-1] == '.' { - // Cannot end with a period - return false + if len(address) > 320 { + return "", "", fmt.Errorf("Address exceeds 320 characters") + } + if address[0] == '@' { + return "", "", fmt.Errorf("Address cannot start with @ symbol") + } + if address[0] == '.' { + return "", "", fmt.Errorf("Address cannot start with a period") } + // Loop over address parsing out local part + buf := new(bytes.Buffer) prev := byte('.') inCharQuote := false inStringQuote := false - for i := 0; i < length; i++ { - c := local[i] +LOOP: + for i := 0; i < len(address); i++ { + c := address[i] switch { case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'): // Letters are OK + buf.WriteByte(c) inCharQuote = false case '0' <= c && c <= '9': // Numbers are OK @@ -122,31 +131,57 @@ func ValidateLocalPart(local string) bool { // A single period is OK if prev == '.' { // Sequence of periods is not permitted - return false + return "", "", fmt.Errorf("Sequence of periods is not permitted") } case c == '\\': inCharQuote = true case c == '"': if inCharQuote { inCharQuote = false + } else if inStringQuote { + inStringQuote = false } else { - inStringQuote = !inStringQuote + if i == 0 { + inStringQuote = true + } else { + return "", "", fmt.Errorf("Quoted string can only begin at start of address") + } + } + case c == '@': + if inCharQuote || inStringQuote { + inCharQuote = false + } else { + // End of local-part + if i > 63 { + return "", "", fmt.Errorf("Local part must not exceed 64 characters") + } + if prev == '.' { + return "", "", fmt.Errorf("Local part cannot end with a period") + } + domain = address[i+1:] + break LOOP } case c > 127: - return false + return "", "", fmt.Errorf("Characters outside of US-ASCII range not permitted") default: if inCharQuote || inStringQuote { inCharQuote = false - return true + } else { + return "", "", fmt.Errorf("Character %q must be quoted", c) } - return false } prev = c } - if inCharQuote || inStringQuote { - // Can't end with unused backslash quote or unterminated string quote - return false + if inCharQuote { + return "", "", fmt.Errorf("Cannot end address with unterminated quoted-pair") + } + if inStringQuote { + return "", "", fmt.Errorf("Cannot end address with unterminated string quote") } - return true + if !ValidateDomainPart(domain) { + return "", "", fmt.Errorf("Domain part validation failed") + } + + return buf.String(), domain, nil } diff --git a/smtpd/utils_test.go b/smtpd/utils_test.go index b0d4ad4..8812396 100644 --- a/smtpd/utils_test.go +++ b/smtpd/utils_test.go @@ -53,9 +53,9 @@ func TestValidateDomain(t *testing.T) { func TestValidateLocal(t *testing.T) { var testTable = []struct { - input string + input string expect bool - msg string + msg string }{ {"", false, "Empty local is not valid"}, {"a", true, "Single letter should be fine"}, @@ -99,8 +99,36 @@ func TestValidateLocal(t *testing.T) { } for _, tt := range testTable { - if ValidateLocalPart(tt.input) != tt.expect { + _, _, err := ParseEmailAddress(tt.input + "@domain.com") + if (err != nil) == tt.expect { + if err != nil { + t.Logf("Got error: %s", err) + } t.Errorf("Expected %v for %q: %s", tt.expect, tt.input, tt.msg) } } } + +func TestParseEmailAddress(t *testing.T) { + var testTable = []struct { + input, local, domain string + }{ + {"root@localhost", "root", "localhost"}, + } + + for _, tt := range testTable { + local, domain, err := ParseEmailAddress(tt.input) + if err != nil { + t.Errorf("Error when parsing %q: %s", tt.input, err) + } else { + if tt.local != local { + t.Errorf("When parsing %q, expected local %q, got %q instead", + tt.input, tt.local, local) + } + if tt.domain != domain { + t.Errorf("When parsing %q, expected domain %q, got %q instead", + tt.input, tt.domain, domain) + } + } + } +}