1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 01:27:01 +00:00

lua: Bind after_message_stored and before_mail_accepted (#322)

Signed-off-by: James Hillyerd <james@hillyerd.com>

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2023-01-24 16:37:26 -08:00
committed by GitHub
parent 55addbb556
commit 7f91c3e9cb
6 changed files with 472 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
package luahost
import (
"net/mail"
lua "github.com/yuin/gopher-lua"
)
const mailAddressName = "address"
func registerMailAddressType(ls *lua.LState) {
mt := ls.NewTypeMetatable(mailAddressName)
ls.SetGlobal(mailAddressName, mt)
// Static attributes.
ls.SetField(mt, "new", ls.NewFunction(newMailAddress))
// Methods.
ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), mailAddressMethods))
}
var mailAddressMethods = map[string]lua.LGFunction{
"address": mailAddressGetSetAddress,
"name": mailAddressGetSetName,
}
func newMailAddress(ls *lua.LState) int {
val := &mail.Address{
Name: ls.CheckString(1),
Address: ls.CheckString(2),
}
ud := wrapMailAddress(ls, val)
ls.Push(ud)
return 1
}
func wrapMailAddress(ls *lua.LState, val *mail.Address) *lua.LUserData {
ud := ls.NewUserData()
ud.Value = val
ls.SetMetatable(ud, ls.GetTypeMetatable(mailAddressName))
return ud
}
func unwrapMailAddress(ud *lua.LUserData) (*mail.Address, bool) {
val, ok := ud.Value.(*mail.Address)
return val, ok
}
func checkMailAddress(ls *lua.LState) *mail.Address {
ud := ls.CheckUserData(1)
if val, ok := ud.Value.(*mail.Address); ok {
return val
}
ls.ArgError(1, mailAddressName+" expected")
return nil
}
func mailAddressGetSetAddress(ls *lua.LState) int {
val := checkMailAddress(ls)
if ls.GetTop() == 2 {
// Setter.
val.Address = ls.CheckString(2)
return 0
}
// Getter.
ls.Push(lua.LString(val.Address))
return 1
}
func mailAddressGetSetName(ls *lua.LState) int {
val := checkMailAddress(ls)
if ls.GetTop() == 2 {
// Setter.
val.Name = ls.CheckString(2)
return 0
}
// Getter.
ls.Push(lua.LString(val.Name))
return 1
}

View File

@@ -0,0 +1,161 @@
package luahost
import (
"net/mail"
"time"
"github.com/inbucket/inbucket/pkg/extension/event"
lua "github.com/yuin/gopher-lua"
)
const messageMetadataName = "message_metadata"
func registerMessageMetadataType(ls *lua.LState) {
mt := ls.NewTypeMetatable(messageMetadataName)
ls.SetGlobal(messageMetadataName, mt)
// Static attributes.
ls.SetField(mt, "new", ls.NewFunction(newMessageMetadata))
// Methods.
ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), messageMetadataMethods))
}
var messageMetadataMethods = map[string]lua.LGFunction{
"mailbox": messageMetadataGetSetMailbox,
"id": messageMetadataGetSetID,
"from": messageMetadataGetSetFrom,
"to": messageMetadataGetSetTo,
"subject": messageMetadataGetSetSubject,
"date": messageMetadataGetSetDate,
"size": messageMetadataGetSetSize,
}
func newMessageMetadata(ls *lua.LState) int {
val := &event.MessageMetadata{}
ud := wrapMessageMetadata(ls, val)
ls.Push(ud)
return 1
}
func wrapMessageMetadata(ls *lua.LState, val *event.MessageMetadata) *lua.LUserData {
ud := ls.NewUserData()
ud.Value = val
ls.SetMetatable(ud, ls.GetTypeMetatable(messageMetadataName))
return ud
}
func checkMessageMetadata(ls *lua.LState) *event.MessageMetadata {
ud := ls.CheckUserData(1)
if v, ok := ud.Value.(*event.MessageMetadata); ok {
return v
}
ls.ArgError(1, messageMetadataName+" expected")
return nil
}
func messageMetadataGetSetMailbox(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.Mailbox = ls.CheckString(2)
return 0
}
// Getter.
ls.Push(lua.LString(val.Mailbox))
return 1
}
func messageMetadataGetSetID(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.ID = ls.CheckString(2)
return 0
}
// Getter.
ls.Push(lua.LString(val.ID))
return 1
}
func messageMetadataGetSetFrom(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.From = checkMailAddress(ls)
return 0
}
// Getter.
ls.Push(wrapMailAddress(ls, val.From))
return 1
}
func messageMetadataGetSetTo(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
lt := ls.CheckTable(2)
to := make([]*mail.Address, lt.Len())
lt.ForEach(func(k, lv lua.LValue) {
if ud, ok := lv.(*lua.LUserData); ok {
if entry, ok := unwrapMailAddress(ud); ok {
to = append(to, entry)
}
}
})
val.To = to
return 0
}
// Getter.
lt := &lua.LTable{}
for _, v := range val.To {
lt.Append(wrapMailAddress(ls, v))
}
ls.Push(lt)
return 1
}
func messageMetadataGetSetSubject(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.Subject = ls.CheckString(2)
return 0
}
// Getter.
ls.Push(lua.LString(val.Subject))
return 1
}
func messageMetadataGetSetDate(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.Date = time.Unix(ls.CheckInt64(2), 0)
return 0
}
// Getter.
ls.Push(lua.LNumber(val.Date.Unix()))
return 1
}
func messageMetadataGetSetSize(ls *lua.LState) int {
val := checkMessageMetadata(ls)
if ls.GetTop() == 2 {
// Setter.
val.Size = ls.CheckInt64(2)
return 0
}
// Getter.
ls.Push(lua.LNumber(val.Size))
return 1
}

