mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-17 09:37:02 +00:00
lua: bind BeforeMessageStored function (#418)
* lua: Restore missing test log output * lua: Use logger for test assert_async output * lua: add InboundMessage bindings * lua: bind BeforeMessageStored function --------- Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
134
pkg/extension/luahost/bind_inboundmessage.go
Normal file
134
pkg/extension/luahost/bind_inboundmessage.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package luahost
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
|
||||
"github.com/inbucket/inbucket/v3/pkg/extension/event"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const inboundMessageName = "inbound_message"
|
||||
|
||||
func registerInboundMessageType(ls *lua.LState) {
|
||||
mt := ls.NewTypeMetatable(inboundMessageName)
|
||||
ls.SetGlobal(inboundMessageName, mt)
|
||||
|
||||
// Static attributes.
|
||||
ls.SetField(mt, "new", ls.NewFunction(newInboundMessage))
|
||||
|
||||
// Methods.
|
||||
ls.SetField(mt, "__index", ls.NewFunction(inboundMessageIndex))
|
||||
ls.SetField(mt, "__newindex", ls.NewFunction(inboundMessageNewIndex))
|
||||
}
|
||||
|
||||
func newInboundMessage(ls *lua.LState) int {
|
||||
val := &event.InboundMessage{}
|
||||
ud := wrapInboundMessage(ls, val)
|
||||
ls.Push(ud)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func wrapInboundMessage(ls *lua.LState, val *event.InboundMessage) *lua.LUserData {
|
||||
ud := ls.NewUserData()
|
||||
ud.Value = val
|
||||
ls.SetMetatable(ud, ls.GetTypeMetatable(inboundMessageName))
|
||||
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks there is an InboundMessage at stack position `pos`, else throws Lua error.
|
||||
func checkInboundMessage(ls *lua.LState, pos int) *event.InboundMessage {
|
||||
ud := ls.CheckUserData(pos)
|
||||
if v, ok := ud.Value.(*event.InboundMessage); ok {
|
||||
return v
|
||||
}
|
||||
ls.ArgError(pos, inboundMessageName+" expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func unwrapInboundMessage(lv lua.LValue) (*event.InboundMessage, error) {
|
||||
if ud, ok := lv.(*lua.LUserData); ok {
|
||||
if v, ok := ud.Value.(*event.InboundMessage); ok {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Expected InboundMessage, got %q", lv.Type().String())
|
||||
}
|
||||
|
||||
// Gets a field value from InboundMessage user object. This emulates a Lua table,
|
||||
// allowing `msg.subject` instead of a Lua object syntax of `msg:subject()`.
|
||||
func inboundMessageIndex(ls *lua.LState) int {
|
||||
m := checkInboundMessage(ls, 1)
|
||||
field := ls.CheckString(2)
|
||||
|
||||
// Push the requested field's value onto the stack.
|
||||
switch field {
|
||||
case "mailboxes":
|
||||
lt := &lua.LTable{}
|
||||
for _, v := range m.Mailboxes {
|
||||
lt.Append(lua.LString(v))
|
||||
}
|
||||
ls.Push(lt)
|
||||
case "from":
|
||||
ls.Push(wrapMailAddress(ls, &m.From))
|
||||
case "to":
|
||||
lt := &lua.LTable{}
|
||||
for _, v := range m.To {
|
||||
addr := v
|
||||
lt.Append(wrapMailAddress(ls, &addr))
|
||||
}
|
||||
ls.Push(lt)
|
||||
case "subject":
|
||||
ls.Push(lua.LString(m.Subject))
|
||||
case "size":
|
||||
ls.Push(lua.LNumber(m.Size))
|
||||
default:
|
||||
// Unknown field.
|
||||
ls.Push(lua.LNil)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// Sets a field value on InboundMessage user object. This emulates a Lua table,
|
||||
// allowing `msg.subject = x` instead of a Lua object syntax of `msg:subject(x)`.
|
||||
func inboundMessageNewIndex(ls *lua.LState) int {
|
||||
m := checkInboundMessage(ls, 1)
|
||||
index := ls.CheckString(2)
|
||||
|
||||
switch index {
|
||||
case "mailboxes":
|
||||
lt := ls.CheckTable(3)
|
||||
mailboxes := make([]string, 0, 16)
|
||||
lt.ForEach(func(k, lv lua.LValue) {
|
||||
if mb, ok := lv.(lua.LString); ok {
|
||||
mailboxes = append(mailboxes, string(mb))
|
||||
}
|
||||
})
|
||||
m.Mailboxes = mailboxes
|
||||
case "from":
|
||||
m.From = *checkMailAddress(ls, 3)
|
||||
case "to":
|
||||
lt := ls.CheckTable(3)
|
||||
to := make([]mail.Address, 0, 16)
|
||||
lt.ForEach(func(k, lv lua.LValue) {
|
||||
if ud, ok := lv.(*lua.LUserData); ok {
|
||||
if entry, ok := unwrapMailAddress(ud); ok {
|
||||
to = append(to, *entry)
|
||||
}
|
||||
}
|
||||
})
|
||||
m.To = to
|
||||
case "subject":
|
||||
m.Subject = ls.CheckString(3)
|
||||
case "size":
|
||||
ls.RaiseError("size is read-only")
|
||||
default:
|
||||
ls.RaiseError("invalid index %q", index)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
93
pkg/extension/luahost/bind_inboundmessage_test.go
Normal file
93
pkg/extension/luahost/bind_inboundmessage_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package luahost
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/inbucket/inbucket/v3/pkg/extension/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// LuaInit holds useful test globals.
|
||||
const LuaInit = `
|
||||
function assert_eq(got, want)
|
||||
if type(got) == "table" and type(want) == "table" then
|
||||
assert(#got == #want, string.format("got %d element(s), wanted %d", #got, #want))
|
||||
|
||||
for i, gotv in ipairs(got) do
|
||||
local wantv = want[i]
|
||||
assert_eq(gotv, wantv, "got[%d] = %q, wanted %q", gotv, wantv)
|
||||
end
|
||||
|
||||
return
|
||||
end
|
||||
|
||||
assert(got == want, string.format("got %q, wanted %q", got, want))
|
||||
end
|
||||
`
|
||||
|
||||
func TestInboundMessageGetters(t *testing.T) {
|
||||
want := &event.InboundMessage{
|
||||
Mailboxes: []string{"mb1", "mb2"},
|
||||
From: mail.Address{Name: "name1", Address: "addr1"},
|
||||
To: []mail.Address{
|
||||
{Name: "name2", Address: "addr2"},
|
||||
{Name: "name3", Address: "addr3"},
|
||||
},
|
||||
Subject: "subj1",
|
||||
Size: 42,
|
||||
}
|
||||
script := `
|
||||
assert(msg, "msg should not be nil")
|
||||
|
||||
assert_eq(msg.mailboxes, {"mb1", "mb2"})
|
||||
assert_eq(msg.subject, "subj1")
|
||||
assert_eq(msg.size, 42)
|
||||
|
||||
assert_eq(msg.from.name, "name1")
|
||||
assert_eq(msg.from.address, "addr1")
|
||||
|
||||
assert_eq(#msg.to, 2)
|
||||
assert_eq(msg.to[1].name, "name2")
|
||||
assert_eq(msg.to[1].address, "addr2")
|
||||
assert_eq(msg.to[2].name, "name3")
|
||||
assert_eq(msg.to[2].address, "addr3")
|
||||
`
|
||||
|
||||
ls := lua.NewState()
|
||||
registerInboundMessageType(ls)
|
||||
registerMailAddressType(ls)
|
||||
ls.SetGlobal("msg", wrapInboundMessage(ls, want))
|
||||
require.NoError(t, ls.DoString(LuaInit+script))
|
||||
}
|
||||
|
||||
func TestInboundMessageSetters(t *testing.T) {
|
||||
want := &event.InboundMessage{
|
||||
Mailboxes: []string{"mb1", "mb2"},
|
||||
From: mail.Address{Name: "name1", Address: "addr1"},
|
||||
To: []mail.Address{
|
||||
{Name: "name2", Address: "addr2"},
|
||||
{Name: "name3", Address: "addr3"},
|
||||
},
|
||||
Subject: "subj1",
|
||||
}
|
||||
script := `
|
||||
assert(msg, "msg should not be nil")
|
||||
|
||||
msg.mailboxes = {"mb1", "mb2"}
|
||||
msg.subject = "subj1"
|
||||
msg.from = address.new("name1", "addr1")
|
||||
msg.to = { address.new("name2", "addr2"), address.new("name3", "addr3") }
|
||||
`
|
||||
|
||||
got := &event.InboundMessage{}
|
||||
ls := lua.NewState()
|
||||
registerInboundMessageType(ls)
|
||||
registerMailAddressType(ls)
|
||||
ls.SetGlobal("msg", wrapInboundMessage(ls, got))
|
||||
require.NoError(t, ls.DoString(script))
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
@@ -29,7 +29,8 @@ type InbucketAfterFuncs struct {
|
||||
// InbucketBeforeFuncs holds references to Lua extension functions to be called
|
||||
// before Inbucket handles an event.
|
||||
type InbucketBeforeFuncs struct {
|
||||
MailAccepted *lua.LFunction
|
||||
MailAccepted *lua.LFunction
|
||||
MessageStored *lua.LFunction
|
||||
}
|
||||
|
||||
func registerInbucketTypes(ls *lua.LState) {
|
||||
@@ -186,6 +187,8 @@ func inbucketBeforeIndex(ls *lua.LState) int {
|
||||
switch field {
|
||||
case "mail_accepted":
|
||||
ls.Push(funcOrNil(before.MailAccepted))
|
||||
case "message_stored":
|
||||
ls.Push(funcOrNil(before.MessageStored))
|
||||
default:
|
||||
// Unknown field.
|
||||
ls.Push(lua.LNil)
|
||||
@@ -202,6 +205,8 @@ func inbucketBeforeNewIndex(ls *lua.LState) int {
|
||||
switch index {
|
||||
case "mail_accepted":
|
||||
m.MailAccepted = ls.CheckFunction(3)
|
||||
case "message_stored":
|
||||
m.MessageStored = ls.CheckFunction(3)
|
||||
default:
|
||||
ls.RaiseError("invalid inbucket.before index %q", index)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestInbucketBeforeFuncs(t *testing.T) {
|
||||
assert(inbucket, "inbucket should not be nil")
|
||||
assert(inbucket.before, "inbucket.before should not be nil")
|
||||
|
||||
local fns = { "mail_accepted" }
|
||||
local fns = { "mail_accepted", "message_stored" }
|
||||
|
||||
-- Verify functions start off nil.
|
||||
for i, name in ipairs(fns) do
|
||||
|
||||
@@ -105,6 +105,9 @@ func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) {
|
||||
if ib.Before.MailAccepted != nil {
|
||||
events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted)
|
||||
}
|
||||
if ib.Before.MessageStored != nil {
|
||||
events.BeforeMessageStored.AddListener(listenerName, h.handleBeforeMessageStored)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) {
|
||||
@@ -174,6 +177,38 @@ func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool {
|
||||
return &result
|
||||
}
|
||||
|
||||
func (h *Host) handleBeforeMessageStored(msg event.InboundMessage) *event.InboundMessage {
|
||||
logger, ls, ib, ok := h.prepareInbucketFuncCall("before.message_stored")
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
defer h.pool.putState(ls)
|
||||
|
||||
logger.Debug().Msgf("Calling Lua function with %+v", msg)
|
||||
if err := ls.CallByParam(
|
||||
lua.P{Fn: ib.Before.MessageStored, NRet: 1, Protect: true},
|
||||
wrapInboundMessage(ls, &msg),
|
||||
); err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to call Lua function")
|
||||
return nil
|
||||
}
|
||||
|
||||
lval := ls.Get(1)
|
||||
logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String())
|
||||
|
||||
if lval.Type() == lua.LTNil || lua.LVIsFalse(lval) {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := unwrapInboundMessage(lval)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Bad response from Lua Function")
|
||||
}
|
||||
ls.Pop(1)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Common preparation for calling Lua functions.
|
||||
func (h *Host) prepareInbucketFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, ib *Inbucket, ok bool) {
|
||||
logger = h.logContext.Str("event", funcName).Logger()
|
||||
|
||||
@@ -17,14 +17,16 @@ import (
|
||||
|
||||
// LuaInit holds useful test globals.
|
||||
const LuaInit = `
|
||||
local logger = require("logger")
|
||||
|
||||
async = false
|
||||
test_ok = true
|
||||
|
||||
-- Sends marks tests failed instead of erroring when enabled.
|
||||
-- Sends marks tests as failed instead of erroring when enabled.
|
||||
function assert_async(value, message)
|
||||
if not value then
|
||||
if async then
|
||||
print(message)
|
||||
logger.error(message, {from = "assert_async"})
|
||||
test_ok = false
|
||||
else
|
||||
error(message)
|
||||
@@ -32,7 +34,7 @@ const LuaInit = `
|
||||
end
|
||||
end
|
||||
|
||||
-- Tests plain values and list-style tables.
|
||||
-- Verifies plain values and list-style tables.
|
||||
function assert_eq(got, want)
|
||||
if type(got) == "table" and type(want) == "table" then
|
||||
assert_async(#got == #want, string.format("got %d elements, wanted %d", #got, #want))
|
||||
@@ -48,17 +50,20 @@ const LuaInit = `
|
||||
assert_async(got == want, string.format("got %q, wanted %q", got, want))
|
||||
end
|
||||
|
||||
-- Verifies string want contains string got.
|
||||
function assert_contains(got, want)
|
||||
assert_async(string.find(got, want),
|
||||
string.format("got %q, wanted it to contain %q", got, want))
|
||||
end
|
||||
`
|
||||
|
||||
var consoleLogger = zerolog.New(zerolog.NewConsoleWriter())
|
||||
|
||||
func TestEmptyScript(t *testing.T) {
|
||||
script := ""
|
||||
extHost := extension.NewHost()
|
||||
|
||||
_, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua")
|
||||
_, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -91,7 +96,7 @@ func TestAfterMessageDeleted(t *testing.T) {
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua")
|
||||
luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
notify := luaHost.CreateChannel("notify")
|
||||
|
||||
@@ -122,7 +127,7 @@ func TestAfterMessageStored(t *testing.T) {
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
luaHost, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(LuaInit+script), "test.lua")
|
||||
luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
notify := luaHost.CreateChannel("notify")
|
||||
|
||||
@@ -148,14 +153,14 @@ func TestBeforeMailAccepted(t *testing.T) {
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
_, err := luahost.NewFromReader(zerolog.Nop(), extHost, strings.NewReader(script), "test.lua")
|
||||
_, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send event to be accepted.
|
||||
addr := &event.AddressParts{Local: "from", Domain: "test"}
|
||||
got := extHost.Events.BeforeMailAccepted.Emit(addr)
|
||||
want := true
|
||||
require.NotNil(t, got)
|
||||
require.NotNil(t, got, "Expected result from Emit()")
|
||||
if *got != want {
|
||||
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
|
||||
}
|
||||
@@ -164,12 +169,81 @@ func TestBeforeMailAccepted(t *testing.T) {
|
||||
addr = &event.AddressParts{Local: "reject", Domain: "me"}
|
||||
got = extHost.Events.BeforeMailAccepted.Emit(addr)
|
||||
want = false
|
||||
require.NotNil(t, got)
|
||||
require.NotNil(t, got, "Expected result from Emit()")
|
||||
if *got != want {
|
||||
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeforeMessageStored(t *testing.T) {
|
||||
// Event to send.
|
||||
msg := event.InboundMessage{
|
||||
Mailboxes: []string{"one", "two"},
|
||||
From: mail.Address{Name: "From Name", Address: "from@example.com"},
|
||||
To: []mail.Address{
|
||||
{Name: "To1 Name", Address: "to1@example.com"},
|
||||
{Name: "To2 Name", Address: "to2@example.com"},
|
||||
},
|
||||
Subject: "inbound subj",
|
||||
Size: 42,
|
||||
}
|
||||
|
||||
// Register lua event listener.
|
||||
script := `
|
||||
async = true
|
||||
|
||||
function inbucket.before.message_stored(msg)
|
||||
-- Verify incoming values.
|
||||
assert_eq(msg.mailboxes, {"one", "two"})
|
||||
assert_eq(msg.from.name, "From Name")
|
||||
assert_eq(msg.from.address, "from@example.com")
|
||||
assert_eq(2, #msg.to)
|
||||
assert_eq(msg.to[1].name, "To1 Name")
|
||||
assert_eq(msg.to[1].address, "to1@example.com")
|
||||
assert_eq(msg.to[2].name, "To2 Name")
|
||||
assert_eq(msg.to[2].address, "to2@example.com")
|
||||
assert_eq(msg.subject, "inbound subj")
|
||||
assert_eq(msg.size, 42)
|
||||
notify:send(test_ok)
|
||||
|
||||
-- Generate response.
|
||||
res = inbound_message.new()
|
||||
res.mailboxes = {"resone", "restwo"}
|
||||
res.from = address.new("Res From", "res@example.com")
|
||||
res.to = {
|
||||
address.new("To1 Res", "res1@example.com"),
|
||||
address.new("To2 Res", "res2@example.com"),
|
||||
}
|
||||
res.subject = "res subj"
|
||||
return res
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
luaHost, err := luahost.NewFromReader(consoleLogger, extHost, strings.NewReader(LuaInit+script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
notify := luaHost.CreateChannel("notify")
|
||||
|
||||
// Send event to be accepted.
|
||||
got := extHost.Events.BeforeMessageStored.Emit(&msg)
|
||||
require.NotNil(t, got, "Expected result from Emit()")
|
||||
|
||||
// Verify Lua assertions passed.
|
||||
assertNotified(t, notify)
|
||||
|
||||
// Verify response values.
|
||||
want := &event.InboundMessage{
|
||||
Mailboxes: []string{"resone", "restwo"},
|
||||
From: mail.Address{Name: "Res From", Address: "res@example.com"},
|
||||
To: []mail.Address{
|
||||
{Name: "To1 Res", Address: "res1@example.com"},
|
||||
{Name: "To2 Res", Address: "res2@example.com"},
|
||||
},
|
||||
Subject: "res subj",
|
||||
Size: 0,
|
||||
}
|
||||
assert.Equal(t, want, got, "Response InboundMessage did not match")
|
||||
}
|
||||
|
||||
func assertNotified(t *testing.T, notify chan lua.LValue) {
|
||||
t.Helper()
|
||||
select {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/cjoudrey/gluahttp"
|
||||
"github.com/cosmotek/loguago"
|
||||
"github.com/inbucket/gopher-json"
|
||||
json "github.com/inbucket/gopher-json"
|
||||
"github.com/rs/zerolog"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
@@ -44,9 +44,10 @@ func (lp *statePool) newState() (*lua.LState, error) {
|
||||
}
|
||||
|
||||
// Register custom types.
|
||||
registerInboundMessageType(ls)
|
||||
registerInbucketTypes(ls)
|
||||
registerMessageMetadataType(ls)
|
||||
registerMailAddressType(ls)
|
||||
registerMessageMetadataType(ls)
|
||||
registerPolicyType(ls)
|
||||
|
||||
// Run compiled script.
|
||||
|
||||
Reference in New Issue
Block a user