diff --git a/pkg/extension/host.go b/pkg/extension/host.go index c61d531..e2a5f68 100644 --- a/pkg/extension/host.go +++ b/pkg/extension/host.go @@ -20,11 +20,11 @@ type Host struct { // processed asynchronously with respect to the rest of Inbuckets operation. However, an event // listener will not be called until the one before it completes. type Events struct { - AfterMessageDeleted AsyncEventBroker[event.MessageMetadata] - AfterMessageStored AsyncEventBroker[event.MessageMetadata] - BeforeMailAccepted EventBroker[event.AddressParts, event.SMTPResponse] - BeforeMessageStored EventBroker[event.InboundMessage, event.InboundMessage] - BeforeRcptToAccepted EventBroker[event.SMTPSession, event.SMTPResponse] + AfterMessageDeleted AsyncEventBroker[event.MessageMetadata] + AfterMessageStored AsyncEventBroker[event.MessageMetadata] + BeforeMailFromAccepted EventBroker[event.SMTPSession, event.SMTPResponse] + BeforeMessageStored EventBroker[event.InboundMessage, event.InboundMessage] + BeforeRcptToAccepted EventBroker[event.SMTPSession, event.SMTPResponse] } // Void indicates the event emitter will ignore any value returned by listeners. diff --git a/pkg/extension/luahost/bind_inbucket.go b/pkg/extension/luahost/bind_inbucket.go index e01cfe0..64fa18e 100644 --- a/pkg/extension/luahost/bind_inbucket.go +++ b/pkg/extension/luahost/bind_inbucket.go @@ -29,9 +29,9 @@ type InbucketAfterFuncs struct { // InbucketBeforeFuncs holds references to Lua extension functions to be called // before Inbucket handles an event. type InbucketBeforeFuncs struct { - MailAccepted *lua.LFunction - MessageStored *lua.LFunction - RcptToAccepted *lua.LFunction + MailFromAccepted *lua.LFunction + MessageStored *lua.LFunction + RcptToAccepted *lua.LFunction } func registerInbucketTypes(ls *lua.LState) { @@ -186,8 +186,8 @@ func inbucketBeforeIndex(ls *lua.LState) int { // Push the requested field's value onto the stack. switch field { - case "mail_accepted": - ls.Push(funcOrNil(before.MailAccepted)) + case "mail_from_accepted": + ls.Push(funcOrNil(before.MailFromAccepted)) case "message_stored": ls.Push(funcOrNil(before.MessageStored)) case "rcpt_to_accepted": @@ -206,8 +206,8 @@ func inbucketBeforeNewIndex(ls *lua.LState) int { index := ls.CheckString(2) switch index { - case "mail_accepted": - m.MailAccepted = ls.CheckFunction(3) + case "mail_from_accepted": + m.MailFromAccepted = ls.CheckFunction(3) case "message_stored": m.MessageStored = ls.CheckFunction(3) case "rcpt_to_accepted": diff --git a/pkg/extension/luahost/bind_inbucket_test.go b/pkg/extension/luahost/bind_inbucket_test.go index 5214034..714196f 100644 --- a/pkg/extension/luahost/bind_inbucket_test.go +++ b/pkg/extension/luahost/bind_inbucket_test.go @@ -62,7 +62,7 @@ func TestInbucketBeforeFuncs(t *testing.T) { assert(inbucket, "inbucket should not be nil") assert(inbucket.before, "inbucket.before should not be nil") - local fns = { "mail_accepted", "message_stored" } + local fns = { "mail_from_accepted", "message_stored", "rcpt_to_accepted" } -- Verify functions start off nil. for i, name in ipairs(fns) do diff --git a/pkg/extension/luahost/lua.go b/pkg/extension/luahost/lua.go index d0db001..ad46273 100644 --- a/pkg/extension/luahost/lua.go +++ b/pkg/extension/luahost/lua.go @@ -106,8 +106,8 @@ func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) { if ib.After.MessageStored != nil { events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored) } - if ib.Before.MailAccepted != nil { - events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted) + if ib.Before.MailFromAccepted != nil { + events.BeforeMailFromAccepted.AddListener(listenerName, h.handleBeforeMailFromAccepted) } if ib.Before.MessageStored != nil { events.BeforeMessageStored.AddListener(listenerName, h.handleBeforeMessageStored) @@ -151,18 +151,17 @@ func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) { } } -func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *event.SMTPResponse { - logger, ls, ib, ok := h.prepareInbucketFuncCall("before.mail_accepted") +func (h *Host) handleBeforeMailFromAccepted(session event.SMTPSession) *event.SMTPResponse { + logger, ls, ib, ok := h.prepareInbucketFuncCall("before.mail_from_accepted") if !ok { return nil } defer h.pool.putState(ls) - logger.Debug().Msgf("Calling Lua function with %+v", addr) + logger.Debug().Msgf("Calling Lua function with %+v", session) if err := ls.CallByParam( - lua.P{Fn: ib.Before.MailAccepted, NRet: 1, Protect: true}, - lua.LString(addr.Local), - lua.LString(addr.Domain), + lua.P{Fn: ib.Before.MailFromAccepted, NRet: 1, Protect: true}, + wrapSMTPSession(ls, &session), ); err != nil { logger.Error().Err(err).Msg("Failed to call Lua function") return nil diff --git a/pkg/extension/luahost/lua_test.go b/pkg/extension/luahost/lua_test.go index 84cd104..f16d452 100644 --- a/pkg/extension/luahost/lua_test.go +++ b/pkg/extension/luahost/lua_test.go @@ -105,11 +105,11 @@ func TestAfterMessageStored(t *testing.T) { test.AssertNotified(t, notify) } -func TestBeforeMailAccepted(t *testing.T) { +func TestBeforeMailFromAccepted(t *testing.T) { // Register lua event listener. script := ` - function inbucket.before.mail_accepted(localpart, domain) - if localpart == "from" and domain == "test" then + function inbucket.before.mail_from_accepted(session) + if session.from.address == "from@example.com" then logger.info("allowing message", {}) return smtp.allow() else @@ -123,22 +123,30 @@ func TestBeforeMailAccepted(t *testing.T) { consoleLogger, extHost, strings.NewReader(test.LuaInit+script), "test.lua") require.NoError(t, err) - // Send event to be accepted. - addr := &event.AddressParts{Local: "from", Domain: "test"} - got := extHost.Events.BeforeMailAccepted.Emit(addr) - want := event.ActionAllow - require.NotNil(t, got, "Expected result from Emit()") - if got.Action != want { - t.Errorf("Got %v, wanted %v for addr %v", got.Action, want, addr) + { + // Send event to be accepted. + session := event.SMTPSession{ + From: &mail.Address{Name: "", Address: "from@example.com"}, + } + got := extHost.Events.BeforeMailFromAccepted.Emit(&session) + want := event.ActionAllow + require.NotNil(t, got, "Expected result from Emit()") + if got.Action != want { + t.Errorf("Got %v, wanted %v for addr %v", got.Action, want, session.From) + } } - // Send event to be denied. - addr = &event.AddressParts{Local: "reject", Domain: "me"} - got = extHost.Events.BeforeMailAccepted.Emit(addr) - want = event.ActionDeny - require.NotNil(t, got, "Expected result from Emit()") - if got.Action != want { - t.Errorf("Got %v, wanted %v for addr %v", got.Action, want, addr) + { + // Send event to be denied. + session := event.SMTPSession{ + From: &mail.Address{Name: "", Address: "from@reject.com"}, + } + got := extHost.Events.BeforeMailFromAccepted.Emit(&session) + want := event.ActionDeny + require.NotNil(t, got, "Expected result from Emit()") + if got.Action != want { + t.Errorf("Got %v, wanted %v for addr %v", got.Action, want, session.From) + } } } diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index ecbc62a..da35965 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -408,7 +408,7 @@ func (s *Session) parseMailFromCmd(arg string) { s.logger.Debug().Msgf("Mail sender is %v", from) // Parse from address. - localpart, domain, err := policy.ParseEmailAddress(from) + _, domain, err := policy.ParseEmailAddress(from) s.logger.Debug().Msgf("Origin domain is %v", domain) if from != "" && err != nil { s.send("501 Bad sender address syntax") @@ -446,10 +446,17 @@ func (s *Session) parseMailFromCmd(arg string) { } } + // Parse origin (from) address. + origin, err := s.addrPolicy.ParseOrigin(from) + if err != nil { + s.send("501 Bad origin address syntax") + s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg") + return + } + // Process through extensions. extAction := event.ActionDefer - extResult := s.extHost.Events.BeforeMailAccepted.Emit( - &event.AddressParts{Local: localpart, Domain: domain}) + extResult := s.extHost.Events.BeforeMailFromAccepted.Emit(&event.SMTPSession{From: &origin.Address}) if extResult != nil { extAction = extResult.Action } @@ -460,12 +467,6 @@ func (s *Session) parseMailFromCmd(arg string) { } // Sender was permitted by an extension, or no extension rejected it. - origin, err := s.addrPolicy.ParseOrigin(from) - if err != nil { - s.send("501 Bad origin address syntax") - s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg") - return - } s.from = origin // Ignore ShouldAccept if extensions explicitly allowed this From. if extAction == event.ActionDefer && !s.from.ShouldAccept() { diff --git a/pkg/server/smtp/handler_test.go b/pkg/server/smtp/handler_test.go index e34f3d7..813d208 100644 --- a/pkg/server/smtp/handler_test.go +++ b/pkg/server/smtp/handler_test.go @@ -385,17 +385,17 @@ Hi! _, _, _ = c.ReadCodeLine(221) } -// Tests "MAIL FROM" emits BeforeMailAccepted event. -func TestBeforeMailAcceptedEventEmitted(t *testing.T) { +// Tests "MAIL FROM" emits BeforeMailFromAccepted event. +func TestBeforeMailFromAcceptedEventEmitted(t *testing.T) { ds := test.NewStore() extHost := extension.NewHost() server := setupSMTPServer(ds, extHost) - var got *event.AddressParts - extHost.Events.BeforeMailAccepted.AddListener( + var got *event.SMTPSession + extHost.Events.BeforeMailFromAccepted.AddListener( "test", - func(addr event.AddressParts) *event.SMTPResponse { - got = &addr + func(session event.SMTPSession) *event.SMTPResponse { + got = &session return &event.SMTPResponse{Action: event.ActionDefer} }) @@ -407,22 +407,22 @@ func TestBeforeMailAcceptedEventEmitted(t *testing.T) { playSession(t, server, script) assert.NotNil(t, got, "BeforeMailListener did not receive Address") - assert.Equal(t, "john", got.Local, "Address local part had wrong value") - assert.Equal(t, "gmail.com", got.Domain, "Address domain part had wrong value") + assert.Equal(t, "john@gmail.com", got.From.Address, "Address had wrong value") } -// Test "MAIL FROM" acts on BeforeMailAccepted event result. -func TestBeforeMailAcceptedEventResponse(t *testing.T) { +// Test "MAIL FROM" acts on BeforeMailFromAccepted event result. +func TestBeforeMailFromAcceptedEventResponse(t *testing.T) { ds := test.NewStore() extHost := extension.NewHost() server := setupSMTPServer(ds, extHost) var shouldReturn *event.SMTPResponse - var gotEvent *event.AddressParts - extHost.Events.BeforeMailAccepted.AddListener( + var gotEvent *event.SMTPSession + + extHost.Events.BeforeMailFromAccepted.AddListener( "test", - func(addr event.AddressParts) *event.SMTPResponse { - gotEvent = &addr + func(session event.SMTPSession) *event.SMTPResponse { + gotEvent = &session return shouldReturn }) @@ -462,7 +462,7 @@ func TestBeforeMailAcceptedEventResponse(t *testing.T) { {"QUIT", 221}} playSession(t, server, script) - assert.NotNil(t, gotEvent, "BeforeMailListener did not receive Address") + assert.NotNil(t, gotEvent, "BeforeMailFromAccepted did not receive event") }) } }