mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-17 17:47:03 +00:00
lua: Bind after_message_stored and before_mail_accepted (#322)
Signed-off-by: James Hillyerd <james@hillyerd.com> Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
84
pkg/extension/luahost/bind_address.go
Normal file
84
pkg/extension/luahost/bind_address.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package luahost
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const mailAddressName = "address"
|
||||
|
||||
func registerMailAddressType(ls *lua.LState) {
|
||||
mt := ls.NewTypeMetatable(mailAddressName)
|
||||
ls.SetGlobal(mailAddressName, mt)
|
||||
|
||||
// Static attributes.
|
||||
ls.SetField(mt, "new", ls.NewFunction(newMailAddress))
|
||||
|
||||
// Methods.
|
||||
ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), mailAddressMethods))
|
||||
}
|
||||
|
||||
var mailAddressMethods = map[string]lua.LGFunction{
|
||||
"address": mailAddressGetSetAddress,
|
||||
"name": mailAddressGetSetName,
|
||||
}
|
||||
|
||||
func newMailAddress(ls *lua.LState) int {
|
||||
val := &mail.Address{
|
||||
Name: ls.CheckString(1),
|
||||
Address: ls.CheckString(2),
|
||||
}
|
||||
ud := wrapMailAddress(ls, val)
|
||||
ls.Push(ud)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func wrapMailAddress(ls *lua.LState, val *mail.Address) *lua.LUserData {
|
||||
ud := ls.NewUserData()
|
||||
ud.Value = val
|
||||
ls.SetMetatable(ud, ls.GetTypeMetatable(mailAddressName))
|
||||
|
||||
return ud
|
||||
}
|
||||
|
||||
func unwrapMailAddress(ud *lua.LUserData) (*mail.Address, bool) {
|
||||
val, ok := ud.Value.(*mail.Address)
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func checkMailAddress(ls *lua.LState) *mail.Address {
|
||||
ud := ls.CheckUserData(1)
|
||||
if val, ok := ud.Value.(*mail.Address); ok {
|
||||
return val
|
||||
}
|
||||
ls.ArgError(1, mailAddressName+" expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func mailAddressGetSetAddress(ls *lua.LState) int {
|
||||
val := checkMailAddress(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Address = ls.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LString(val.Address))
|
||||
return 1
|
||||
}
|
||||
|
||||
func mailAddressGetSetName(ls *lua.LState) int {
|
||||
val := checkMailAddress(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Name = ls.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LString(val.Name))
|
||||
return 1
|
||||
}
|
||||
161
pkg/extension/luahost/bind_message.go
Normal file
161
pkg/extension/luahost/bind_message.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package luahost
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"time"
|
||||
|
||||
"github.com/inbucket/inbucket/pkg/extension/event"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const messageMetadataName = "message_metadata"
|
||||
|
||||
func registerMessageMetadataType(ls *lua.LState) {
|
||||
mt := ls.NewTypeMetatable(messageMetadataName)
|
||||
ls.SetGlobal(messageMetadataName, mt)
|
||||
|
||||
// Static attributes.
|
||||
ls.SetField(mt, "new", ls.NewFunction(newMessageMetadata))
|
||||
|
||||
// Methods.
|
||||
ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), messageMetadataMethods))
|
||||
}
|
||||
|
||||
var messageMetadataMethods = map[string]lua.LGFunction{
|
||||
"mailbox": messageMetadataGetSetMailbox,
|
||||
"id": messageMetadataGetSetID,
|
||||
"from": messageMetadataGetSetFrom,
|
||||
"to": messageMetadataGetSetTo,
|
||||
"subject": messageMetadataGetSetSubject,
|
||||
"date": messageMetadataGetSetDate,
|
||||
"size": messageMetadataGetSetSize,
|
||||
}
|
||||
|
||||
func newMessageMetadata(ls *lua.LState) int {
|
||||
val := &event.MessageMetadata{}
|
||||
ud := wrapMessageMetadata(ls, val)
|
||||
ls.Push(ud)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func wrapMessageMetadata(ls *lua.LState, val *event.MessageMetadata) *lua.LUserData {
|
||||
ud := ls.NewUserData()
|
||||
ud.Value = val
|
||||
ls.SetMetatable(ud, ls.GetTypeMetatable(messageMetadataName))
|
||||
|
||||
return ud
|
||||
}
|
||||
|
||||
func checkMessageMetadata(ls *lua.LState) *event.MessageMetadata {
|
||||
ud := ls.CheckUserData(1)
|
||||
if v, ok := ud.Value.(*event.MessageMetadata); ok {
|
||||
return v
|
||||
}
|
||||
ls.ArgError(1, messageMetadataName+" expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageMetadataGetSetMailbox(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Mailbox = ls.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LString(val.Mailbox))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetID(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.ID = ls.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LString(val.ID))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetFrom(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.From = checkMailAddress(ls)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(wrapMailAddress(ls, val.From))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetTo(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
lt := ls.CheckTable(2)
|
||||
to := make([]*mail.Address, lt.Len())
|
||||
lt.ForEach(func(k, lv lua.LValue) {
|
||||
if ud, ok := lv.(*lua.LUserData); ok {
|
||||
if entry, ok := unwrapMailAddress(ud); ok {
|
||||
to = append(to, entry)
|
||||
}
|
||||
}
|
||||
})
|
||||
val.To = to
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
lt := &lua.LTable{}
|
||||
for _, v := range val.To {
|
||||
lt.Append(wrapMailAddress(ls, v))
|
||||
}
|
||||
ls.Push(lt)
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetSubject(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Subject = ls.CheckString(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LString(val.Subject))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetDate(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Date = time.Unix(ls.CheckInt64(2), 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LNumber(val.Date.Unix()))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageMetadataGetSetSize(ls *lua.LState) int {
|
||||
val := checkMessageMetadata(ls)
|
||||
if ls.GetTop() == 2 {
|
||||
// Setter.
|
||||
val.Size = ls.CheckInt64(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
// Getter.
|
||||
ls.Push(lua.LNumber(val.Size))
|
||||
return 1
|
||||
}
|
||||
17
pkg/extension/luahost/bind_policy.go
Normal file
17
pkg/extension/luahost/bind_policy.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package luahost
|
||||
|
||||
import (
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const policyName = "policy"
|
||||
|
||||
func registerPolicyType(ls *lua.LState) {
|
||||
mt := ls.NewTypeMetatable(policyName)
|
||||
ls.SetGlobal(policyName, mt)
|
||||
|
||||
// Static attributes.
|
||||
ls.SetField(mt, "allow", lua.LTrue)
|
||||
ls.SetField(mt, "deny", lua.LFalse)
|
||||
ls.SetField(mt, "defer", lua.LNil)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/inbucket/inbucket/pkg/config"
|
||||
"github.com/inbucket/inbucket/pkg/extension"
|
||||
"github.com/inbucket/inbucket/pkg/extension/event"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
@@ -54,6 +55,7 @@ func New(conf config.Lua, extHost *extension.Host) (*Host, error) {
|
||||
// The provided path is used in logging and error messages.
|
||||
func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, error) {
|
||||
logContext := log.With().Str("module", "lua")
|
||||
logger := logContext.Str("phase", "startup").Str("path", path).Logger()
|
||||
|
||||
// Pre-parse, and compile script.
|
||||
chunk, err := parse.Parse(r, path)
|
||||
@@ -69,6 +71,8 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
|
||||
pool := newStatePool(proto)
|
||||
h := &Host{extHost: extHost, pool: pool, logContext: logContext}
|
||||
if ls, err := pool.getState(); err == nil {
|
||||
h.wireFunctions(logger, ls)
|
||||
|
||||
// State creation works, put it back.
|
||||
pool.putState(ls)
|
||||
} else {
|
||||
@@ -83,3 +87,107 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
|
||||
func (h *Host) CreateChannel(name string) chan lua.LValue {
|
||||
return h.pool.createChannel(name)
|
||||
}
|
||||
|
||||
const afterMessageStoredFnName string = "after_message_stored"
|
||||
const beforeMailAcceptedFnName string = "before_mail_accepted"
|
||||
|
||||
// Detects global lua event listener functions and wires them up.
|
||||
func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) {
|
||||
detectFn := func(name string) bool {
|
||||
lval := ls.GetGlobal(name)
|
||||
switch lval.Type() {
|
||||
case lua.LTFunction:
|
||||
logger.Debug().Msgf("Detected %q function", name)
|
||||
h.Functions = append(h.Functions, name)
|
||||
return true
|
||||
case lua.LTNil:
|
||||
logger.Debug().Msgf("Did not detect %q function", name)
|
||||
default:
|
||||
logger.Fatal().Msgf("Found global named %q, but was a %v instead of a function",
|
||||
name, lval.Type().String())
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
events := h.extHost.Events
|
||||
const listenerName string = "lua"
|
||||
|
||||
if detectFn(afterMessageStoredFnName) {
|
||||
events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored)
|
||||
}
|
||||
if detectFn(beforeMailAcceptedFnName) {
|
||||
events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Void {
|
||||
logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
defer h.pool.putState(ls)
|
||||
|
||||
// Call lua function.
|
||||
logger.Debug().Msgf("Calling Lua function with %+v", msg)
|
||||
if err := ls.CallByParam(
|
||||
lua.P{Fn: lfunc, NRet: 0, Protect: true},
|
||||
wrapMessageMetadata(ls, &msg),
|
||||
); err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to call Lua function")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool {
|
||||
logger, ls, lfunc, ok := h.prepareFuncCall(beforeMailAcceptedFnName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
defer h.pool.putState(ls)
|
||||
|
||||
logger.Debug().Msgf("Calling Lua function with %+v", addr)
|
||||
if err := ls.CallByParam(
|
||||
lua.P{Fn: lfunc, NRet: 1, Protect: true},
|
||||
lua.LString(addr.Local),
|
||||
lua.LString(addr.Domain),
|
||||
); err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to call Lua function")
|
||||
return nil
|
||||
}
|
||||
|
||||
lval := ls.Get(1)
|
||||
ls.Pop(1)
|
||||
logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String())
|
||||
|
||||
if lval.Type() == lua.LTNil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := true
|
||||
if lua.LVIsFalse(lval) {
|
||||
result = false
|
||||
}
|
||||
|
||||
return &result
|
||||
}
|
||||
|
||||
// Common preparation for calling Lua functions.
|
||||
func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, lfunc lua.LValue, ok bool) {
|
||||
logger = h.logContext.Str("event", funcName).Logger()
|
||||
|
||||
ls, err := h.pool.getState()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to get Lua state instance from pool")
|
||||
return logger, nil, nil, false
|
||||
}
|
||||
|
||||
lfunc = ls.GetGlobal(funcName)
|
||||
if lfunc.Type() != lua.LTFunction {
|
||||
logger.Error().Msgf("global %q is no longer a function", funcName)
|
||||
return logger, nil, nil, false
|
||||
}
|
||||
|
||||
return logger, ls, lfunc, true
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package luahost_test
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/inbucket/inbucket/pkg/extension"
|
||||
"github.com/inbucket/inbucket/pkg/extension/event"
|
||||
"github.com/inbucket/inbucket/pkg/extension/luahost"
|
||||
"github.com/stretchr/testify/require"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func TestEmptyScript(t *testing.T) {
|
||||
@@ -16,3 +20,96 @@ func TestEmptyScript(t *testing.T) {
|
||||
_, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAfterMessageStored(t *testing.T) {
|
||||
// Register lua event listener, setup notify channel.
|
||||
script := `
|
||||
local test_ok = true
|
||||
|
||||
function assert_eq(got, want)
|
||||
if got ~= want then
|
||||
-- Incorrect value, schedule test to fail.
|
||||
print("got '" .. got .. "', wanted '" .. want .. "'")
|
||||
test_ok = false
|
||||
end
|
||||
end
|
||||
|
||||
function after_message_stored(msg)
|
||||
assert_eq(msg:mailbox(), "mb1")
|
||||
assert_eq(msg:id(), "id1")
|
||||
assert_eq(msg:subject(), "subj1")
|
||||
assert_eq(msg:size(), 42)
|
||||
|
||||
assert_eq(msg:from():name(), "name1")
|
||||
assert_eq(msg:from():address(), "addr1")
|
||||
|
||||
assert_eq(table.getn(msg:to()), 1)
|
||||
assert_eq(msg:to()[1]:name(), "name2")
|
||||
assert_eq(msg:to()[1]:address(), "addr2")
|
||||
|
||||
assert_eq(msg:date(), 981173106)
|
||||
|
||||
notify:send(test_ok)
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
luaHost, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
notify := luaHost.CreateChannel("notify")
|
||||
|
||||
// Send event, check channel response is true.
|
||||
msg := &event.MessageMetadata{
|
||||
Mailbox: "mb1",
|
||||
ID: "id1",
|
||||
From: &mail.Address{Name: "name1", Address: "addr1"},
|
||||
To: []*mail.Address{{Name: "name2", Address: "addr2"}},
|
||||
Date: time.Date(2001, time.February, 3, 4, 5, 6, 0, time.UTC),
|
||||
Subject: "subj1",
|
||||
Size: 42,
|
||||
}
|
||||
go extHost.Events.AfterMessageStored.Emit(msg)
|
||||
assertNotified(t, notify)
|
||||
}
|
||||
|
||||
func TestBeforeMailAccepted(t *testing.T) {
|
||||
// Register lua event listener.
|
||||
script := `
|
||||
function before_mail_accepted(localpart, domain)
|
||||
return localpart == "from" and domain == "test"
|
||||
end
|
||||
`
|
||||
extHost := extension.NewHost()
|
||||
_, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send event to be accepted.
|
||||
addr := &event.AddressParts{Local: "from", Domain: "test"}
|
||||
got := extHost.Events.BeforeMailAccepted.Emit(addr)
|
||||
want := true
|
||||
require.NotNil(t, got)
|
||||
if *got != want {
|
||||
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
|
||||
}
|
||||
|
||||
// Send event to be denied.
|
||||
addr = &event.AddressParts{Local: "reject", Domain: "me"}
|
||||
got = extHost.Events.BeforeMailAccepted.Emit(addr)
|
||||
want = false
|
||||
require.NotNil(t, got)
|
||||
if *got != want {
|
||||
t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNotified(t *testing.T, notify chan lua.LValue) {
|
||||
t.Helper()
|
||||
select {
|
||||
case reslv := <-notify:
|
||||
// Lua function received event.
|
||||
if lua.LVIsFalse(reslv) {
|
||||
t.Error("Lua responsed with false, wanted true")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Lua did not respond to event within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ func (lp *statePool) newState() (*lua.LState, error) {
|
||||
ls.SetGlobal(name, lua.LChannel(ch))
|
||||
}
|
||||
|
||||
// Register custom types.
|
||||
registerMessageMetadataType(ls)
|
||||
registerMailAddressType(ls)
|
||||
registerPolicyType(ls)
|
||||
|
||||
// Run compiled script.
|
||||
ls.Push(ls.NewFunctionFromProto(lp.funcProto))
|
||||
if err := ls.PCall(0, lua.MultRet, nil); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user