1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-18 18:17:03 +00:00

feat: Add RemoteAddr to SMTPSession (#548)

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2024-10-20 11:49:55 -07:00
committed by GitHub
parent 78d4c4f4e7
commit 15d1970dbe
5 changed files with 42 additions and 20 deletions

View File

@@ -50,6 +50,7 @@ type SMTPResponse struct {
// SMTPSession captures SMTP `MAIL FROM` & `RCPT TO` values prior to mail DATA being received. // SMTPSession captures SMTP `MAIL FROM` & `RCPT TO` values prior to mail DATA being received.
type SMTPSession struct { type SMTPSession struct {
From *mail.Address From *mail.Address
To []*mail.Address To []*mail.Address
RemoteAddr string
} }

View File

@@ -47,20 +47,22 @@ func checkSMTPSession(ls *lua.LState, pos int) *event.SMTPSession {
// Gets a field value from SMTPSession user object. This emulates a Lua table, // Gets a field value from SMTPSession user object. This emulates a Lua table,
// allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`. // allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`.
func smtpSessionIndex(ls *lua.LState) int { func smtpSessionIndex(ls *lua.LState) int {
m := checkSMTPSession(ls, 1) session := checkSMTPSession(ls, 1)
field := ls.CheckString(2) field := ls.CheckString(2)
// Push the requested field's value onto the stack. // Push the requested field's value onto the stack.
switch field { switch field {
case "from": case "from":
ls.Push(wrapMailAddress(ls, m.From)) ls.Push(wrapMailAddress(ls, session.From))
case "to": case "to":
lt := &lua.LTable{} lt := &lua.LTable{}
for _, v := range m.To { for _, v := range session.To {
addr := v addr := v
lt.Append(wrapMailAddress(ls, addr)) lt.Append(wrapMailAddress(ls, addr))
} }
ls.Push(lt) ls.Push(lt)
case "remote_addr":
ls.Push(lua.LString(session.RemoteAddr))
default: default:
// Unknown field. // Unknown field.
ls.Push(lua.LNil) ls.Push(lua.LNil)

View File

@@ -16,23 +16,26 @@ func TestSMTPSessionGetters(t *testing.T) {
{Name: "name2", Address: "addr2"}, {Name: "name2", Address: "addr2"},
{Name: "name3", Address: "addr3"}, {Name: "name3", Address: "addr3"},
}, },
RemoteAddr: "1.2.3.4",
} }
script := ` script := `
assert(msg, "msg should not be nil") assert(session, "session should not be nil")
assert_eq(msg.from.name, "name1", "from.name") assert_eq(session.from.name, "name1", "from.name")
assert_eq(msg.from.address, "addr1", "from.address") assert_eq(session.from.address, "addr1", "from.address")
assert_eq(#msg.to, 2, "#msg.to") assert_eq(#session.to, 2, "#session.to")
assert_eq(msg.to[1].name, "name2", "to[1].name") assert_eq(session.to[1].name, "name2", "to[1].name")
assert_eq(msg.to[1].address, "addr2", "to[1].address") assert_eq(session.to[1].address, "addr2", "to[1].address")
assert_eq(msg.to[2].name, "name3", "to[2].name") assert_eq(session.to[2].name, "name3", "to[2].name")
assert_eq(msg.to[2].address, "addr3", "to[2].address") assert_eq(session.to[2].address, "addr3", "to[2].address")
assert_eq(session.remote_addr, "1.2.3.4")
` `
ls, _ := test.NewLuaState() ls, _ := test.NewLuaState()
registerSMTPSessionType(ls) registerSMTPSessionType(ls)
registerMailAddressType(ls) registerMailAddressType(ls)
ls.SetGlobal("msg", wrapSMTPSession(ls, want)) ls.SetGlobal("session", wrapSMTPSession(ls, want))
require.NoError(t, ls.DoString(script)) require.NoError(t, ls.DoString(script))
} }

View File

@@ -116,7 +116,11 @@ type Session struct {
// NewSession creates a new Session for the given connection // NewSession creates a new Session for the given connection
func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *Session { func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *Session {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
remoteHost := conn.RemoteAddr().String()
if host, _, err := net.SplitHostPort(remoteHost); err == nil {
remoteHost = host
}
session := &Session{ session := &Session{
Server: server, Server: server,
@@ -124,7 +128,7 @@ func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *S
conn: conn, conn: conn,
state: GREET, state: GREET,
reader: reader, reader: reader,
remoteHost: host, remoteHost: remoteHost,
recipients: make([]*policy.Recipient, 0), recipients: make([]*policy.Recipient, 0),
logger: logger, logger: logger,
debug: server.config.Debug, debug: server.config.Debug,
@@ -454,9 +458,14 @@ func (s *Session) parseMailFromCmd(arg string) {
return return
} }
// Add from to extSession for inspection.
extSession := s.extSession()
addrCopy := origin.Address
extSession.From = &addrCopy
// Process through extensions. // Process through extensions.
extAction := event.ActionDefer extAction := event.ActionDefer
extResult := s.extHost.Events.BeforeMailFromAccepted.Emit(&event.SMTPSession{From: &origin.Address}) extResult := s.extHost.Events.BeforeMailFromAccepted.Emit(extSession)
if extResult != nil { if extResult != nil {
extAction = extResult.Action extAction = extResult.Action
} }
@@ -711,7 +720,11 @@ func (s *Session) ooSeq(cmd string) {
// extSession builds an SMTPSession for extensions. // extSession builds an SMTPSession for extensions.
func (s *Session) extSession() *event.SMTPSession { func (s *Session) extSession() *event.SMTPSession {
from := s.from.Address var from *mail.Address
if s.from != nil {
addr := s.from.Address
from = &addr
}
to := make([]*mail.Address, 0, len(s.recipients)) to := make([]*mail.Address, 0, len(s.recipients))
for _, recip := range s.recipients { for _, recip := range s.recipients {
addr := recip.Address addr := recip.Address
@@ -719,7 +732,8 @@ func (s *Session) extSession() *event.SMTPSession {
} }
return &event.SMTPSession{ return &event.SMTPSession{
From: &from, From: from,
To: to, To: to,
RemoteAddr: s.remoteHost,
} }
} }

View File

@@ -408,6 +408,7 @@ func TestBeforeMailFromAcceptedEventEmitted(t *testing.T) {
assert.NotNil(t, got, "BeforeMailListener did not receive Address") assert.NotNil(t, got, "BeforeMailListener did not receive Address")
assert.Equal(t, "john@gmail.com", got.From.Address, "Address had wrong value") assert.Equal(t, "john@gmail.com", got.From.Address, "Address had wrong value")
assert.Equal(t, "pipe", got.RemoteAddr, "RemoteAddr had wrong value")
} }
// Test "MAIL FROM" acts on BeforeMailFromAccepted event result. // Test "MAIL FROM" acts on BeforeMailFromAccepted event result.
@@ -492,6 +493,7 @@ func TestBeforeRcptToAcceptedSingleEventEmitted(t *testing.T) {
require.NotNil(t, got, "BeforeRcptToListener did not receive SMTPSession") require.NotNil(t, got, "BeforeRcptToListener did not receive SMTPSession")
require.NotNil(t, got.From) require.NotNil(t, got.From)
require.NotNil(t, got.To) require.NotNil(t, got.To)
assert.Equal(t, "pipe", got.RemoteAddr, "RemoteAddr had wrong value")
assert.Equal(t, "john@gmail.com", got.From.Address) assert.Equal(t, "john@gmail.com", got.From.Address)
assert.Len(t, got.To, 1) assert.Len(t, got.To, 1)
assert.Equal(t, "user@gmail.com", got.To[0].Address) assert.Equal(t, "user@gmail.com", got.To[0].Address)