From 15d1970dbecd68396065adbbbcc8dc1680fd998c Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Sun, 20 Oct 2024 11:49:55 -0700 Subject: [PATCH] feat: Add RemoteAddr to SMTPSession (#548) Signed-off-by: James Hillyerd --- pkg/extension/event/events.go | 5 ++-- pkg/extension/luahost/bind_smtpsession.go | 8 +++--- .../luahost/bind_smtpsession_test.go | 21 ++++++++------- pkg/server/smtp/handler.go | 26 ++++++++++++++----- pkg/server/smtp/handler_test.go | 2 ++ 5 files changed, 42 insertions(+), 20 deletions(-) diff --git a/pkg/extension/event/events.go b/pkg/extension/event/events.go index 6b02803..52a07aa 100644 --- a/pkg/extension/event/events.go +++ b/pkg/extension/event/events.go @@ -50,6 +50,7 @@ type SMTPResponse struct { // SMTPSession captures SMTP `MAIL FROM` & `RCPT TO` values prior to mail DATA being received. type SMTPSession struct { - From *mail.Address - To []*mail.Address + From *mail.Address + To []*mail.Address + RemoteAddr string } diff --git a/pkg/extension/luahost/bind_smtpsession.go b/pkg/extension/luahost/bind_smtpsession.go index b3cf504..4b56780 100644 --- a/pkg/extension/luahost/bind_smtpsession.go +++ b/pkg/extension/luahost/bind_smtpsession.go @@ -47,20 +47,22 @@ func checkSMTPSession(ls *lua.LState, pos int) *event.SMTPSession { // Gets a field value from SMTPSession user object. This emulates a Lua table, // allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`. func smtpSessionIndex(ls *lua.LState) int { - m := checkSMTPSession(ls, 1) + session := checkSMTPSession(ls, 1) field := ls.CheckString(2) // Push the requested field's value onto the stack. switch field { case "from": - ls.Push(wrapMailAddress(ls, m.From)) + ls.Push(wrapMailAddress(ls, session.From)) case "to": lt := &lua.LTable{} - for _, v := range m.To { + for _, v := range session.To { addr := v lt.Append(wrapMailAddress(ls, addr)) } ls.Push(lt) + case "remote_addr": + ls.Push(lua.LString(session.RemoteAddr)) default: // Unknown field. ls.Push(lua.LNil) diff --git a/pkg/extension/luahost/bind_smtpsession_test.go b/pkg/extension/luahost/bind_smtpsession_test.go index b495428..21b00f4 100644 --- a/pkg/extension/luahost/bind_smtpsession_test.go +++ b/pkg/extension/luahost/bind_smtpsession_test.go @@ -16,23 +16,26 @@ func TestSMTPSessionGetters(t *testing.T) { {Name: "name2", Address: "addr2"}, {Name: "name3", Address: "addr3"}, }, + RemoteAddr: "1.2.3.4", } script := ` - assert(msg, "msg should not be nil") + assert(session, "session should not be nil") - assert_eq(msg.from.name, "name1", "from.name") - assert_eq(msg.from.address, "addr1", "from.address") + assert_eq(session.from.name, "name1", "from.name") + assert_eq(session.from.address, "addr1", "from.address") - assert_eq(#msg.to, 2, "#msg.to") - assert_eq(msg.to[1].name, "name2", "to[1].name") - assert_eq(msg.to[1].address, "addr2", "to[1].address") - assert_eq(msg.to[2].name, "name3", "to[2].name") - assert_eq(msg.to[2].address, "addr3", "to[2].address") + assert_eq(#session.to, 2, "#session.to") + assert_eq(session.to[1].name, "name2", "to[1].name") + assert_eq(session.to[1].address, "addr2", "to[1].address") + assert_eq(session.to[2].name, "name3", "to[2].name") + assert_eq(session.to[2].address, "addr3", "to[2].address") + + assert_eq(session.remote_addr, "1.2.3.4") ` ls, _ := test.NewLuaState() registerSMTPSessionType(ls) registerMailAddressType(ls) - ls.SetGlobal("msg", wrapSMTPSession(ls, want)) + ls.SetGlobal("session", wrapSMTPSession(ls, want)) require.NoError(t, ls.DoString(script)) } diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index da35965..8643a33 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -116,7 +116,11 @@ type Session struct { // NewSession creates a new Session for the given connection func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *Session { reader := bufio.NewReader(conn) - host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + + remoteHost := conn.RemoteAddr().String() + if host, _, err := net.SplitHostPort(remoteHost); err == nil { + remoteHost = host + } session := &Session{ Server: server, @@ -124,7 +128,7 @@ func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *S conn: conn, state: GREET, reader: reader, - remoteHost: host, + remoteHost: remoteHost, recipients: make([]*policy.Recipient, 0), logger: logger, debug: server.config.Debug, @@ -454,9 +458,14 @@ func (s *Session) parseMailFromCmd(arg string) { return } + // Add from to extSession for inspection. + extSession := s.extSession() + addrCopy := origin.Address + extSession.From = &addrCopy + // Process through extensions. extAction := event.ActionDefer - extResult := s.extHost.Events.BeforeMailFromAccepted.Emit(&event.SMTPSession{From: &origin.Address}) + extResult := s.extHost.Events.BeforeMailFromAccepted.Emit(extSession) if extResult != nil { extAction = extResult.Action } @@ -711,7 +720,11 @@ func (s *Session) ooSeq(cmd string) { // extSession builds an SMTPSession for extensions. func (s *Session) extSession() *event.SMTPSession { - from := s.from.Address + var from *mail.Address + if s.from != nil { + addr := s.from.Address + from = &addr + } to := make([]*mail.Address, 0, len(s.recipients)) for _, recip := range s.recipients { addr := recip.Address @@ -719,7 +732,8 @@ func (s *Session) extSession() *event.SMTPSession { } return &event.SMTPSession{ - From: &from, - To: to, + From: from, + To: to, + RemoteAddr: s.remoteHost, } } diff --git a/pkg/server/smtp/handler_test.go b/pkg/server/smtp/handler_test.go index 813d208..8658701 100644 --- a/pkg/server/smtp/handler_test.go +++ b/pkg/server/smtp/handler_test.go @@ -408,6 +408,7 @@ func TestBeforeMailFromAcceptedEventEmitted(t *testing.T) { assert.NotNil(t, got, "BeforeMailListener did not receive Address") assert.Equal(t, "john@gmail.com", got.From.Address, "Address had wrong value") + assert.Equal(t, "pipe", got.RemoteAddr, "RemoteAddr had wrong value") } // Test "MAIL FROM" acts on BeforeMailFromAccepted event result. @@ -492,6 +493,7 @@ func TestBeforeRcptToAcceptedSingleEventEmitted(t *testing.T) { require.NotNil(t, got, "BeforeRcptToListener did not receive SMTPSession") require.NotNil(t, got.From) require.NotNil(t, got.To) + assert.Equal(t, "pipe", got.RemoteAddr, "RemoteAddr had wrong value") assert.Equal(t, "john@gmail.com", got.From.Address) assert.Len(t, got.To, 1) assert.Equal(t, "user@gmail.com", got.To[0].Address)