View File

@@ -0,0 +1,17 @@
package luahost
import (
lua "github.com/yuin/gopher-lua"
)
const policyName = "policy"
func registerPolicyType(ls *lua.LState) {
mt := ls.NewTypeMetatable(policyName)
ls.SetGlobal(policyName, mt)
// Static attributes.
ls.SetField(mt, "allow", lua.LTrue)
ls.SetField(mt, "deny", lua.LFalse)
ls.SetField(mt, "defer", lua.LNil)
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/inbucket/inbucket/pkg/config"
"github.com/inbucket/inbucket/pkg/extension"
"github.com/inbucket/inbucket/pkg/extension/event"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
lua "github.com/yuin/gopher-lua"
@@ -54,6 +55,7 @@ func New(conf config.Lua, extHost *extension.Host) (*Host, error) {
// The provided path is used in logging and error messages.
func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, error) {
logContext := log.With().Str("module", "lua")
logger := logContext.Str("phase", "startup").Str("path", path).Logger()
// Pre-parse, and compile script.
chunk, err := parse.Parse(r, path)
@@ -69,6 +71,8 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
pool := newStatePool(proto)
h := &Host{extHost: extHost, pool: pool, logContext: logContext}
if ls, err := pool.getState(); err == nil {
h.wireFunctions(logger, ls)
// State creation works, put it back.
pool.putState(ls)
} else {
@@ -83,3 +87,107 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
func (h *Host) CreateChannel(name string) chan lua.LValue {
return h.pool.createChannel(name)
}
const afterMessageStoredFnName string = "after_message_stored"
const beforeMailAcceptedFnName string = "before_mail_accepted"
// Detects global lua event listener functions and wires them up.
func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) {
detectFn := func(name string) bool {
lval := ls.GetGlobal(name)
switch lval.Type() {
case lua.LTFunction:
logger.Debug().Msgf("Detected %q function", name)
h.Functions = append(h.Functions, name)
return true
case lua.LTNil:
logger.Debug().Msgf("Did not detect %q function", name)
default:
logger.Fatal().Msgf("Found global named %q, but was a %v instead of a function",
name, lval.Type().String())
}
return false
}
events := h.extHost.Events
const listenerName string = "lua"
if detectFn(afterMessageStoredFnName) {
events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored)
}
if detectFn(beforeMailAcceptedFnName) {
events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted)
}
}
func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Void {
logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName)
if !ok {
return nil
}
defer h.pool.putState(ls)
// Call lua function.
logger.Debug().Msgf("Calling Lua function with %+v", msg)
if err := ls.CallByParam(
lua.P{Fn: lfunc, NRet: 0, Protect: true},
wrapMessageMetadata(ls, &msg),
); err != nil {
logger.Error().Err(err).Msg("Failed to call Lua function")
}
return nil
}
func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool {
logger, ls, lfunc, ok := h.prepareFuncCall(beforeMailAcceptedFnName)
if !ok {
return nil
}
defer h.pool.putState(ls)
logger.Debug().Msgf("Calling Lua function with %+v", addr)
if err := ls.CallByParam(
lua.P{Fn: lfunc, NRet: 1, Protect: true},
lua.LString(addr.Local),
lua.LString(addr.Domain),
); err != nil {
logger.Error().Err(err).Msg("Failed to call Lua function")
return nil
}
lval := ls.Get(1)
ls.Pop(1)
logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String())
if lval.Type() == lua.LTNil {
return nil
}
result := true
if lua.LVIsFalse(lval) {
result = false
}
return &result
}
// Common preparation for calling Lua functions.
func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, lfunc lua.LValue, ok bool) {
logger = h.logContext.Str("event", funcName).Logger()
ls, err := h.pool.getState()
if err != nil {
logger.Error().Err(err).Msg("Failed to get Lua state instance from pool")
return logger, nil, nil, false
}
lfunc = ls.GetGlobal(funcName)
if lfunc.Type() != lua.LTFunction {
logger.Error().Msgf("global %q is no longer a function", funcName)
return logger, nil, nil, false
}
return logger, ls, lfunc, true
}

