1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 17:47:03 +00:00

lua: bind BeforeMessageStored function (#418)

* lua: Restore missing test log output
* lua: Use logger for test assert_async output
* lua: add InboundMessage bindings
* lua: bind BeforeMessageStored function

---------

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2023-11-06 18:10:02 -08:00
committed by GitHub
parent 01fb161df8
commit 4a6b727cbc
7 changed files with 355 additions and 13 deletions

View File

@@ -0,0 +1,134 @@
package luahost
import (
"fmt"
"net/mail"
"github.com/inbucket/inbucket/v3/pkg/extension/event"
lua "github.com/yuin/gopher-lua"
)
const inboundMessageName = "inbound_message"
func registerInboundMessageType(ls *lua.LState) {
mt := ls.NewTypeMetatable(inboundMessageName)
ls.SetGlobal(inboundMessageName, mt)
// Static attributes.
ls.SetField(mt, "new", ls.NewFunction(newInboundMessage))
// Methods.
ls.SetField(mt, "__index", ls.NewFunction(inboundMessageIndex))
ls.SetField(mt, "__newindex", ls.NewFunction(inboundMessageNewIndex))
}
func newInboundMessage(ls *lua.LState) int {
val := &event.InboundMessage{}
ud := wrapInboundMessage(ls, val)
ls.Push(ud)
return 1
}
func wrapInboundMessage(ls *lua.LState, val *event.InboundMessage) *lua.LUserData {
ud := ls.NewUserData()
ud.Value = val
ls.SetMetatable(ud, ls.GetTypeMetatable(inboundMessageName))
return ud
}
// Checks there is an InboundMessage at stack position `pos`, else throws Lua error.
func checkInboundMessage(ls *lua.LState, pos int) *event.InboundMessage {
ud := ls.CheckUserData(pos)
if v, ok := ud.Value.(*event.InboundMessage); ok {
return v
}
ls.ArgError(pos, inboundMessageName+" expected")
return nil
}
func unwrapInboundMessage(lv lua.LValue) (*event.InboundMessage, error) {
if ud, ok := lv.(*lua.LUserData); ok {
if v, ok := ud.Value.(*event.InboundMessage); ok {
return v, nil
}
}
return nil, fmt.Errorf("Expected InboundMessage, got %q", lv.Type().String())
}
// Gets a field value from InboundMessage user object. This emulates a Lua table,
// allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`.
func inboundMessageIndex(ls *lua.LState) int {
m := checkInboundMessage(ls, 1)
field := ls.CheckString(2)
// Push the requested field's value onto the stack.
switch field {
case "mailboxes":
lt := &lua.LTable{}
for _, v := range m.Mailboxes {
lt.Append(lua.LString(v))
}
ls.Push(lt)
case "from":
ls.Push(wrapMailAddress(ls, &m.From))
case "to":
lt := &lua.LTable{}
for _, v := range m.To {
addr := v
lt.Append(wrapMailAddress(ls, &addr))
}
ls.Push(lt)
case "subject":
ls.Push(lua.LString(m.Subject))
case "size":
ls.Push(lua.LNumber(m.Size))
default:
// Unknown field.
ls.Push(lua.LNil)
}
return 1
}
// Sets a field value on InboundMessage user object. This emulates a Lua table,
// allowing `msg.subject = x` instead of a Lua object syntax of `msg:subject(x)`.
func inboundMessageNewIndex(ls *lua.LState) int {
m := checkInboundMessage(ls, 1)
index := ls.CheckString(2)
switch index {
case "mailboxes":
lt := ls.CheckTable(3)
mailboxes := make([]string, 0, 16)
lt.ForEach(func(k, lv lua.LValue) {
if mb, ok := lv.(lua.LString); ok {
mailboxes = append(mailboxes, string(mb))
}
})
m.Mailboxes = mailboxes
case "from":
m.From = *checkMailAddress(ls, 3)
case "to":
lt := ls.CheckTable(3)
to := make([]mail.Address, 0, 16)
lt.ForEach(func(k, lv lua.LValue) {
if ud, ok := lv.(*lua.LUserData); ok {
if entry, ok := unwrapMailAddress(ud); ok {
to = append(to, *entry)
}
}
})
m.To = to
case "subject":
m.Subject = ls.CheckString(3)
case "size":
ls.RaiseError("size is read-only")
default:
ls.RaiseError("invalid index %q", index)
}
return 0
}

View File

@@ -0,0 +1,93 @@
package luahost
import (
"net/mail"
"testing"
"github.com/inbucket/inbucket/v3/pkg/extension/event"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
lua "github.com/yuin/gopher-lua"
)
// LuaInit holds useful test globals.
const LuaInit = `
function assert_eq(got, want)
if type(got) == "table" and type(want) == "table" then
assert(#got == #want, string.format("got %d element(s), wanted %d", #got, #want))
for i, gotv in ipairs(got) do
local wantv = want[i]
assert_eq(gotv, wantv, "got[%d] = %q, wanted %q", gotv, wantv)
end
return
end
assert(got == want, string.format("got %q, wanted %q", got, want))
end
`
func TestInboundMessageGetters(t *testing.T) {
want := &event.InboundMessage{
Mailboxes: []string{"mb1", "mb2"},
From: mail.Address{Name: "name1", Address: "addr1"},
To: []mail.Address{
{Name: "name2", Address: "addr2"},
{Name: "name3", Address: "addr3"},
},
Subject: "subj1",
Size: 42,
}
script := `
assert(msg, "msg should not be nil")
assert_eq(msg.mailboxes, {"mb1", "mb2"})
assert_eq(msg.subject, "subj1")
assert_eq(msg.size, 42)
assert_eq(msg.from.name, "name1")
assert_eq(msg.from.address, "addr1")
assert_eq(#msg.to, 2)
assert_eq(msg.to[1].name, "name2")
assert_eq(msg.to[1].address, "addr2")
assert_eq(msg.to[2].name, "name3")
assert_eq(msg.to[2].address, "addr3")
`
ls := lua.NewState()
registerInboundMessageType(ls)
registerMailAddressType(ls)
ls.SetGlobal("msg", wrapInboundMessage(ls, want))
require.NoError(t, ls.DoString(LuaInit+script))
}
func TestInboundMessageSetters(t *testing.T) {
want := &event.InboundMessage{
Mailboxes: []string{"mb1", "mb2"},
From: mail.Address{Name: "name1", Address: "addr1"},
To: []mail.Address{
{Name: "name2", Address: "addr2"},
{Name: "name3", Address: "addr3"},
},
Subject: "subj1",
}
script := `
assert(msg, "msg should not be nil")
msg.mailboxes = {"mb1", "mb2"}
msg.subject = "subj1"
msg.from = address.new("name1", "addr1")
msg.to = { address.new("name2", "addr2"), address.new("name3", "addr3") }
`
got := &event.InboundMessage{}
ls := lua.NewState()
registerInboundMessageType(ls)
registerMailAddressType(ls)
ls.SetGlobal("msg", wrapInboundMessage(ls, got))
require.NoError(t, ls.DoString(script))
assert.Equal(t, want, got)
}

View File

@@ -29,7 +29,8 @@ type InbucketAfterFuncs struct {
// InbucketBeforeFuncs holds references to Lua extension functions to be called // InbucketBeforeFuncs holds references to Lua extension functions to be called
// before Inbucket handles an event. // before Inbucket handles an event.
type InbucketBeforeFuncs struct { type InbucketBeforeFuncs struct {
MailAccepted *lua.LFunction MailAccepted *lua.LFunction
MessageStored *lua.LFunction
} }
func registerInbucketTypes(ls *lua.LState) { func registerInbucketTypes(ls *lua.LState) {
@@ -186,6 +187,8 @@ func inbucketBeforeIndex(ls *lua.LState) int {
switch field { switch field {
case "mail_accepted": case "mail_accepted":
ls.Push(funcOrNil(before.MailAccepted)) ls.Push(funcOrNil(before.MailAccepted))
case "message_stored":
ls.Push(funcOrNil(before.MessageStored))
default: default:
// Unknown field. // Unknown field.
ls.Push(lua.LNil) ls.Push(lua.LNil)
@@ -202,6 +205,8 @@ func inbucketBeforeNewIndex(ls *lua.LState) int {
switch index { switch index {
case "mail_accepted": case "mail_accepted":
m.MailAccepted = ls.CheckFunction(3) m.MailAccepted = ls.CheckFunction(3)
case "message_stored":
m.MessageStored = ls.CheckFunction(3)
default: default:
ls.RaiseError("invalid inbucket.before index %q", index) ls.RaiseError("invalid inbucket.before index %q", index)
} }

View File

@@ -62,7 +62,7 @@ func TestInbucketBeforeFuncs(t *testing.T) {
assert(inbucket, "inbucket should not be nil") assert(inbucket, "inbucket should not be nil")
assert(inbucket.before, "inbucket.before should not be nil") assert(inbucket.before, "inbucket.before should not be nil")
local fns = { "mail_accepted" } local fns = { "mail_accepted", "message_stored" }
-- Verify functions start off nil. -- Verify functions start off nil.
for i, name in ipairs(fns) do for i, name in ipairs(fns) do

View File

@@ -105,6 +105,9 @@ func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) {
if ib.Before.MailAccepted != nil { if ib.Before.MailAccepted != nil {
events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted) events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted)
} }
if ib.Before.MessageStored != nil {
events.BeforeMessageStored.AddListener(listenerName, h.handleBeforeMessageStored)
}
} }
func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) {
@@ -174,6 +177,38 @@ func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool {
return &result return &result
} }
func (h *Host) handleBeforeMessageStored(msg event.InboundMessage) *event.InboundMessage {
logger, ls, ib, ok := h.prepareInbucketFuncCall("before.message_stored")
if !ok {
return nil
}
defer h.pool.putState(ls)
logger.Debug().Msgf("Calling Lua function with %+v", msg)
if err := ls.CallByParam(
lua.P{Fn: ib.Before.MessageStored, NRet: 1, Protect: true},
wrapInboundMessage(ls, &msg),
); err != nil {
logger.Error().Err(err).Msg("Failed to call Lua function")
return nil
}
lval := ls.Get(1)
logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String())
if lval.Type() == lua.LTNil || lua.LVIsFalse(lval) {
return nil
}
result, err := unwrapInboundMessage(lval)
if err != nil {
logger.Error().Err(err).Msg("Bad response from Lua Function")
}
ls.Pop(1)
return result
}
// Common preparation for calling Lua functions. // Common preparation for calling Lua functions.
func (h *Host) prepareInbucketFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, ib *Inbucket, ok bool) { func (h *Host) prepareInbucketFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, ib *Inbucket, ok bool) {
logger = h.logContext.Str("event", funcName).Logger() logger = h.logContext.Str("event", funcName).Logger()

View File

@@ -17,14 +17,16 @@ import (
// LuaInit holds useful test globals. // LuaInit holds useful test globals.
const LuaInit = ` const LuaInit = `
local logger = require("logger")
async = false async = false
test_ok = true test_ok = true
-- Sends marks tests failed instead of erroring when enabled. -- Sends marks tests as failed instead of erroring when enabled.
function assert_async(value, message) function assert_async(value, message)
if not value then if not value then
if async then if async then
print(message) logger.error(message, {from = "assert_async"})
test_ok = false test_ok = false
else else
error(message) error(message)
@@ -32,7 +34,7 @@ const LuaInit = `
end end
end end
-- Tests plain values and list-style tables. -- Verifies plain values and list-style tables.
function assert_eq(got, want) function assert_eq(got, want)
if type(got) == "table" and type(want) == "table" then if type(got) == "table" and type(want) == "table" then
assert_async(#got == #want, string.format("got %d elements, wanted %d", #got, #want)) assert_async(#got == #want, string.format("got %d elements, wanted %d", #got, #want))
@@ -48,17 +50,20 @@ const LuaInit = `
assert_async(got == want, string.format("got %q, wanted %q", got, want)) assert_async(got == want, string.format("got %q, wanted %q", got, want))
end end
-- Verifies string want contains string got.
function assert_contains(got, want) function assert_contains(got, want)
assert_async(string.find(got, want), assert_async(string.find(got, want),
string.format("got %q, wanted it to contain %q", got, want)) string.format("got %q, wanted it to contain %q", got, want))
end end
` `
var consoleLogger = zerolog.New(zerolog.NewConsoleWriter())
func TestEmptyScript(t *testing.T) { func TestEmptyScript(t *testing.T) {
script := "" script := ""
extHost := extension.NewHost() extHost := extension.NewHost()
_, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua") _, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua")
require.NoError(t, err) require.NoError(t, err)
} }
@@ -91,7 +96,7 @@ func TestAfterMessageDeleted(t *testing.T) {
end end
` `
extHost := extension.NewHost() extHost := extension.NewHost()
luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua") luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
require.NoError(t, err) require.NoError(t, err)
notify := luaHost.CreateChannel("notify") notify := luaHost.CreateChannel("notify")
@@ -122,7 +127,7 @@ func TestAfterMessageStored(t *testing.T) {
end end
` `
extHost := extension.NewHost() extHost := extension.NewHost()
luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua") luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
require.NoError(t, err) require.NoError(t, err)
notify := luaHost.CreateChannel("notify") notify := luaHost.CreateChannel("notify")
@@ -148,14 +153,14 @@ func TestBeforeMailAccepted(t *testing.T) {
end end
` `
extHost := extension.NewHost() extHost := extension.NewHost()
_, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua") _, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua")
require.NoError(t, err) require.NoError(t, err)
// Send event to be accepted. // Send event to be accepted.
addr := &event.AddressParts{Local: "from", Domain: "test"} addr := &event.AddressParts{Local: "from", Domain: "test"}
got := extHost.Events.BeforeMailAccepted.Emit(addr) got := extHost.Events.BeforeMailAccepted.Emit(addr)
want := true want := true
require.NotNil(t, got) require.NotNil(t, got, "Expected result from Emit()")
if *got != want { if *got != want {
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
} }
@@ -164,12 +169,81 @@ func TestBeforeMailAccepted(t *testing.T) {
addr = &event.AddressParts{Local: "reject", Domain: "me"} addr = &event.AddressParts{Local: "reject", Domain: "me"}
got = extHost.Events.BeforeMailAccepted.Emit(addr) got = extHost.Events.BeforeMailAccepted.Emit(addr)
want = false want = false
require.NotNil(t, got) require.NotNil(t, got, "Expected result from Emit()")
if *got != want { if *got != want {
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
} }
} }
func TestBeforeMessageStored(t *testing.T) {
// Event to send.
msg := event.InboundMessage{
Mailboxes: []string{"one", "two"},
From: mail.Address{Name: "From Name", Address: "from@example.com"},
To: []mail.Address{
{Name: "To1 Name", Address: "to1@example.com"},
{Name: "To2 Name", Address: "to2@example.com"},
},
Subject: "inbound subj",
Size: 42,
}
// Register lua event listener.
script := `
async = true
function inbucket.before.message_stored(msg)
-- Verify incoming values.
assert_eq(msg.mailboxes, {"one", "two"})
assert_eq(msg.from.name, "From Name")
assert_eq(msg.from.address, "from@example.com")
assert_eq(2, #msg.to)
assert_eq(msg.to[1].name, "To1 Name")
assert_eq(msg.to[1].address, "to1@example.com")
assert_eq(msg.to[2].name, "To2 Name")
assert_eq(msg.to[2].address, "to2@example.com")
assert_eq(msg.subject, "inbound subj")
assert_eq(msg.size, 42)
notify:send(test_ok)
-- Generate response.
res = inbound_message.new()
res.mailboxes = {"resone", "restwo"}
res.from = address.new("Res From", "res@example.com")
res.to = {
address.new("To1 Res", "res1@example.com"),
address.new("To2 Res", "res2@example.com"),
}
res.subject = "res subj"
return res
end
`
extHost := extension.NewHost()
luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
require.NoError(t, err)
notify := luaHost.CreateChannel("notify")
// Send event to be accepted.
got := extHost.Events.BeforeMessageStored.Emit(&msg)
require.NotNil(t, got, "Expected result from Emit()")
// Verify Lua assertions passed.
assertNotified(t, notify)
// Verify response values.
want := &event.InboundMessage{
Mailboxes: []string{"resone", "restwo"},
From: mail.Address{Name: "Res From", Address: "res@example.com"},
To: []mail.Address{
{Name: "To1 Res", Address: "res1@example.com"},
{Name: "To2 Res", Address: "res2@example.com"},
},
Subject: "res subj",
Size: 0,
}
assert.Equal(t, want, got, "Response InboundMessage did not match")
}
func assertNotified(t *testing.T, notify chan lua.LValue) { func assertNotified(t *testing.T, notify chan lua.LValue) {
t.Helper() t.Helper()
select { select {

View File

@@ -6,7 +6,7 @@ import (
"github.com/cjoudrey/gluahttp" "github.com/cjoudrey/gluahttp"
"github.com/cosmotek/loguago" "github.com/cosmotek/loguago"
"github.com/inbucket/gopher-json" json "github.com/inbucket/gopher-json"
"github.com/rs/zerolog" "github.com/rs/zerolog"
lua "github.com/yuin/gopher-lua" lua "github.com/yuin/gopher-lua"
) )
@@ -44,9 +44,10 @@ func (lp *statePool) newState() (*lua.LState, error) {
} }
// Register custom types. // Register custom types.
registerInboundMessageType(ls)
registerInbucketTypes(ls) registerInbucketTypes(ls)
registerMessageMetadataType(ls)
registerMailAddressType(ls) registerMailAddressType(ls)
registerMessageMetadataType(ls)
registerPolicyType(ls) registerPolicyType(ls)
// Run compiled script. // Run compiled script.