1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 17:47:03 +00:00

lua: Bind after_message_stored and before_mail_accepted (#322)

Signed-off-by: James Hillyerd <james@hillyerd.com>

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2023-01-24 16:37:26 -08:00
committed by GitHub
parent 55addbb556
commit 7f91c3e9cb
6 changed files with 472 additions and 0 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/inbucket/inbucket/pkg/config"
"github.com/inbucket/inbucket/pkg/extension"
"github.com/inbucket/inbucket/pkg/extension/event"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
lua "github.com/yuin/gopher-lua"
@@ -54,6 +55,7 @@ func New(conf config.Lua, extHost *extension.Host) (*Host, error) {
// The provided path is used in logging and error messages.
func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, error) {
logContext := log.With().Str("module", "lua")
logger := logContext.Str("phase", "startup").Str("path", path).Logger()
// Pre-parse, and compile script.
chunk, err := parse.Parse(r, path)
@@ -69,6 +71,8 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
pool := newStatePool(proto)
h := &Host{extHost: extHost, pool: pool, logContext: logContext}
if ls, err := pool.getState(); err == nil {
h.wireFunctions(logger, ls)
// State creation works, put it back.
pool.putState(ls)
} else {
@@ -83,3 +87,107 @@ func NewFromReader(extHost *extension.Host, r io.Reader, path string) (*Host, er
func (h *Host) CreateChannel(name string) chan lua.LValue {
return h.pool.createChannel(name)
}
const afterMessageStoredFnName string = "after_message_stored"
const beforeMailAcceptedFnName string = "before_mail_accepted"
// Detects global lua event listener functions and wires them up.
func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) {
detectFn := func(name string) bool {
lval := ls.GetGlobal(name)
switch lval.Type() {
case lua.LTFunction:
logger.Debug().Msgf("Detected %q function", name)
h.Functions = append(h.Functions, name)
return true
case lua.LTNil:
logger.Debug().Msgf("Did not detect %q function", name)
default:
logger.Fatal().Msgf("Found global named %q, but was a %v instead of a function",
name, lval.Type().String())
}
return false
}
events := h.extHost.Events
const listenerName string = "lua"
if detectFn(afterMessageStoredFnName) {
events.AfterMessageStored.AddListener(listenerName, h.handleAfterMessageStored)
}
if detectFn(beforeMailAcceptedFnName) {
events.BeforeMailAccepted.AddListener(listenerName, h.handleBeforeMailAccepted)
}
}
func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Void {
logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName)
if !ok {
return nil
}
defer h.pool.putState(ls)
// Call lua function.
logger.Debug().Msgf("Calling Lua function with %+v", msg)
if err := ls.CallByParam(
lua.P{Fn: lfunc, NRet: 0, Protect: true},
wrapMessageMetadata(ls, &msg),
); err != nil {
logger.Error().Err(err).Msg("Failed to call Lua function")
}
return nil
}
func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool {
logger, ls, lfunc, ok := h.prepareFuncCall(beforeMailAcceptedFnName)
if !ok {
return nil
}
defer h.pool.putState(ls)
logger.Debug().Msgf("Calling Lua function with %+v", addr)
if err := ls.CallByParam(
lua.P{Fn: lfunc, NRet: 1, Protect: true},
lua.LString(addr.Local),
lua.LString(addr.Domain),
); err != nil {
logger.Error().Err(err).Msg("Failed to call Lua function")
return nil
}
lval := ls.Get(1)
ls.Pop(1)
logger.Debug().Msgf("Lua function returned %q (%v)", lval, lval.Type().String())
if lval.Type() == lua.LTNil {
return nil
}
result := true
if lua.LVIsFalse(lval) {
result = false
}
return &result
}
// Common preparation for calling Lua functions.
func (h *Host) prepareFuncCall(funcName string) (logger zerolog.Logger, ls *lua.LState, lfunc lua.LValue, ok bool) {
logger = h.logContext.Str("event", funcName).Logger()
ls, err := h.pool.getState()
if err != nil {
logger.Error().Err(err).Msg("Failed to get Lua state instance from pool")
return logger, nil, nil, false
}
lfunc = ls.GetGlobal(funcName)
if lfunc.Type() != lua.LTFunction {
logger.Error().Msgf("global %q is no longer a function", funcName)
return logger, nil, nil, false
}
return logger, ls, lfunc, true
}