View File

@@ -1,12 +1,16 @@
package luahost_test
import (
"net/mail"
"strings"
"testing"
"time"
"github.com/inbucket/inbucket/pkg/extension"
"github.com/inbucket/inbucket/pkg/extension/event"
"github.com/inbucket/inbucket/pkg/extension/luahost"
"github.com/stretchr/testify/require"
lua "github.com/yuin/gopher-lua"
)
func TestEmptyScript(t *testing.T) {
@@ -16,3 +20,96 @@ func TestEmptyScript(t *testing.T) {
_, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
require.NoError(t, err)
}
func TestAfterMessageStored(t *testing.T) {
// Register lua event listener, setup notify channel.
script := `
local test_ok = true
function assert_eq(got, want)
if got ~= want then
-- Incorrect value, schedule test to fail.
print("got '" .. got .. "', wanted '" .. want .. "'")
test_ok = false
end
end
function after_message_stored(msg)
assert_eq(msg:mailbox(), "mb1")
assert_eq(msg:id(), "id1")
assert_eq(msg:subject(), "subj1")
assert_eq(msg:size(), 42)
assert_eq(msg:from():name(), "name1")
assert_eq(msg:from():address(), "addr1")
assert_eq(table.getn(msg:to()), 1)
assert_eq(msg:to()[1]:name(), "name2")
assert_eq(msg:to()[1]:address(), "addr2")
assert_eq(msg:date(), 981173106)
notify:send(test_ok)
end
`
extHost := extension.NewHost()
luaHost, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
require.NoError(t, err)
notify := luaHost.CreateChannel("notify")
// Send event, check channel response is true.
msg := &event.MessageMetadata{
Mailbox: "mb1",
ID: "id1",
From: &mail.Address{Name: "name1", Address: "addr1"},
To: []*mail.Address{{Name: "name2", Address: "addr2"}},
Date: time.Date(2001, time.February, 3, 4, 5, 6, 0, time.UTC),
Subject: "subj1",
Size: 42,
}
go extHost.Events.AfterMessageStored.Emit(msg)
assertNotified(t, notify)
}
func TestBeforeMailAccepted(t *testing.T) {
// Register lua event listener.
script := `
function before_mail_accepted(localpart, domain)
return localpart == "from" and domain == "test"
end
`
extHost := extension.NewHost()
_, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
require.NoError(t, err)
// Send event to be accepted.
addr := &event.AddressParts{Local: "from", Domain: "test"}
got := extHost.Events.BeforeMailAccepted.Emit(addr)
want := true
require.NotNil(t, got)
if *got != want {
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
}
// Send event to be denied.
addr = &event.AddressParts{Local: "reject", Domain: "me"}
got = extHost.Events.BeforeMailAccepted.Emit(addr)
want = false
require.NotNil(t, got)
if *got != want {
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
}
}
func assertNotified(t *testing.T, notify chan lua.LValue) {
t.Helper()
select {
case reslv := <-notify:
// Lua function received event.
if lua.LVIsFalse(reslv) {
t.Error("Lua responsed with false, wanted true")
}
case <-time.After(2 * time.Second):
t.Fatal("Lua did not respond to event within timeout")
}
}

View File

@@ -29,6 +29,11 @@ func (lp *statePool) newState() (*lua.LState, error) {
ls.SetGlobal(name, lua.LChannel(ch))
}
// Register custom types.
registerMessageMetadataType(ls)
registerMailAddressType(ls)
registerPolicyType(ls)
// Run compiled script.
ls.Push(ls.NewFunctionFromProto(lp.funcProto))
if err := ls.PCall(0, lua.MultRet, nil); err != nil {