From 7f91c3e9cba545b6b56a3f6b060a34c7334bafc4 Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Tue, 24 Jan 2023 16:37:26 -0800 Subject: [PATCH] lua: Bind after_message_stored and before_mail_accepted (#322) Signed-off-by: James Hillyerd Signed-off-by: James Hillyerd --- pkg/extension/luahost/bind_address.go | 84 ++++++++++++++ pkg/extension/luahost/bind_message.go | 161 ++++++++++++++++++++++++++ pkg/extension/luahost/bind_policy.go | 17 +++ pkg/extension/luahost/lua.go | 108 +++++++++++++++++ pkg/extension/luahost/lua_test.go | 97 ++++++++++++++++ pkg/extension/luahost/pool.go | 5 + 6 files changed, 472 insertions(+) create mode 100644 pkg/extension/luahost/bind_address.go create mode 100644 pkg/extension/luahost/bind_message.go create mode 100644 pkg/extension/luahost/bind_policy.go diff --git a/pkg/extension/luahost/bind_address.go b/pkg/extension/luahost/bind_address.go new file mode 100644 index 0000000..e3ed4c3 --- /dev/null +++ b/pkg/extension/luahost/bind_address.go @@ -0,0 +1,84 @@ +package luahost + +import ( + "net/mail" + + lua "github.com/yuin/gopher-lua" +) + +const mailAddressName = "address" + +func registerMailAddressType(ls *lua.LState) { + mt := ls.NewTypeMetatable(mailAddressName) + ls.SetGlobal(mailAddressName, mt) + + // Static attributes. + ls.SetField(mt, "new", ls.NewFunction(newMailAddress)) + + // Methods. + ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), mailAddressMethods)) +} + +var mailAddressMethods = map[string]lua.LGFunction{ + "address": mailAddressGetSetAddress, + "name": mailAddressGetSetName, +} + +func newMailAddress(ls *lua.LState) int { + val := &mail.Address{ + Name: ls.CheckString(1), + Address: ls.CheckString(2), + } + ud := wrapMailAddress(ls, val) + ls.Push(ud) + + return 1 +} + +func wrapMailAddress(ls *lua.LState, val *mail.Address) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(mailAddressName)) + + return ud +} + +func unwrapMailAddress(ud *lua.LUserData) (*mail.Address, bool) { + val, ok := ud.Value.(*mail.Address) + return val, ok +} + +func checkMailAddress(ls *lua.LState) *mail.Address { + ud := ls.CheckUserData(1) + if val, ok := ud.Value.(*mail.Address); ok { + return val + } + ls.ArgError(1, mailAddressName+" expected") + return nil +} + +func mailAddressGetSetAddress(ls *lua.LState) int { + val := checkMailAddress(ls) + if ls.GetTop() == 2 { + // Setter. + val.Address = ls.CheckString(2) + return 0 + } + + // Getter. + ls.Push(lua.LString(val.Address)) + return 1 +} + +func mailAddressGetSetName(ls *lua.LState) int { + val := checkMailAddress(ls) + if ls.GetTop() == 2 { + // Setter. + val.Name = ls.CheckString(2) + return 0 + } + + // Getter. + ls.Push(lua.LString(val.Name)) + return 1 +} diff --git a/pkg/extension/luahost/bind_message.go b/pkg/extension/luahost/bind_message.go new file mode 100644 index 0000000..28ecfa2 --- /dev/null +++ b/pkg/extension/luahost/bind_message.go @@ -0,0 +1,161 @@ +package luahost + +import ( + "net/mail" + "time" + + "github.com/inbucket/inbucket/pkg/extension/event" + lua "github.com/yuin/gopher-lua" +) + +const messageMetadataName = "message_metadata" + +func registerMessageMetadataType(ls *lua.LState) { + mt := ls.NewTypeMetatable(messageMetadataName) + ls.SetGlobal(messageMetadataName, mt) + + // Static attributes. + ls.SetField(mt, "new", ls.NewFunction(newMessageMetadata)) + + // Methods. + ls.SetField(mt, "__index", ls.SetFuncs(ls.NewTable(), messageMetadataMethods)) +} + +var messageMetadataMethods = map[string]lua.LGFunction{ + "mailbox": messageMetadataGetSetMailbox, + "id": messageMetadataGetSetID, + "from": messageMetadataGetSetFrom, + "to": messageMetadataGetSetTo, + "subject": messageMetadataGetSetSubject, + "date": messageMetadataGetSetDate, + "size": messageMetadataGetSetSize, +} + +func newMessageMetadata(ls *lua.LState) int { + val := &event.MessageMetadata{} + ud := wrapMessageMetadata(ls, val) + ls.Push(ud) + + return 1 +} + +func wrapMessageMetadata(ls *lua.LState, val *event.MessageMetadata) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(messageMetadataName)) + + return ud +} + +func checkMessageMetadata(ls *lua.LState) *event.MessageMetadata { + ud := ls.CheckUserData(1) + if v, ok := ud.Value.(*event.MessageMetadata); ok { + return v + } + ls.ArgError(1, messageMetadataName+" expected") + return nil +} + +func messageMetadataGetSetMailbox(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.Mailbox = ls.CheckString(2) + return 0 + } + + // Getter. + ls.Push(lua.LString(val.Mailbox)) + return 1 +} + +func messageMetadataGetSetID(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.ID = ls.CheckString(2) + return 0 + } + + // Getter. + ls.Push(lua.LString(val.ID)) + return 1 +} + +func messageMetadataGetSetFrom(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.From = checkMailAddress(ls) + return 0 + } + + // Getter. + ls.Push(wrapMailAddress(ls, val.From)) + return 1 +} + +func messageMetadataGetSetTo(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + lt := ls.CheckTable(2) + to := make([]*mail.Address, lt.Len()) + lt.ForEach(func(k, lv lua.LValue) { + if ud, ok := lv.(*lua.LUserData); ok { + if entry, ok := unwrapMailAddress(ud); ok { + to = append(to, entry) + } + } + }) + val.To = to + return 0 + } + + // Getter. + lt := &lua.LTable{} + for _, v := range val.To { + lt.Append(wrapMailAddress(ls, v)) + } + ls.Push(lt) + return 1 +} + +func messageMetadataGetSetSubject(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.Subject = ls.CheckString(2) + return 0 + } + + // Getter. + ls.Push(lua.LString(val.Subject)) + return 1 +} + +func messageMetadataGetSetDate(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.Date = time.Unix(ls.CheckInt64(2), 0) + return 0 + } + + // Getter. + ls.Push(lua.LNumber(val.Date.Unix())) + return 1 +} + +func messageMetadataGetSetSize(ls *lua.LState) int { + val := checkMessageMetadata(ls) + if ls.GetTop() == 2 { + // Setter. + val.Size = ls.CheckInt64(2) + return 0 + } + + // Getter. + ls.Push(lua.LNumber(val.Size)) + return 1 +} diff --git a/pkg/extension/luahost/bind_policy.go b/pkg/extension/luahost/bind_policy.go new file mode 100644 index 0000000..8a42060 --- /dev/null +++ b/pkg/extension/luahost/bind_policy.go @@ -0,0 +1,17 @@ +package luahost + +import ( + lua "github.com/yuin/gopher-lua" +) + +const policyName = "policy" + +func registerPolicyType(ls *lua.LState) { + mt := ls.NewTypeMetatable(policyName) + ls.SetGlobal(policyName, mt) + + // Static attributes. + ls.SetField(mt, "allow", lua.LTrue) + ls.SetField(mt, "deny", lua.LFalse) + ls.SetField(mt, "defer", lua.LNil) +} diff --git a/pkg/extension/luahost/lua.go b/pkg/extension/luahost/lua.go index 346a9c3..eb74e7b 100644 --- a/pkg/extension/luahost/lua.go +++ b/pkg/extension/luahost/lua.go @@ -8,6 +8,7 @@ import ( "github.com/inbucket/inbucket/pkg/config" "github.com/inbucket/inbucket/pkg/extension" + "github.com/inbucket/inbucket/pkg/extension/event" "github.com/rs/zerolog" "github.com/rs/zerolog/log" lua "github.com/yuin/gopher-lua" @@ -54,6 +55,7 @@ func New(conf config.Lua, extHost *extension.Host) (*Host, error) { // The provided path is used in logging and error messages. func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, error) { logContext := log.With().Str("module", "lua") + logger := logContext.Str("phase", "startup").Str("path", path).Logger() // Pre-parse, and compile script. chunk, err := parse.Parse(r, path) @@ -69,6 +71,8 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er pool := newStatePool(proto) h := &Host{extHost: extHost, pool: pool, logContext: logContext} if ls, err := pool.getState(); err == nil { + h.wireFunctions(logger, ls) + // State creation works, put it back. pool.putState(ls) } else { @@ -83,3 +87,107 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er func (h *Host) CreateChannel(name string) chan lua.LValue { return h.pool.createChannel(name) } + +const afterMessageStoredFnName string = "after_message_stored" +const beforeMailAcceptedFnName string = "before_mail_accepted" + +// Detects global lua event listener functions and wires them up. +func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) { + detectFn := func(name string) bool { + lval := ls.GetGlobal(name) + switch lval.Type() { + case lua.LTFunction: + logger.Debug().Msgf("Detected %q function", name) + h.Functions = append(h.Functions, name) + return true + case lua.LTNil: + logger.Debug().Msgf("Did not detect %q function", name) + default: + logger.Fatal().Msgf("Found global named %q, but was a %v instead of a function", + name, lval.Type().String()) + } + + return false + } + + events := h.extHost.Events + const listenerName string = "lua" + + if detectFn(afterMessageStoredFnName) { + events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored) + } + if detectFn(beforeMailAcceptedFnName) { + events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted) + } +} + +func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Void { + logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName) + if !ok { + return nil + } + defer h.pool.putState(ls) + + // Call lua function. + logger.Debug().Msgf("Calling Lua function with %+v", msg) + if err := ls.CallByParam( + lua.P{Fn: lfunc, NRet: 0, Protect: true}, + wrapMessageMetadata(ls, &msg), + ); err != nil { + logger.Error().Err(err).Msg("Failed to call Lua function") + } + + return nil +} + +func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { + logger, ls, lfunc, ok := h.prepareFuncCall(beforeMailAcceptedFnName) + if !ok { + return nil + } + defer h.pool.putState(ls) + + logger.Debug().Msgf("Calling Lua function with %+v", addr) + if err := ls.CallByParam( + lua.P{Fn: lfunc, NRet: 1, Protect: true}, + lua.LString(addr.Local), + lua.LString(addr.Domain), + ); err != nil { + logger.Error().Err(err).Msg("Failed to call Lua function") + return nil + } + + lval := ls.Get(1) + ls.Pop(1) + logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String()) + + if lval.Type() == lua.LTNil { + return nil + } + + result := true + if lua.LVIsFalse(lval) { + result = false + } + + return &result +} + +// Common preparation for calling Lua functions. +func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, lfunc lua.LValue, ok bool) { + logger = h.logContext.Str("event", funcName).Logger() + + ls, err := h.pool.getState() + if err != nil { + logger.Error().Err(err).Msg("Failed to get Lua state instance from pool") + return logger, nil, nil, false + } + + lfunc = ls.GetGlobal(funcName) + if lfunc.Type() != lua.LTFunction { + logger.Error().Msgf("global %q is no longer a function", funcName) + return logger, nil, nil, false + } + + return logger, ls, lfunc, true +} diff --git a/pkg/extension/luahost/lua_test.go b/pkg/extension/luahost/lua_test.go index 4cac14f..012af5d 100644 --- a/pkg/extension/luahost/lua_test.go +++ b/pkg/extension/luahost/lua_test.go @@ -1,12 +1,16 @@ package luahost_test import ( + "net/mail" "strings" "testing" + "time" "github.com/inbucket/inbucket/pkg/extension" + "github.com/inbucket/inbucket/pkg/extension/event" "github.com/inbucket/inbucket/pkg/extension/luahost" "github.com/stretchr/testify/require" + lua "github.com/yuin/gopher-lua" ) func TestEmptyScript(t *testing.T) { @@ -16,3 +20,96 @@ func TestEmptyScript(t *testing.T) { _, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua") require.NoError(t, err) } + +func TestAfterMessageStored(t *testing.T) { + // Register lua event listener, setup notify channel. + script := ` + local test_ok = true + + function assert_eq(got, want) + if got ~= want then + -- Incorrect value, schedule test to fail. + print("got '" .. got .. "', wanted '" .. want .. "'") + test_ok = false + end + end + + function after_message_stored(msg) + assert_eq(msg:mailbox(), "mb1") + assert_eq(msg:id(), "id1") + assert_eq(msg:subject(), "subj1") + assert_eq(msg:size(), 42) + + assert_eq(msg:from():name(), "name1") + assert_eq(msg:from():address(), "addr1") + + assert_eq(table.getn(msg:to()), 1) + assert_eq(msg:to()[1]:name(), "name2") + assert_eq(msg:to()[1]:address(), "addr2") + + assert_eq(msg:date(), 981173106) + + notify:send(test_ok) + end + ` + extHost := extension.NewHost() + luaHost, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua") + require.NoError(t, err) + notify := luaHost.CreateChannel("notify") + + // Send event, check channel response is true. + msg := &event.MessageMetadata{ + Mailbox: "mb1", + ID: "id1", + From: &mail.Address{Name: "name1", Address: "addr1"}, + To: []*mail.Address{{Name: "name2", Address: "addr2"}}, + Date: time.Date(2001, time.February, 3, 4, 5, 6, 0, time.UTC), + Subject: "subj1", + Size: 42, + } + go extHost.Events.AfterMessageStored.Emit(msg) + assertNotified(t, notify) +} + +func TestBeforeMailAccepted(t *testing.T) { + // Register lua event listener. + script := ` + function before_mail_accepted(localpart, domain) + return localpart == "from" and domain == "test" + end + ` + extHost := extension.NewHost() + _, err := luahost.NewFromReader(extHost, strings.NewReader(script), "test.lua") + require.NoError(t, err) + + // Send event to be accepted. + addr := &event.AddressParts{Local: "from", Domain: "test"} + got := extHost.Events.BeforeMailAccepted.Emit(addr) + want := true + require.NotNil(t, got) + if *got != want { + t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) + } + + // Send event to be denied. + addr = &event.AddressParts{Local: "reject", Domain: "me"} + got = extHost.Events.BeforeMailAccepted.Emit(addr) + want = false + require.NotNil(t, got) + if *got != want { + t.Errorf("Got %v, wanted %v for addr %v", *got, want, addr) + } +} + +func assertNotified(t *testing.T, notify chan lua.LValue) { + t.Helper() + select { + case reslv := <-notify: + // Lua function received event. + if lua.LVIsFalse(reslv) { + t.Error("Lua responsed with false, wanted true") + } + case <-time.After(2 * time.Second): + t.Fatal("Lua did not respond to event within timeout") + } +} diff --git a/pkg/extension/luahost/pool.go b/pkg/extension/luahost/pool.go index f6f5d8d..4b3dc92 100644 --- a/pkg/extension/luahost/pool.go +++ b/pkg/extension/luahost/pool.go @@ -29,6 +29,11 @@ func (lp *statePool) newState() (*lua.LState, error) { ls.SetGlobal(name, lua.LChannel(ch)) } + // Register custom types. + registerMessageMetadataType(ls) + registerMailAddressType(ls) + registerPolicyType(ls) + // Run compiled script. ls.Push(ls.NewFunctionFromProto(lp.funcProto)) if err := ls.PCall(0, lua.MultRet, nil); err != nil {