mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
Implement AUTH
This patch implements the AUTH SMTP command, using per-domain user databases. Note that we don't really use or check the validation for anything, this is just implementing the command itself.
This commit is contained in:
136
chasquid.go
136
chasquid.go
@@ -12,15 +12,18 @@ import (
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/auth"
|
||||
"blitiri.com.ar/go/chasquid/internal/config"
|
||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||
"blitiri.com.ar/go/chasquid/internal/queue"
|
||||
"blitiri.com.ar/go/chasquid/internal/systemd"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
@@ -40,6 +43,9 @@ var (
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
// Seed the PRNG, just to prevent for it to be totally predictable.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
conf, err := config.Load(*configDir + "/chasquid.conf")
|
||||
if err != nil {
|
||||
glog.Fatalf("Error reading config")
|
||||
@@ -69,10 +75,9 @@ func main() {
|
||||
} else {
|
||||
glog.Infof("Domain config paths:")
|
||||
for _, info := range domainDirs {
|
||||
glog.Infof(" %s", info.Name())
|
||||
s.AddDomain(info.Name())
|
||||
dir := filepath.Join(*configDir, "domains", info.Name())
|
||||
s.AddCerts(dir+"/cert.pem", dir+"/key.pem")
|
||||
name := info.Name()
|
||||
dir := filepath.Join(*configDir, "domains", name)
|
||||
loadDomain(s, name, dir)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +112,27 @@ func main() {
|
||||
s.ListenAndServe()
|
||||
}
|
||||
|
||||
// Helper to load a single domain configuration into the server.
|
||||
func loadDomain(s *Server, name, dir string) {
|
||||
glog.Infof(" %s", name)
|
||||
s.AddDomain(name)
|
||||
s.AddCerts(dir+"/cert.pem", dir+"/key.pem")
|
||||
|
||||
if _, err := os.Stat(dir + "/users"); err == nil {
|
||||
glog.Infof(" adding users")
|
||||
udb, warnings, err := userdb.Load(dir + "/users")
|
||||
if err != nil {
|
||||
glog.Errorf(" error: %v", err)
|
||||
} else {
|
||||
for _, w := range warnings {
|
||||
glog.Warningf(" %v", w)
|
||||
}
|
||||
s.AddUserDB(name, udb)
|
||||
// TODO: periodically reload the database.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// Main hostname, used for display only.
|
||||
Hostname string
|
||||
@@ -129,6 +155,12 @@ type Server struct {
|
||||
// Local domains.
|
||||
localDomains map[string]bool
|
||||
|
||||
// User databases (per domain).
|
||||
userDBs map[string]*userdb.DB
|
||||
|
||||
// Local courier.
|
||||
localCourier courier.Courier
|
||||
|
||||
// Time before we give up on a connection, even if it's sending data.
|
||||
connTimeout time.Duration
|
||||
|
||||
@@ -144,6 +176,7 @@ func NewServer() *Server {
|
||||
connTimeout: 20 * time.Minute,
|
||||
commandTimeout: 1 * time.Minute,
|
||||
localDomains: map[string]bool{},
|
||||
userDBs: map[string]*userdb.DB{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,6 +197,10 @@ func (s *Server) AddDomain(d string) {
|
||||
s.localDomains[d] = true
|
||||
}
|
||||
|
||||
func (s *Server) AddUserDB(domain string, db *userdb.DB) {
|
||||
s.userDBs[domain] = db
|
||||
}
|
||||
|
||||
func (s *Server) getTLSConfig() (*tls.Config, error) {
|
||||
var err error
|
||||
conf := &tls.Config{}
|
||||
@@ -241,6 +278,7 @@ func (s *Server) serve(l net.Listener) {
|
||||
netconn: conn,
|
||||
tc: textproto.NewConn(conn),
|
||||
tlsConfig: s.tlsConfig,
|
||||
userDBs: s.userDBs,
|
||||
deadline: time.Now().Add(s.connTimeout),
|
||||
commandTimeout: s.commandTimeout,
|
||||
queue: s.queue,
|
||||
@@ -274,6 +312,19 @@ type Conn struct {
|
||||
// Are we using TLS?
|
||||
onTLS bool
|
||||
|
||||
// User databases - taken from the server at creation time.
|
||||
userDBs map[string]*userdb.DB
|
||||
|
||||
// Have we successfully completed AUTH?
|
||||
completedAuth bool
|
||||
|
||||
// How many times have we attempted AUTH?
|
||||
authAttempts int
|
||||
|
||||
// Authenticated user and domain, empty if !completedAuth.
|
||||
authUser string
|
||||
authDomain string
|
||||
|
||||
// When we should close this connection, no matter what.
|
||||
deadline time.Time
|
||||
|
||||
@@ -341,6 +392,8 @@ loop:
|
||||
code, msg = c.DATA(params, tr)
|
||||
case "STARTTLS":
|
||||
code, msg = c.STARTTLS(params, tr)
|
||||
case "AUTH":
|
||||
code, msg = c.AUTH(params, tr)
|
||||
case "QUIT":
|
||||
c.writeResponse(221, "Be seeing you...")
|
||||
break loop
|
||||
@@ -383,7 +436,11 @@ func (c *Conn) EHLO(params string) (code int, msg string) {
|
||||
fmt.Fprintf(buf, "8BITMIME\n")
|
||||
fmt.Fprintf(buf, "PIPELINING\n")
|
||||
fmt.Fprintf(buf, "SIZE %d\n", c.maxDataSize)
|
||||
if c.onTLS {
|
||||
fmt.Fprintf(buf, "AUTH PLAIN\n")
|
||||
} else {
|
||||
fmt.Fprintf(buf, "STARTTLS\n")
|
||||
}
|
||||
fmt.Fprintf(buf, "HELP\n")
|
||||
return 250, buf.String()
|
||||
}
|
||||
@@ -582,6 +639,73 @@ func (c *Conn) STARTTLS(params string, tr *trace.Trace) (code int, msg string) {
|
||||
return 0, ""
|
||||
}
|
||||
|
||||
func (c *Conn) AUTH(params string, tr *trace.Trace) (code int, msg string) {
|
||||
if !c.onTLS {
|
||||
return 503, "You feel vulnerable"
|
||||
}
|
||||
|
||||
if c.completedAuth {
|
||||
// After a successful AUTH command completes, a server MUST reject
|
||||
// any further AUTH commands with a 503 reply.
|
||||
// https://tools.ietf.org/html/rfc4954#section-4
|
||||
return 503, "You are already wearing that!"
|
||||
}
|
||||
|
||||
if c.authAttempts > 3 {
|
||||
// TODO: close the connection?
|
||||
return 503, "Too many attempts - go away"
|
||||
}
|
||||
c.authAttempts++
|
||||
|
||||
// We only support PLAIN for now, so no need to make this too complicated.
|
||||
// Params should be either "PLAIN" or "PLAIN <response>".
|
||||
// If the response is not there, we reply with 334, and expect the
|
||||
// response back from the client in the next message.
|
||||
|
||||
sp := strings.SplitN(params, " ", 2)
|
||||
if len(sp) < 1 || sp[0] != "PLAIN" {
|
||||
// As we only offer plain, this should not really happen.
|
||||
return 534, "Asmodeus demands 534 zorkmids for safe passage"
|
||||
}
|
||||
|
||||
// Note we use more "serious" error messages from now own, as these may
|
||||
// find their way to the users in some circumstances.
|
||||
|
||||
// Get the response, either from the message or interactively.
|
||||
response := ""
|
||||
if len(sp) == 2 {
|
||||
response = sp[1]
|
||||
} else {
|
||||
// Reply 334 and expect the user to provide it.
|
||||
// In this case, the text IS relevant, as it is taken as the
|
||||
// server-side SASL challenge (empty for PLAIN).
|
||||
// https://tools.ietf.org/html/rfc4954#section-4
|
||||
err := c.writeResponse(334, "")
|
||||
if err != nil {
|
||||
return 554, fmt.Sprintf("error writing AUTH 334: %v", err)
|
||||
}
|
||||
|
||||
response, err = c.readLine()
|
||||
if err != nil {
|
||||
return 554, fmt.Sprintf("error reading AUTH response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
user, domain, passwd, err := auth.DecodeResponse(response)
|
||||
if err != nil {
|
||||
return 535, fmt.Sprintf("error decoding AUTH response: %v", err)
|
||||
}
|
||||
|
||||
if auth.Authenticate(c.userDBs[domain], user, passwd) {
|
||||
c.authUser = user
|
||||
c.authDomain = domain
|
||||
c.completedAuth = true
|
||||
return 235, ""
|
||||
} else {
|
||||
return 535, "Incorrect user or password"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) resetEnvelope() {
|
||||
c.mail_from = ""
|
||||
c.rcpt_to = nil
|
||||
@@ -605,6 +729,10 @@ func (c *Conn) readCommand() (cmd, params string, err error) {
|
||||
return cmd, params, err
|
||||
}
|
||||
|
||||
func (c *Conn) readLine() (line string, err error) {
|
||||
return c.tc.ReadLine()
|
||||
}
|
||||
|
||||
func (c *Conn) writeResponse(code int, msg string) error {
|
||||
defer c.tc.W.Flush()
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
@@ -66,8 +68,18 @@ func mustDial(tb testing.TB, useTLS bool) *smtp.Client {
|
||||
}
|
||||
|
||||
func sendEmail(tb testing.TB, c *smtp.Client) {
|
||||
sendEmailWithAuth(tb, c, nil)
|
||||
}
|
||||
|
||||
func sendEmailWithAuth(tb testing.TB, c *smtp.Client, auth smtp.Auth) {
|
||||
var err error
|
||||
|
||||
if auth != nil {
|
||||
if err = c.Auth(auth); err != nil {
|
||||
tb.Errorf("Auth: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = c.Mail("from@from"); err != nil {
|
||||
tb.Errorf("Mail: %v", err)
|
||||
}
|
||||
@@ -111,6 +123,14 @@ func TestManyEmails(t *testing.T) {
|
||||
sendEmail(t, c)
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
c := mustDial(t, true)
|
||||
defer c.Close()
|
||||
|
||||
auth := smtp.PlainAuth("", "testuser@localhost", "testpasswd", "127.0.0.1")
|
||||
sendEmailWithAuth(t, c, auth)
|
||||
}
|
||||
|
||||
func TestWrongMailParsing(t *testing.T) {
|
||||
c := mustDial(t, false)
|
||||
defer c.Close()
|
||||
@@ -360,6 +380,11 @@ func realMain(m *testing.M) int {
|
||||
s.MaxDataSize = 50 * 1024 * 1025
|
||||
s.AddCerts(tmpDir+"/cert.pem", tmpDir+"/key.pem")
|
||||
s.AddAddr(srvAddr)
|
||||
|
||||
udb := userdb.New("/dev/null")
|
||||
udb.AddUser("testuser", "testpasswd")
|
||||
s.AddUserDB("localhost", udb)
|
||||
|
||||
go s.ListenAndServe()
|
||||
}
|
||||
|
||||
|
||||
110
internal/auth/auth.go
Normal file
110
internal/auth/auth.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
)
|
||||
|
||||
// DecodeResponse decodes a plain auth response.
|
||||
//
|
||||
// It must be a a base64-encoded string of the form:
|
||||
// <authorization id> NUL <authentication id> NUL <password>
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4954#section-4.1.
|
||||
//
|
||||
// Either both ID match, or one of them is empty.
|
||||
// We expect the ID to be "user@domain", which is NOT an RFC requirement but
|
||||
// our own.
|
||||
func DecodeResponse(response string) (user, domain, passwd string, err error) {
|
||||
buf, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bufsp := bytes.SplitN(buf, []byte{0}, 3)
|
||||
if len(bufsp) != 3 {
|
||||
err = fmt.Errorf("Response pieces != 3, as per RFC")
|
||||
return
|
||||
}
|
||||
|
||||
identity := ""
|
||||
passwd = string(bufsp[2])
|
||||
|
||||
{
|
||||
// We don't make the distinction between the two IDs, as long as one is
|
||||
// empty, or they're the same.
|
||||
z := string(bufsp[0])
|
||||
c := string(bufsp[1])
|
||||
|
||||
// If neither is empty, then they must be the same.
|
||||
if (z != "" && c != "") && (z != c) {
|
||||
err = fmt.Errorf("Auth IDs do not match")
|
||||
return
|
||||
}
|
||||
|
||||
if z != "" {
|
||||
identity = z
|
||||
}
|
||||
if c != "" {
|
||||
identity = c
|
||||
}
|
||||
}
|
||||
|
||||
if identity == "" {
|
||||
err = fmt.Errorf("Empty identity, must be in the form user@domain")
|
||||
return
|
||||
}
|
||||
|
||||
// Identity must be in the form "user@domain".
|
||||
// This is NOT an RFC requirement, it's our own.
|
||||
idsp := strings.SplitN(identity, "@", 2)
|
||||
if len(idsp) != 2 {
|
||||
err = fmt.Errorf("Identity must be in the form user@domain")
|
||||
return
|
||||
}
|
||||
|
||||
user = idsp[0]
|
||||
domain = idsp[1]
|
||||
|
||||
// TODO: Quedamos aca. Validar dominio no (solo) como utf8, sino ver que
|
||||
// no contenga ni "/" ni "..". Podemos usar golang.org/x/net/idna para
|
||||
// convertirlo a unicode primero, o al reves. No se que queremos.
|
||||
if !utf8.ValidString(user) || !utf8.ValidString(domain) {
|
||||
err = fmt.Errorf("User/domain is not valid UTF-8")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// How long Authenticate calls should last, approximately.
|
||||
// This will be applied both for successful and unsuccessful attempts.
|
||||
// We will increase this number by 0-20%.
|
||||
var AuthenticateTime = 100 * time.Millisecond
|
||||
|
||||
// Authenticate user/password on the given database.
|
||||
func Authenticate(udb *userdb.DB, user, passwd string) bool {
|
||||
defer func(start time.Time) {
|
||||
elapsed := time.Since(start)
|
||||
delay := AuthenticateTime - elapsed
|
||||
if delay > 0 {
|
||||
maxDelta := int64(float64(delay) * 0.2)
|
||||
delay += time.Duration(rand.Int63n(maxDelta))
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
// Note that the database CAN be nil, to simplify callers.
|
||||
if udb == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return udb.Authenticate(user, passwd)
|
||||
}
|
||||
93
internal/auth/auth_test.go
Normal file
93
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
)
|
||||
|
||||
func TestDecodeResponse(t *testing.T) {
|
||||
// Successful cases. Note we hard-code the response for extra assurance.
|
||||
cases := []struct {
|
||||
response, user, domain, passwd string
|
||||
}{
|
||||
{"dUBkAHVAZABwYXNz", "u", "d", "pass"}, // u@d\0u@d\0pass
|
||||
{"dUBkAABwYXNz", "u", "d", "pass"}, // u@d\0\0pass
|
||||
{"AHVAZABwYXNz", "u", "d", "pass"}, // \0u@d\0pass
|
||||
{"dUBkAABwYXNz/w==", "u", "d", "pass\xff"}, // u@d\0\0pass\xff
|
||||
|
||||
// "ñaca@ñeque\0\0clavaré"
|
||||
{"w7FhY2FAw7FlcXVlAABjbGF2YXLDqQ==", "ñaca", "ñeque", "clavaré"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
u, d, p, err := DecodeResponse(c.response)
|
||||
if err != nil {
|
||||
t.Errorf("Error in case %v: %v", c, err)
|
||||
}
|
||||
|
||||
if u != c.user || d != c.domain || p != c.passwd {
|
||||
t.Errorf("Expected %q %q %q ; got %q %q %q",
|
||||
c.user, c.domain, c.passwd, u, d, p)
|
||||
}
|
||||
}
|
||||
|
||||
_, _, _, err := DecodeResponse("this is not base64 encoded")
|
||||
if err == nil {
|
||||
t.Errorf("invalid base64 did not fail as expected")
|
||||
}
|
||||
|
||||
failedCases := []string{
|
||||
"", "\x00", "\x00\x00", "\x00\x00\x00", "\x00\x00\x00\x00",
|
||||
"a\x00b", "a\x00b\x00c", "a@a\x00b@b\x00pass", "a\x00a\x00pass",
|
||||
"\xffa@b\x00\xffa@b\x00pass",
|
||||
}
|
||||
for _, c := range failedCases {
|
||||
r := base64.StdEncoding.EncodeToString([]byte(c))
|
||||
_, _, _, err := DecodeResponse(r)
|
||||
if err == nil {
|
||||
t.Errorf("Expected case %q to fail, but succeeded", c)
|
||||
} else {
|
||||
t.Logf("OK: %q failed with %v", c, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
db := userdb.New("/dev/null")
|
||||
db.AddUser("user", "password")
|
||||
|
||||
// Test the correct case first
|
||||
ts := time.Now()
|
||||
if !Authenticate(db, "user", "password") {
|
||||
t.Errorf("failed valid authentication for user/password")
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
}
|
||||
|
||||
// Incorrect cases.
|
||||
cases := []struct{ user, password string }{
|
||||
{"user", "incorrect"},
|
||||
{"invalid", "p"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
ts = time.Now()
|
||||
if Authenticate(db, c.user, c.password) {
|
||||
t.Errorf("successful auth on %v", c)
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
}
|
||||
}
|
||||
|
||||
// And the special case of a nil userdb.
|
||||
ts = time.Now()
|
||||
if Authenticate(nil, "user", "password") {
|
||||
t.Errorf("successful auth on a nil userdb")
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user