diff --git a/pkg/extension/event/events.go b/pkg/extension/event/events.go index 4f83e3b..38fcd61 100644 --- a/pkg/extension/event/events.go +++ b/pkg/extension/event/events.go @@ -5,6 +5,12 @@ import ( "time" ) +// AddressParts contains the local and domain parts of an email address. +type AddressParts struct { + Local string + Domain string +} + // MessageMetadata contains the basic header data for a message event. type MessageMetadata struct { Mailbox string diff --git a/pkg/extension/host.go b/pkg/extension/host.go index 7889050..5d3ab3e 100644 --- a/pkg/extension/host.go +++ b/pkg/extension/host.go @@ -21,6 +21,7 @@ type Host struct { // listener will not be called until the one before it complets. type Events struct { AfterMessageStored EventBroker[event.MessageMetadata, Void] + BeforeMailAccepted EventBroker[event.AddressParts, bool] } // Void indicates the event emitter will ignore any value returned by listeners. diff --git a/pkg/server/lifecycle.go b/pkg/server/lifecycle.go index fa5be3c..d4d4f3d 100644 --- a/pkg/server/lifecycle.go +++ b/pkg/server/lifecycle.go @@ -56,7 +56,7 @@ func FullAssembly(conf *config.Root) (*Services, error) { webServer := web.NewServer(conf, mmanager, msgHub) pop3Server := pop3.NewServer(conf.POP3, store) - smtpServer := smtp.NewServer(conf.SMTP, mmanager, addrPolicy) + smtpServer := smtp.NewServer(conf.SMTP, mmanager, addrPolicy, extHost) return &Services{ MsgHub: msgHub, diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index 98e08da..01f2fc8 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/inbucket/inbucket/pkg/extension/event" "github.com/inbucket/inbucket/pkg/policy" "github.com/rs/zerolog" ) @@ -383,7 +384,8 @@ func (s *Session) readyHandler(cmd string, arg string) { return } from := m[1] - if _, _, err := policy.ParseEmailAddress(from); from != "" && err != nil { + localpart, domain, err := policy.ParseEmailAddress(from) + if from != "" && err != nil { s.send("501 Bad sender address syntax") s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err) return @@ -415,10 +417,22 @@ func (s *Session) readyHandler(cmd string, arg string) { } } } - s.from = from - s.logger.Info().Msgf("Mail from: %v", from) - s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) - s.enterState(MAIL) + + // Process through extensions. + extResult := s.extHost.Events.BeforeMailAccepted.Emit( + &event.AddressParts{Local: localpart, Domain: domain}) + + if extResult == nil || *extResult { + // Permitted by extension, or none had an opinion. + s.from = from + s.logger.Info().Msgf("Mail from: %v", from) + s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) + s.enterState(MAIL) + } else { + s.send("550 Mail denied by policy") + s.logger.Warn().Msgf("Extension denied mail from <%v>", from) + return + } } else if cmd == "EHLO" { // Reset session s.logger.Debug().Msgf("Resetting session state on EHLO request") @@ -614,7 +628,9 @@ func (s *Session) parseCmd(line string) (cmd string, arg string, ok bool) { // parseArgs takes the arguments proceeding a command and files them // into a map[string]string after uppercasing each key. Sample arg // string: -// " BODY=8BITMIME SIZE=1024" +// +// " BODY=8BITMIME SIZE=1024" +// // The leading space is mandatory. func (s *Session) parseArgs(arg string) (args map[string]string, ok bool) { args = make(map[string]string) diff --git a/pkg/server/smtp/handler_test.go b/pkg/server/smtp/handler_test.go index 22c574c..62d794f 100644 --- a/pkg/server/smtp/handler_test.go +++ b/pkg/server/smtp/handler_test.go @@ -10,11 +10,14 @@ import ( "time" "github.com/inbucket/inbucket/pkg/config" + "github.com/inbucket/inbucket/pkg/extension" + "github.com/inbucket/inbucket/pkg/extension/event" "github.com/inbucket/inbucket/pkg/message" "github.com/inbucket/inbucket/pkg/policy" "github.com/inbucket/inbucket/pkg/storage" "github.com/inbucket/inbucket/pkg/test" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" ) type scriptStep struct { @@ -25,7 +28,7 @@ type scriptStep struct { // Test valid commands in GREET state. func TestGreetStateValidCommands(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) tests := []scriptStep{ {"HELO mydomain", 250}, @@ -56,7 +59,7 @@ func TestGreetStateValidCommands(t *testing.T) { // Test invalid commands in GREET state. func TestGreetState(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) defer server.Drain() // Required to prevent test logging data race. tests := []scriptStep{ @@ -83,7 +86,7 @@ func TestGreetState(t *testing.T) { func TestEmptyEnvelope(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) defer server.Drain() // Test out some empty envelope without blanks @@ -108,7 +111,7 @@ func TestEmptyEnvelope(t *testing.T) { // Test AUTH commands. func TestAuth(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) defer server.Drain() // PLAIN AUTH @@ -145,7 +148,7 @@ func TestAuth(t *testing.T) { // Test TLS commands. func TestTLS(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) defer server.Drain() // Test Start TLS parsing. @@ -162,7 +165,7 @@ func TestTLS(t *testing.T) { // Test valid commands in READY state. func TestReadyStateValidCommands(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) // Test out some valid MAIL commands tests := []scriptStep{ @@ -198,7 +201,7 @@ func TestReadyStateValidCommands(t *testing.T) { // Test invalid commands in READY state. func TestReadyStateInvalidCommands(t *testing.T) { ds := test.NewStore() - server := setupSMTPServer(ds) + server := setupSMTPServer(ds, extension.NewHost()) tests := []scriptStep{ {"FOOB", 500}, @@ -231,7 +234,7 @@ func TestReadyStateInvalidCommands(t *testing.T) { // Test commands in MAIL state func TestMailState(t *testing.T) { mds := test.NewStore() - server := setupSMTPServer(mds) + server := setupSMTPServer(mds, extension.NewHost()) defer server.Drain() // Test out some mangled READY commands @@ -338,7 +341,7 @@ func TestMailState(t *testing.T) { // Test commands in DATA state func TestDataState(t *testing.T) { mds := test.NewStore() - server := setupSMTPServer(mds) + server := setupSMTPServer(mds, extension.NewHost()) defer server.Drain() var script []scriptStep @@ -448,6 +451,93 @@ func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) err return nil } +// Tests "MAIL FROM" emits BeforeMailAccepted event. +func TestBeforeMailAcceptedEventEmitted(t *testing.T) { + ds := test.NewStore() + extHost := extension.NewHost() + server := setupSMTPServer(ds, extHost) + defer server.Drain() + + var got *event.AddressParts + extHost.Events.BeforeMailAccepted.AddListener( + "test", + func(addr event.AddressParts) *bool { + got = &addr + return nil + }) + + // Play and verify SMTP session. + script := []scriptStep{ + {"HELO localhost", 250}, + {"MAIL FROM:", 250}, + {"QUIT", 221}} + if err := playSession(t, server, script); err != nil { + t.Error(err) + } + + assert.NotNil(t, got, "BeforeMailListener did not receive Address") + assert.Equal(t, "john", got.Local, "Address local part had wrong value") + assert.Equal(t, "gmail.com", got.Domain, "Address domain part had wrong value") +} + +// Test "MAIL FROM" acts on BeforeMailAccepted event result. +func TestBeforeMailAcceptedEventResponse(t *testing.T) { + ds := test.NewStore() + extHost := extension.NewHost() + server := setupSMTPServer(ds, extHost) + defer server.Drain() + + var shouldReturn *bool + var gotEvent *event.AddressParts + extHost.Events.BeforeMailAccepted.AddListener( + "test", + func(addr event.AddressParts) *bool { + gotEvent = &addr + return shouldReturn + }) + + allowRes := true + denyRes := false + tcs := map[string]struct { + script scriptStep // Command to send and SMTP code expected. + eventRes *bool // Response to send from event listener. + }{ + "allow": { + script: scriptStep{"MAIL FROM:", 250}, + eventRes: &allowRes, + }, + "deny": { + script: scriptStep{"MAIL FROM:", 550}, + eventRes: &denyRes, + }, + "defer": { + script: scriptStep{"MAIL FROM:", 250}, + eventRes: nil, + }, + } + + for name, tc := range tcs { + tc := tc + t.Run(name, func(t *testing.T) { + // Reset event listener. + shouldReturn = tc.eventRes + gotEvent = nil + + // Play and verify SMTP session. + script := []scriptStep{ + {"HELO localhost", 250}, + tc.script, + {"QUIT", 221}} + if err := playSession(t, server, script); err != nil { + t.Error(err) + } + + assert.NotNil(t, gotEvent, "BeforeMailListener did not receive Address") + }) + } + +} + // net.Pipe does not implement deadlines type mockConn struct { net.Conn @@ -457,7 +547,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil } func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } -func setupSMTPServer(ds storage.Store) *Server { +func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server { cfg := &config.Root{ MailboxNaming: config.FullNaming, SMTP: config.SMTP{ @@ -475,7 +565,7 @@ func setupSMTPServer(ds storage.Store) *Server { addrPolicy := &policy.Addressing{Config: cfg} manager := &message.StoreManager{Store: ds} - return NewServer(cfg.SMTP, manager, addrPolicy) + return NewServer(cfg.SMTP, manager, addrPolicy, extHost) } var sessionNum int diff --git a/pkg/server/smtp/listener.go b/pkg/server/smtp/listener.go index 5d1cc07..382f1e4 100644 --- a/pkg/server/smtp/listener.go +++ b/pkg/server/smtp/listener.go @@ -10,6 +10,7 @@ import ( "time" "github.com/inbucket/inbucket/pkg/config" + "github.com/inbucket/inbucket/pkg/extension" "github.com/inbucket/inbucket/pkg/message" "github.com/inbucket/inbucket/pkg/metric" "github.com/inbucket/inbucket/pkg/policy" @@ -59,11 +60,12 @@ func init() { // Server holds the configuration and state of our SMTP server. type Server struct { config config.SMTP // SMTP configuration. + tlsConfig *tls.Config // TLS encryption configuration. addrPolicy *policy.Addressing // Address policy. manager message.Manager // Used to deliver messages. + extHost *extension.Host // Extension event processor. listener net.Listener // Incoming network connections. wg *sync.WaitGroup // Waitgroup tracks individual sessions. - tlsConfig *tls.Config // TLS encryption configuration. notify chan error // Notify on fatal error. } @@ -72,6 +74,7 @@ func NewServer( smtpConfig config.SMTP, manager message.Manager, apolicy *policy.Addressing, + extHost *extension.Host, ) *Server { slog := log.With().Str("module", "smtp").Str("phase", "tls").Logger() tlsConfig := &tls.Config{} @@ -90,10 +93,11 @@ func NewServer( return &Server{ config: smtpConfig, + tlsConfig: tlsConfig, manager: manager, addrPolicy: apolicy, + extHost: extHost, wg: new(sync.WaitGroup), - tlsConfig: tlsConfig, notify: make(chan error, 1), } } diff --git a/pkg/test/integration_test.go b/pkg/test/integration_test.go index ae0b01a..4f54ef7 100644 --- a/pkg/test/integration_test.go +++ b/pkg/test/integration_test.go @@ -243,7 +243,7 @@ func startServer() (func(), error) { go webServer.Start(svcCtx, func() {}) // Start SMTP server. - smtpServer := smtp.NewServer(conf.SMTP, mmanager, addrPolicy) + smtpServer := smtp.NewServer(conf.SMTP, mmanager, addrPolicy, extHost) go smtpServer.Start(svcCtx, func() {}) // TODO Use a readyFunc to determine server readiness.