From 4a6b727cbcbe1c563a249561ecca654a1e2478df Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Mon, 6 Nov 2023 18:10:02 -0800 Subject: [PATCH] lua: bind BeforeMessageStored function (#418) * lua: Restore missing test log output * lua: Use logger for test assert_async output * lua: add InboundMessage bindings * lua: bind BeforeMessageStored function --------- Signed-off-by: James Hillyerd --- pkg/extension/luahost/bind_inboundmessage.go | 134 ++++++++++++++++++ .../luahost/bind_inboundmessage_test.go | 93 ++++++++++++ pkg/extension/luahost/bind_inbucket.go | 7 +- pkg/extension/luahost/bind_inbucket_test.go | 2 +- pkg/extension/luahost/lua.go | 35 +++++ pkg/extension/luahost/lua_test.go | 92 ++++++++++-- pkg/extension/luahost/pool.go | 5 +- 7 files changed, 355 insertions(+), 13 deletions(-) create mode 100644 pkg/extension/luahost/bind_inboundmessage.go create mode 100644 pkg/extension/luahost/bind_inboundmessage_test.go diff --git a/pkg/extension/luahost/bind_inboundmessage.go b/pkg/extension/luahost/bind_inboundmessage.go new file mode 100644 index 0000000..12dc6cd --- /dev/null +++ b/pkg/extension/luahost/bind_inboundmessage.go @@ -0,0 +1,134 @@ +package luahost + +import ( + "fmt" + "net/mail" + + "github.com/inbucket/inbucket/v3/pkg/extension/event" + lua "github.com/yuin/gopher-lua" +) + +const inboundMessageName = "inbound_message" + +func registerInboundMessageType(ls *lua.LState) { + mt := ls.NewTypeMetatable(inboundMessageName) + ls.SetGlobal(inboundMessageName, mt) + + // Static attributes. + ls.SetField(mt, "new", ls.NewFunction(newInboundMessage)) + + // Methods. + ls.SetField(mt, "__index", ls.NewFunction(inboundMessageIndex)) + ls.SetField(mt, "__newindex", ls.NewFunction(inboundMessageNewIndex)) +} + +func newInboundMessage(ls *lua.LState) int { + val := &event.InboundMessage{} + ud := wrapInboundMessage(ls, val) + ls.Push(ud) + + return 1 +} + +func wrapInboundMessage(ls *lua.LState, val *event.InboundMessage) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(inboundMessageName)) + + return ud +} + +// Checks there is an InboundMessage at stack position `pos`, else throws Lua error. +func checkInboundMessage(ls *lua.LState, pos int) *event.InboundMessage { + ud := ls.CheckUserData(pos) + if v, ok := ud.Value.(*event.InboundMessage); ok { + return v + } + ls.ArgError(pos, inboundMessageName+" expected") + return nil +} + +func unwrapInboundMessage(lv lua.LValue) (*event.InboundMessage, error) { + if ud, ok := lv.(*lua.LUserData); ok { + if v, ok := ud.Value.(*event.InboundMessage); ok { + return v, nil + } + } + + return nil, fmt.Errorf("Expected InboundMessage, got %q", lv.Type().String()) +} + +// Gets a field value from InboundMessage user object. This emulates a Lua table, +// allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`. +func inboundMessageIndex(ls *lua.LState) int { + m := checkInboundMessage(ls, 1) + field := ls.CheckString(2) + + // Push the requested field's value onto the stack. + switch field { + case "mailboxes": + lt := &lua.LTable{} + for _, v := range m.Mailboxes { + lt.Append(lua.LString(v)) + } + ls.Push(lt) + case "from": + ls.Push(wrapMailAddress(ls, &m.From)) + case "to": + lt := &lua.LTable{} + for _, v := range m.To { + addr := v + lt.Append(wrapMailAddress(ls, &addr)) + } + ls.Push(lt) + case "subject": + ls.Push(lua.LString(m.Subject)) + case "size": + ls.Push(lua.LNumber(m.Size)) + default: + // Unknown field. + ls.Push(lua.LNil) + } + + return 1 +} + +// Sets a field value on InboundMessage user object. This emulates a Lua table, +// allowing `msg.subject = x` instead of a Lua object syntax of `msg:subject(x)`. +func inboundMessageNewIndex(ls *lua.LState) int { + m := checkInboundMessage(ls, 1) + index := ls.CheckString(2) + + switch index { + case "mailboxes": + lt := ls.CheckTable(3) + mailboxes := make([]string, 0, 16) + lt.ForEach(func(k, lv lua.LValue) { + if mb, ok := lv.(lua.LString); ok { + mailboxes = append(mailboxes, string(mb)) + } + }) + m.Mailboxes = mailboxes + case "from": + m.From = *checkMailAddress(ls, 3) + case "to": + lt := ls.CheckTable(3) + to := make([]mail.Address, 0, 16) + lt.ForEach(func(k, lv lua.LValue) { + if ud, ok := lv.(*lua.LUserData); ok { + if entry, ok := unwrapMailAddress(ud); ok { + to = append(to, *entry) + } + } + }) + m.To = to + case "subject": + m.Subject = ls.CheckString(3) + case "size": + ls.RaiseError("size is read-only") + default: + ls.RaiseError("invalid index %q", index) + } + + return 0 +} diff --git a/pkg/extension/luahost/bind_inboundmessage_test.go b/pkg/extension/luahost/bind_inboundmessage_test.go new file mode 100644 index 0000000..715f4b3 --- /dev/null +++ b/pkg/extension/luahost/bind_inboundmessage_test.go @@ -0,0 +1,93 @@ +package luahost + +import ( + "net/mail" + "testing" + + "github.com/inbucket/inbucket/v3/pkg/extension/event" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + lua "github.com/yuin/gopher-lua" +) + +// LuaInit holds useful test globals. +const LuaInit = ` + function assert_eq(got, want) + if type(got) == "table" and type(want) == "table" then + assert(#got == #want, string.format("got %d element(s), wanted %d", #got, #want)) + + for i, gotv in ipairs(got) do + local wantv = want[i] + assert_eq(gotv, wantv, "got[%d] = %q, wanted %q", gotv, wantv) + end + + return + end + + assert(got == want, string.format("got %q, wanted %q", got, want)) + end +` + +func TestInboundMessageGetters(t *testing.T) { + want := &event.InboundMessage{ + Mailboxes: []string{"mb1", "mb2"}, + From: mail.Address{Name: "name1", Address: "addr1"}, + To: []mail.Address{ + {Name: "name2", Address: "addr2"}, + {Name: "name3", Address: "addr3"}, + }, + Subject: "subj1", + Size: 42, + } + script := ` + assert(msg, "msg should not be nil") + + assert_eq(msg.mailboxes, {"mb1", "mb2"}) + assert_eq(msg.subject, "subj1") + assert_eq(msg.size, 42) + + assert_eq(msg.from.name, "name1") + assert_eq(msg.from.address, "addr1") + + assert_eq(#msg.to, 2) + assert_eq(msg.to[1].name, "name2") + assert_eq(msg.to[1].address, "addr2") + assert_eq(msg.to[2].name, "name3") + assert_eq(msg.to[2].address, "addr3") + ` + + ls := lua.NewState() + registerInboundMessageType(ls) + registerMailAddressType(ls) + ls.SetGlobal("msg", wrapInboundMessage(ls, want)) + require.NoError(t, ls.DoString(LuaInit+script)) +} + +func TestInboundMessageSetters(t *testing.T) { + want := &event.InboundMessage{ + Mailboxes: []string{"mb1", "mb2"}, + From: mail.Address{Name: "name1", Address: "addr1"}, + To: []mail.Address{ + {Name: "name2", Address: "addr2"}, + {Name: "name3", Address: "addr3"}, + }, + Subject: "subj1", + } + script := ` + assert(msg, "msg should not be nil") + + msg.mailboxes = {"mb1", "mb2"} + msg.subject = "subj1" + msg.from = address.new("name1", "addr1") + msg.to = { address.new("name2", "addr2"), address.new("name3", "addr3") } + ` + + got := &event.InboundMessage{} + ls := lua.NewState() + registerInboundMessageType(ls) + registerMailAddressType(ls) + ls.SetGlobal("msg", wrapInboundMessage(ls, got)) + require.NoError(t, ls.DoString(script)) + + assert.Equal(t, want, got) +} diff --git a/pkg/extension/luahost/bind_inbucket.go b/pkg/extension/luahost/bind_inbucket.go index e2273d7..c641051 100644 --- a/pkg/extension/luahost/bind_inbucket.go +++ b/pkg/extension/luahost/bind_inbucket.go @@ -29,7 +29,8 @@ type InbucketAfterFuncs struct { // InbucketBeforeFuncs holds references to Lua extension functions to be called // before Inbucket handles an event. type InbucketBeforeFuncs struct { - MailAccepted *lua.LFunction + MailAccepted *lua.LFunction + MessageStored *lua.LFunction } func registerInbucketTypes(ls *lua.LState) { @@ -186,6 +187,8 @@ func inbucketBeforeIndex(ls *lua.LState) int { switch field { case "mail_accepted": ls.Push(funcOrNil(before.MailAccepted)) + case "message_stored": + ls.Push(funcOrNil(before.MessageStored)) default: // Unknown field. ls.Push(lua.LNil) @@ -202,6 +205,8 @@ func inbucketBeforeNewIndex(ls *lua.LState) int { switch index { case "mail_accepted": m.MailAccepted = ls.CheckFunction(3) + case "message_stored": + m.MessageStored = ls.CheckFunction(3) default: ls.RaiseError("invalid inbucket.before index %q", index) } diff --git a/pkg/extension/luahost/bind_inbucket_test.go b/pkg/extension/luahost/bind_inbucket_test.go index 6f7f007..ac391a5 100644 --- a/pkg/extension/luahost/bind_inbucket_test.go +++ b/pkg/extension/luahost/bind_inbucket_test.go @@ -62,7 +62,7 @@ func TestInbucketBeforeFuncs(t *testing.T) { assert(inbucket, "inbucket should not be nil") assert(inbucket.before, "inbucket.before should not be nil") - local fns = { "mail_accepted" } + local fns = { "mail_accepted", "message_stored" } -- Verify functions start off nil. for i, name in ipairs(fns) do diff --git a/pkg/extension/luahost/lua.go b/pkg/extension/luahost/lua.go index cf9615a..c8cac20 100644 --- a/pkg/extension/luahost/lua.go +++ b/pkg/extension/luahost/lua.go @@ -105,6 +105,9 @@ func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) { if ib.Before.MailAccepted != nil { events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted) } + if ib.Before.MessageStored != nil { + events.BeforeMessageStored.AddListener(listenerName, h.handleBeforeMessageStored) + } } func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { @@ -174,6 +177,38 @@ func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { return &result } +func (h *Host) handleBeforeMessageStored(msg event.InboundMessage) *event.InboundMessage { + logger, ls, ib, ok := h.prepareInbucketFuncCall("before.message_stored") + if !ok { + return nil + } + defer h.pool.putState(ls) + + logger.Debug().Msgf("Calling Lua function with %+v", msg) + if err := ls.CallByParam( + lua.P{Fn: ib.Before.MessageStored, NRet: 1, Protect: true}, + wrapInboundMessage(ls, &msg), + ); err != nil { + logger.Error().Err(err).Msg("Failed to call Lua function") + return nil + } + + lval := ls.Get(1) + logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String()) + + if lval.Type() == lua.LTNil || lua.LVIsFalse(lval) { + return nil + } + + result, err := unwrapInboundMessage(lval) + if err != nil { + logger.Error().Err(err).Msg("Bad response from Lua Function") + } + ls.Pop(1) + + return result +} + // Common preparation for calling Lua functions. func (h *Host) prepareInbucketFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, ib *Inbucket, ok bool) { logger = h.logContext.Str("event", funcName).Logger() diff --git a/pkg/extension/luahost/lua_test.go b/pkg/extension/luahost/lua_test.go index bbef805..c72ddd0 100644 --- a/pkg/extension/luahost/lua_test.go +++ b/pkg/extension/luahost/lua_test.go @@ -17,14 +17,16 @@ import ( // LuaInit holds useful test globals. const LuaInit = ` + local logger = require("logger") + async = false test_ok = true - -- Sends marks tests failed instead of erroring when enabled. + -- Sends marks tests as failed instead of erroring when enabled. function assert_async(value, message) if not value then if async then - print(message) + logger.error(message, {from = "assert_async"}) test_ok = false else error(message) @@ -32,7 +34,7 @@ const LuaInit = ` end end - -- Tests plain values and list-style tables. + -- Verifies plain values and list-style tables. function assert_eq(got, want) if type(got) == "table" and type(want) == "table" then assert_async(#got == #want, string.format("got %d elements, wanted %d", #got, #want)) @@ -48,17 +50,20 @@ const LuaInit = ` assert_async(got == want, string.format("got %q, wanted %q", got, want)) end + -- Verifies string want contains string got. function assert_contains(got, want) assert_async(string.find(got, want), string.format("got %q, wanted it to contain %q", got, want)) end ` +var consoleLogger = zerolog.New(zerolog.NewConsoleWriter()) + func TestEmptyScript(t *testing.T) { script := "" extHost := extension.NewHost() - _, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua") + _, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua") require.NoError(t, err) } @@ -91,7 +96,7 @@ func TestAfterMessageDeleted(t *testing.T) { end ` extHost := extension.NewHost() - luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua") + luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua") require.NoError(t, err) notify := luaHost.CreateChannel("notify") @@ -122,7 +127,7 @@ func TestAfterMessageStored(t *testing.T) { end ` extHost := extension.NewHost() - luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua") + luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua") require.NoError(t, err) notify := luaHost.CreateChannel("notify") @@ -148,14 +153,14 @@ func TestBeforeMailAccepted(t *testing.T) { end ` extHost := extension.NewHost() - _, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua") + _, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua") require.NoError(t, err) // Send event to be accepted. addr := &event.AddressParts{Local: "from", Domain: "test"} got := extHost.Events.BeforeMailAccepted.Emit(addr) want := true - require.NotNil(t, got) + require.NotNil(t, got, "Expected result from Emit()") if *got != want { t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) } @@ -164,12 +169,81 @@ func TestBeforeMailAccepted(t *testing.T) { addr = &event.AddressParts{Local: "reject", Domain: "me"} got = extHost.Events.BeforeMailAccepted.Emit(addr) want = false - require.NotNil(t, got) + require.NotNil(t, got, "Expected result from Emit()") if *got != want { t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) } } +func TestBeforeMessageStored(t *testing.T) { + // Event to send. + msg := event.InboundMessage{ + Mailboxes: []string{"one", "two"}, + From: mail.Address{Name: "From Name", Address: "from@example.com"}, + To: []mail.Address{ + {Name: "To1 Name", Address: "to1@example.com"}, + {Name: "To2 Name", Address: "to2@example.com"}, + }, + Subject: "inbound subj", + Size: 42, + } + + // Register lua event listener. + script := ` + async = true + + function inbucket.before.message_stored(msg) + -- Verify incoming values. + assert_eq(msg.mailboxes, {"one", "two"}) + assert_eq(msg.from.name, "From Name") + assert_eq(msg.from.address, "from@example.com") + assert_eq(2, #msg.to) + assert_eq(msg.to[1].name, "To1 Name") + assert_eq(msg.to[1].address, "to1@example.com") + assert_eq(msg.to[2].name, "To2 Name") + assert_eq(msg.to[2].address, "to2@example.com") + assert_eq(msg.subject, "inbound subj") + assert_eq(msg.size, 42) + notify:send(test_ok) + + -- Generate response. + res = inbound_message.new() + res.mailboxes = {"resone", "restwo"} + res.from = address.new("Res From", "res@example.com") + res.to = { + address.new("To1 Res", "res1@example.com"), + address.new("To2 Res", "res2@example.com"), + } + res.subject = "res subj" + return res + end + ` + extHost := extension.NewHost() + luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua") + require.NoError(t, err) + notify := luaHost.CreateChannel("notify") + + // Send event to be accepted. + got := extHost.Events.BeforeMessageStored.Emit(&msg) + require.NotNil(t, got, "Expected result from Emit()") + + // Verify Lua assertions passed. + assertNotified(t, notify) + + // Verify response values. + want := &event.InboundMessage{ + Mailboxes: []string{"resone", "restwo"}, + From: mail.Address{Name: "Res From", Address: "res@example.com"}, + To: []mail.Address{ + {Name: "To1 Res", Address: "res1@example.com"}, + {Name: "To2 Res", Address: "res2@example.com"}, + }, + Subject: "res subj", + Size: 0, + } + assert.Equal(t, want, got, "Response InboundMessage did not match") +} + func assertNotified(t *testing.T, notify chan lua.LValue) { t.Helper() select { diff --git a/pkg/extension/luahost/pool.go b/pkg/extension/luahost/pool.go index b89b250..65a379d 100644 --- a/pkg/extension/luahost/pool.go +++ b/pkg/extension/luahost/pool.go @@ -6,7 +6,7 @@ import ( "github.com/cjoudrey/gluahttp" "github.com/cosmotek/loguago" - "github.com/inbucket/gopher-json" + json "github.com/inbucket/gopher-json" "github.com/rs/zerolog" lua "github.com/yuin/gopher-lua" ) @@ -44,9 +44,10 @@ func (lp *statePool) newState() (*lua.LState, error) { } // Register custom types. + registerInboundMessageType(ls) registerInbucketTypes(ls) - registerMessageMetadataType(ls) registerMailAddressType(ls) + registerMessageMetadataType(ls) registerPolicyType(ls) // Run compiled script.