From 5a886813c3b67a7b2ee12a012d6b43828857a056 Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Mon, 27 Feb 2023 20:22:10 -0800 Subject: [PATCH] Provide `inbucket` object in Lua (#351) * fix delve fortify thingy * Expose inbucket.after.message_stored in lua * Expose inbucket.after.message_deleted in lua * Expose inbucket.before.mail_accepted in lua --- pkg/extension/luahost/bind_inbucket.go | 218 ++++++++++++++++++++ pkg/extension/luahost/bind_inbucket_test.go | 102 +++++++++ pkg/extension/luahost/lua.go | 51 ++--- pkg/extension/luahost/lua_test.go | 6 +- pkg/extension/luahost/pool.go | 1 + shell.nix | 3 + 6 files changed, 344 insertions(+), 37 deletions(-) create mode 100644 pkg/extension/luahost/bind_inbucket.go create mode 100644 pkg/extension/luahost/bind_inbucket_test.go diff --git a/pkg/extension/luahost/bind_inbucket.go b/pkg/extension/luahost/bind_inbucket.go new file mode 100644 index 0000000..e2273d7 --- /dev/null +++ b/pkg/extension/luahost/bind_inbucket.go @@ -0,0 +1,218 @@ +package luahost + +import ( + "errors" + "fmt" + + lua "github.com/yuin/gopher-lua" +) + +const ( + inbucketName = "inbucket" + inbucketBeforeName = "inbucket_before" + inbucketAfterName = "inbucket_after" +) + +// Inbucket is the primary Lua interface data structure. +type Inbucket struct { + After InbucketAfterFuncs + Before InbucketBeforeFuncs +} + +// InbucketAfterFuncs holds references to Lua extension functions to be called async +// after Inbucket handles an event. +type InbucketAfterFuncs struct { + MessageDeleted *lua.LFunction + MessageStored *lua.LFunction +} + +// InbucketBeforeFuncs holds references to Lua extension functions to be called +// before Inbucket handles an event. +type InbucketBeforeFuncs struct { + MailAccepted *lua.LFunction +} + +func registerInbucketTypes(ls *lua.LState) { + // inbucket type. + mt := ls.NewTypeMetatable(inbucketName) + ls.SetField(mt, "__index", ls.NewFunction(inbucketIndex)) + + // inbucket global var. + ud := wrapInbucket(ls, &Inbucket{}) + ls.SetGlobal(inbucketName, ud) + + // inbucket.after type. + mt = ls.NewTypeMetatable(inbucketAfterName) + ls.SetField(mt, "__index", ls.NewFunction(inbucketAfterIndex)) + ls.SetField(mt, "__newindex", ls.NewFunction(inbucketAfterNewIndex)) + + // inbucket.before type. + mt = ls.NewTypeMetatable(inbucketBeforeName) + ls.SetField(mt, "__index", ls.NewFunction(inbucketBeforeIndex)) + ls.SetField(mt, "__newindex", ls.NewFunction(inbucketBeforeNewIndex)) +} + +func wrapInbucket(ls *lua.LState, val *Inbucket) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(inbucketName)) + + return ud +} + +func wrapInbucketAfter(ls *lua.LState, val *InbucketAfterFuncs) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(inbucketAfterName)) + + return ud +} + +func wrapInbucketBefore(ls *lua.LState, val *InbucketBeforeFuncs) *lua.LUserData { + ud := ls.NewUserData() + ud.Value = val + ls.SetMetatable(ud, ls.GetTypeMetatable(inbucketBeforeName)) + + return ud +} + +func getInbucket(ls *lua.LState) (*Inbucket, error) { + lv := ls.GetGlobal(inbucketName) + if lv == nil { + return nil, errors.New("inbucket object was nil") + } + + ud, ok := lv.(*lua.LUserData) + if !ok { + return nil, fmt.Errorf("inbucket object was type %s instead of UserData", lv.Type()) + } + + val, ok := ud.Value.(*Inbucket) + if !ok { + return nil, fmt.Errorf("inbucket object (%v) could not be cast", ud.Value) + } + + return val, nil +} + +func checkInbucket(ls *lua.LState, pos int) *Inbucket { + ud := ls.CheckUserData(pos) + if val, ok := ud.Value.(*Inbucket); ok { + return val + } + ls.ArgError(1, inbucketName+" expected") + return nil +} + +func checkInbucketAfter(ls *lua.LState, pos int) *InbucketAfterFuncs { + ud := ls.CheckUserData(pos) + if val, ok := ud.Value.(*InbucketAfterFuncs); ok { + return val + } + ls.ArgError(1, inbucketAfterName+" expected") + return nil +} + +func checkInbucketBefore(ls *lua.LState, pos int) *InbucketBeforeFuncs { + ud := ls.CheckUserData(pos) + if val, ok := ud.Value.(*InbucketBeforeFuncs); ok { + return val + } + ls.ArgError(1, inbucketBeforeName+" expected") + return nil +} + +// inbucket getter. +func inbucketIndex(ls *lua.LState) int { + ib := checkInbucket(ls, 1) + field := ls.CheckString(2) + + // Push the requested field's value onto the stack. + switch field { + case "after": + ls.Push(wrapInbucketAfter(ls, &ib.After)) + case "before": + ls.Push(wrapInbucketBefore(ls, &ib.Before)) + default: + // Unknown field. + ls.Push(lua.LNil) + } + + return 1 +} + +// inbucket.after getter. +func inbucketAfterIndex(ls *lua.LState) int { + after := checkInbucketAfter(ls, 1) + field := ls.CheckString(2) + + // Push the requested field's value onto the stack. + switch field { + case "message_deleted": + ls.Push(funcOrNil(after.MessageDeleted)) + case "message_stored": + ls.Push(funcOrNil(after.MessageStored)) + default: + // Unknown field. + ls.Push(lua.LNil) + } + + return 1 +} + +// inbucket.after setter. +func inbucketAfterNewIndex(ls *lua.LState) int { + m := checkInbucketAfter(ls, 1) + index := ls.CheckString(2) + + switch index { + case "message_deleted": + m.MessageDeleted = ls.CheckFunction(3) + case "message_stored": + m.MessageStored = ls.CheckFunction(3) + default: + ls.RaiseError("invalid inbucket.after index %q", index) + } + + return 0 +} + +// inbucket.before getter. +func inbucketBeforeIndex(ls *lua.LState) int { + before := checkInbucketBefore(ls, 1) + field := ls.CheckString(2) + + // Push the requested field's value onto the stack. + switch field { + case "mail_accepted": + ls.Push(funcOrNil(before.MailAccepted)) + default: + // Unknown field. + ls.Push(lua.LNil) + } + + return 1 +} + +// inbucket.before setter. +func inbucketBeforeNewIndex(ls *lua.LState) int { + m := checkInbucketBefore(ls, 1) + index := ls.CheckString(2) + + switch index { + case "mail_accepted": + m.MailAccepted = ls.CheckFunction(3) + default: + ls.RaiseError("invalid inbucket.before index %q", index) + } + + return 0 +} + +func funcOrNil(f *lua.LFunction) lua.LValue { + if f == nil { + return lua.LNil + } + + return f +} diff --git a/pkg/extension/luahost/bind_inbucket_test.go b/pkg/extension/luahost/bind_inbucket_test.go new file mode 100644 index 0000000..6f7f007 --- /dev/null +++ b/pkg/extension/luahost/bind_inbucket_test.go @@ -0,0 +1,102 @@ +package luahost + +import ( + "testing" + + "github.com/stretchr/testify/require" + lua "github.com/yuin/gopher-lua" +) + +func TestInbucketAfterFuncs(t *testing.T) { + // This Script registers each function and calls it. No effort is made to use the arguments + // that Inbucket expects, this is only to validate the inbucket.after data structure getters + // and setters. + script := ` + assert(inbucket, "inbucket should not be nil") + assert(inbucket.after, "inbucket.after should not be nil") + + local fns = { "message_deleted", "message_stored" } + + -- Verify functions start off nil. + for i, name in ipairs(fns) do + assert(inbucket.after[name] == nil, "after." .. name .. " should be nil") + end + + -- Test function to track func calls made, ensures no crossed wires. + local calls = {} + function makeTestFunc(create_name) + return function(call_name) + calls[create_name] = call_name + end + end + + -- Set after functions, verify not nil, and call them. + for i, name in ipairs(fns) do + inbucket.after[name] = makeTestFunc(name) + assert(inbucket.after[name], "after." .. name .. " should not be nil") + end + + -- Call each function. Separate loop to verify final state in 'calls'. + for i, name in ipairs(fns) do + inbucket.after[name](name) + end + + -- Verify functions were called. + for i, name in ipairs(fns) do + assert(calls[name], "after." .. name .. " should have been called") + assert(calls[name] == name, + string.format("after.%s was called with incorrect argument %s", name, calls[name])) + end + ` + + ls := lua.NewState() + registerInbucketTypes(ls) + require.NoError(t, ls.DoString(script)) +} + +func TestInbucketBeforeFuncs(t *testing.T) { + // This Script registers each function and calls it. No effort is made to use the arguments + // that Inbucket expects, this is only to validate the inbucket.before data structure getters + // and setters. + script := ` + assert(inbucket, "inbucket should not be nil") + assert(inbucket.before, "inbucket.before should not be nil") + + local fns = { "mail_accepted" } + + -- Verify functions start off nil. + for i, name in ipairs(fns) do + assert(inbucket.before[name] == nil, "before." .. name .. " should be nil") + end + + -- Test function to track func calls made, ensures no crossed wires. + local calls = {} + function makeTestFunc(create_name) + return function(call_name) + calls[create_name] = call_name + end + end + + -- Set before functions, verify not nil, and call them. + for i, name in ipairs(fns) do + inbucket.before[name] = makeTestFunc(name) + assert(inbucket.before[name], "before." .. name .. " should not be nil") + end + + -- Call each function. Separate loop to verify final state in 'calls'. + for i, name in ipairs(fns) do + inbucket.before[name](name) + end + + -- Verify functions were called. + for i, name in ipairs(fns) do + assert(calls[name], "before." .. name .. " should have been called") + assert(calls[name] == name, + string.format("before.%s was called with incorrect argument %s", name, calls[name])) + end + ` + + ls := lua.NewState() + registerInbucketTypes(ls) + require.NoError(t, ls.DoString(script)) +} diff --git a/pkg/extension/luahost/lua.go b/pkg/extension/luahost/lua.go index f399ccb..d91ce41 100644 --- a/pkg/extension/luahost/lua.go +++ b/pkg/extension/luahost/lua.go @@ -17,7 +17,6 @@ import ( // Host of Lua extensions. type Host struct { - Functions []string // Functions detected in lua script. extHost *extension.Host pool *statePool logContext zerolog.Context @@ -88,45 +87,29 @@ func (h *Host) CreateChannel(name string) chan lua.LValue { return h.pool.createChannel(name) } -const afterMessageDeletedFnName string = "after_message_deleted" -const afterMessageStoredFnName string = "after_message_stored" -const beforeMailAcceptedFnName string = "before_mail_accepted" - // Detects global lua event listener functions and wires them up. func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) { - detectFn := func(name string) bool { - lval := ls.GetGlobal(name) - switch lval.Type() { - case lua.LTFunction: - logger.Debug().Msgf("Detected %q function", name) - h.Functions = append(h.Functions, name) - return true - case lua.LTNil: - logger.Debug().Msgf("Did not detect %q function", name) - default: - logger.Fatal().Msgf("Found global named %q, but was a %v instead of a function", - name, lval.Type().String()) - } - - return false + ib, err := getInbucket(ls) + if err != nil { + logger.Fatal().Err(err).Msg("Failed to get inbucket global") } events := h.extHost.Events const listenerName string = "lua" - if detectFn(afterMessageDeletedFnName) { + if ib.After.MessageDeleted != nil { events.AfterMessageDeleted.AddListener(listenerName, h.handleAfterMessageDeleted) } - if detectFn(afterMessageStoredFnName) { + if ib.After.MessageStored != nil { events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored) } - if detectFn(beforeMailAcceptedFnName) { + if ib.Before.MailAccepted != nil { events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted) } } func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { - logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageDeletedFnName) + logger, ls, ib, ok := h.prepareInbucketFuncCall("after.message_deleted") if !ok { return } @@ -135,7 +118,7 @@ func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { // Call lua function. logger.Debug().Msgf("Calling Lua function with %+v", msg) if err := ls.CallByParam( - lua.P{Fn: lfunc, NRet: 0, Protect: true}, + lua.P{Fn: ib.After.MessageDeleted, NRet: 0, Protect: true}, wrapMessageMetadata(ls, &msg), ); err != nil { logger.Error().Err(err).Msg("Failed to call Lua function") @@ -143,7 +126,7 @@ func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { } func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) { - logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName) + logger, ls, ib, ok := h.prepareInbucketFuncCall("after.message_stored") if !ok { return } @@ -152,7 +135,7 @@ func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) { // Call lua function. logger.Debug().Msgf("Calling Lua function with %+v", msg) if err := ls.CallByParam( - lua.P{Fn: lfunc, NRet: 0, Protect: true}, + lua.P{Fn: ib.After.MessageStored, NRet: 0, Protect: true}, wrapMessageMetadata(ls, &msg), ); err != nil { logger.Error().Err(err).Msg("Failed to call Lua function") @@ -160,7 +143,7 @@ func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) { } func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { - logger, ls, lfunc, ok := h.prepareFuncCall(beforeMailAcceptedFnName) + logger, ls, ib, ok := h.prepareInbucketFuncCall("after.message_stored") if !ok { return nil } @@ -168,7 +151,7 @@ func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { logger.Debug().Msgf("Calling Lua function with %+v", addr) if err := ls.CallByParam( - lua.P{Fn: lfunc, NRet: 1, Protect: true}, + lua.P{Fn: ib.Before.MailAccepted, NRet: 1, Protect: true}, lua.LString(addr.Local), lua.LString(addr.Domain), ); err != nil { @@ -193,7 +176,7 @@ func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { } // Common preparation for calling Lua functions. -func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, lfunc lua.LValue, ok bool) { +func (h *Host) prepareInbucketFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, ib *Inbucket, ok bool) { logger = h.logContext.Str("event", funcName).Logger() ls, err := h.pool.getState() @@ -202,11 +185,11 @@ func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua. return logger, nil, nil, false } - lfunc = ls.GetGlobal(funcName) - if lfunc.Type() != lua.LTFunction { - logger.Error().Msgf("global %q is no longer a function", funcName) + ib, err = getInbucket(ls) + if err != nil { + logger.Error().Err(err).Msg("Failed to obtain Lua inbucket object") return logger, nil, nil, false } - return logger, ls, lfunc, true + return logger, ls, ib, true } diff --git a/pkg/extension/luahost/lua_test.go b/pkg/extension/luahost/lua_test.go index 953c17f..376def2 100644 --- a/pkg/extension/luahost/lua_test.go +++ b/pkg/extension/luahost/lua_test.go @@ -65,7 +65,7 @@ func TestAfterMessageDeleted(t *testing.T) { script := ` async = true - function after_message_deleted(msg) + function inbucket.after.message_deleted(msg) -- Full message bindings tested elsewhere. assert_eq(msg.mailbox, "mb1") assert_eq(msg.id, "id1") @@ -96,7 +96,7 @@ func TestAfterMessageStored(t *testing.T) { script := ` async = true - function after_message_stored(msg) + function inbucket.after.message_stored(msg) -- Full message bindings tested elsewhere. assert_eq(msg.mailbox, "mb1") assert_eq(msg.id, "id1") @@ -125,7 +125,7 @@ func TestAfterMessageStored(t *testing.T) { func TestBeforeMailAccepted(t *testing.T) { // Register lua event listener. script := ` - function before_mail_accepted(localpart, domain) + function inbucket.before.mail_accepted(localpart, domain) return localpart == "from" and domain == "test" end ` diff --git a/pkg/extension/luahost/pool.go b/pkg/extension/luahost/pool.go index 9ed1aa5..9e14fa7 100644 --- a/pkg/extension/luahost/pool.go +++ b/pkg/extension/luahost/pool.go @@ -37,6 +37,7 @@ func (lp *statePool) newState() (*lua.LState, error) { } // Register custom types. + registerInbucketTypes(ls) registerMessageMetadataType(ls) registerMailAddressType(ls) registerPolicyType(ls) diff --git a/shell.nix b/shell.nix index 263b912..e163f13 100644 --- a/shell.nix +++ b/shell.nix @@ -32,4 +32,7 @@ pkgs.mkShell { scripts.qt ]; + + # Prevents launch errors with delve debugger. + hardeningDisable = [ "fortify" ]; }