diff --git a/pkg/extension/async_broker.go b/pkg/extension/async_broker.go new file mode 100644 index 0000000..d13362d --- /dev/null +++ b/pkg/extension/async_broker.go @@ -0,0 +1,89 @@ +package extension + +import ( + "errors" + "sync" + "time" +) + +// AsyncEventBroker maintains a list of listeners interested in a specific type +// of event. Events are sent in parallel to all listeners, and no result is +// returned. +type AsyncEventBroker[E any] struct { + sync.RWMutex + listenerNames []string // Ordered listener names. + listenerFuncs []func(E) // Ordered listener functions. +} + +// Emit sends the provided event to each registered listener in parallel. +func (eb *AsyncEventBroker[E]) Emit(event *E) { + eb.RLock() + defer eb.RUnlock() + + for _, l := range eb.listenerFuncs { + // Events are copied to minimize the risk of mutation. + go l(*event) + } +} + +// AddListener registers the named listener, replacing one with a duplicate +// name if present. Listeners should be added in order of priority, most +// significant first. +func (eb *AsyncEventBroker[E]) AddListener(name string, listener func(E)) { + eb.Lock() + defer eb.Unlock() + + eb.lockedRemoveListener(name) + eb.listenerNames = append(eb.listenerNames, name) + eb.listenerFuncs = append(eb.listenerFuncs, listener) +} + +// RemoveListener unregisters the named listener. +func (eb *AsyncEventBroker[E]) RemoveListener(name string) { + eb.Lock() + defer eb.Unlock() + + eb.lockedRemoveListener(name) +} + +func (eb *AsyncEventBroker[E]) lockedRemoveListener(name string) { + for i, entry := range eb.listenerNames { + if entry == name { + eb.listenerNames = append(eb.listenerNames[:i], eb.listenerNames[i+1:]...) + eb.listenerFuncs = append(eb.listenerFuncs[:i], eb.listenerFuncs[i+1:]...) + break + } + } +} + +// AsyncTestListener returns a func that will wait for an event and return it, or timeout +// with an error. +func (eb *AsyncEventBroker[E]) AsyncTestListener(name string, capacity int) func() (*E, error) { + // Send event down channel. + events := make(chan E, capacity) + eb.AddListener(name, + func(msg E) { + events <- msg + }) + + count := 0 + + return func() (*E, error) { + count++ + + defer func() { + if count >= capacity { + eb.RemoveListener(name) + close(events) + } + }() + + select { + case event := <-events: + return &event, nil + + case <-time.After(time.Second * 2): + return nil, errors.New("Timeout waiting for event") + } + } +} diff --git a/pkg/extension/async_broker_test.go b/pkg/extension/async_broker_test.go new file mode 100644 index 0000000..253394d --- /dev/null +++ b/pkg/extension/async_broker_test.go @@ -0,0 +1,101 @@ +package extension_test + +import ( + "testing" + "time" + + "github.com/inbucket/inbucket/pkg/extension" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Simple smoke test without using AsyncTestListener. +func TestAsyncBrokerEmitCallsOneListener(t *testing.T) { + broker := &extension.AsyncEventBroker[string]{} + + // Setup listener. + events := make(chan string, 1) + listener := func(s string) { + events <- s + } + broker.AddListener("x", listener) + + want := "bacon" + broker.Emit(&want) + + var got string + select { + case event := <-events: + got = event + + case <-time.After(time.Second * 2): + t.Fatal("Timeout waiting for event") + } + + if got != want { + t.Errorf("Emit got %q, want %q", got, want) + } +} + +func TestAsyncBrokerEmitCallsMultipleListeners(t *testing.T) { + broker := &extension.AsyncEventBroker[string]{} + + // Setup listeners. + first := broker.AsyncTestListener("first", 1) + second := broker.AsyncTestListener("second", 1) + + want := "hi" + broker.Emit(&want) + + first_got, err := first() + require.NoError(t, err) + assert.Equal(t, want, *first_got) + + second_got, err := second() + require.NoError(t, err) + assert.Equal(t, want, *second_got) +} + +func TestAsyncBrokerAddingDuplicateNameReplacesPrevious(t *testing.T) { + broker := &extension.AsyncEventBroker[string]{} + + // Setup listeners. + first := broker.AsyncTestListener("dup", 1) + second := broker.AsyncTestListener("dup", 1) + + want := "hi" + broker.Emit(&want) + + first_got, err := first() + require.Error(t, err) + assert.Nil(t, first_got) + + second_got, err := second() + require.NoError(t, err) + assert.Equal(t, want, *second_got) +} + +func TestAsyncBrokerRemovingListenerSuccessful(t *testing.T) { + broker := &extension.AsyncEventBroker[string]{} + + // Setup listeners. + first := broker.AsyncTestListener("1", 1) + second := broker.AsyncTestListener("2", 1) + broker.RemoveListener("1") + + want := "hi" + broker.Emit(&want) + + first_got, err := first() + require.Error(t, err) + assert.Nil(t, first_got) + + second_got, err := second() + require.NoError(t, err) + assert.Equal(t, want, *second_got) +} + +func TestAsyncBrokerRemovingMissingListener(t *testing.T) { + broker := &extension.AsyncEventBroker[string]{} + broker.RemoveListener("doesn't crash") +} diff --git a/pkg/extension/broker.go b/pkg/extension/broker.go index fa711eb..742482e 100644 --- a/pkg/extension/broker.go +++ b/pkg/extension/broker.go @@ -1,9 +1,7 @@ package extension import ( - "errors" "sync" - "time" ) // EventBroker maintains a list of listeners interested in a specific type @@ -59,38 +57,3 @@ func (eb *EventBroker[E, R]) lockedRemoveListener(name string) { } } } - -// AsyncTestListener returns a func that will wait for an event and return it, or timeout -// with an error. -func (eb *EventBroker[E, R]) AsyncTestListener(capacity int) func() (*E, error) { - const name = "asyncTestListener" - - // Send event down channel. - events := make(chan E, capacity) - eb.AddListener(name, - func(msg E) *R { - events <- msg - return nil - }) - - count := 0 - - return func() (*E, error) { - count++ - - defer func() { - if count >= capacity { - eb.RemoveListener(name) - close(events) - } - }() - - select { - case event := <-events: - return &event, nil - - case <-time.After(time.Second * 2): - return nil, errors.New("Timeout waiting for event") - } - } -} diff --git a/pkg/extension/host.go b/pkg/extension/host.go index bef9846..47e1b52 100644 --- a/pkg/extension/host.go +++ b/pkg/extension/host.go @@ -20,8 +20,8 @@ type Host struct { // processed asynchronously with respect to the rest of Inbuckets operation. However, an event // listener will not be called until the one before it complets. type Events struct { - AfterMessageDeleted EventBroker[event.MessageMetadata, Void] - AfterMessageStored EventBroker[event.MessageMetadata, Void] + AfterMessageDeleted AsyncEventBroker[event.MessageMetadata] + AfterMessageStored AsyncEventBroker[event.MessageMetadata] BeforeMailAccepted EventBroker[event.AddressParts, bool] } diff --git a/pkg/extension/luahost/lua.go b/pkg/extension/luahost/lua.go index 121a7c3..f399ccb 100644 --- a/pkg/extension/luahost/lua.go +++ b/pkg/extension/luahost/lua.go @@ -125,10 +125,10 @@ func (h *Host) wireFunctions(logger zerolog.Logger, ls *lua.LState) { } } -func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) *extension.Void { +func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) { logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageDeletedFnName) if !ok { - return nil + return } defer h.pool.putState(ls) @@ -140,14 +140,12 @@ func (h *Host) handleAfterMessageDeleted(msg event.MessageMetadata) *extension.V ); err != nil { logger.Error().Err(err).Msg("Failed to call Lua function") } - - return nil } -func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Void { +func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) { logger, ls, lfunc, ok := h.prepareFuncCall(afterMessageStoredFnName) if !ok { - return nil + return } defer h.pool.putState(ls) @@ -159,8 +157,6 @@ func (h *Host) handleAfterMessageStored(msg event.MessageMetadata) *extension.Vo ); err != nil { logger.Error().Err(err).Msg("Failed to call Lua function") } - - return nil } func (h *Host) handleBeforeMailAccepted(addr event.AddressParts) *bool { diff --git a/pkg/message/manager_test.go b/pkg/message/manager_test.go index 9826953..2111e2d 100644 --- a/pkg/message/manager_test.go +++ b/pkg/message/manager_test.go @@ -2,14 +2,13 @@ package message_test import ( "testing" - "time" "github.com/inbucket/inbucket/pkg/extension" - "github.com/inbucket/inbucket/pkg/extension/event" "github.com/inbucket/inbucket/pkg/message" "github.com/inbucket/inbucket/pkg/policy" "github.com/inbucket/inbucket/pkg/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestManagerEmitsMessageStoredEvent(t *testing.T) { @@ -20,16 +19,7 @@ func TestManagerEmitsMessageStoredEvent(t *testing.T) { ExtHost: extHost, } - // Capture message event. - gotc := make(chan *event.MessageMetadata) - defer close(gotc) - - extHost.Events.AfterMessageStored.AddListener( - "test", - func(msg event.MessageMetadata) *extension.Void { - gotc <- &msg - return nil - }) + listener := extHost.Events.AfterMessageStored.AsyncTestListener("manager", 1) // Attempt to deliver a message to generate event. if _, err := sm.Deliver( @@ -42,10 +32,7 @@ func TestManagerEmitsMessageStoredEvent(t *testing.T) { t.Fatal(err) } - select { - case got := <-gotc: - assert.NotNil(t, got, "No event received, or it was nil") - case <-time.After(time.Second * 2): - t.Fatal("Timeout waiting for message event") - } + got, err := listener() + require.NoError(t, err) + assert.NotNil(t, got, "No event received, or it was nil") } diff --git a/pkg/msghub/hub.go b/pkg/msghub/hub.go index 714f493..ba87df2 100644 --- a/pkg/msghub/hub.go +++ b/pkg/msghub/hub.go @@ -38,15 +38,13 @@ func New(historyLen int, extHost *extension.Host) *Hub { // Register an extension event listener for MessageStored. extHost.Events.AfterMessageStored.AddListener("msghub", - func(msg event.MessageMetadata) *extension.Void { + func(msg event.MessageMetadata) { hub.Dispatch(msg) - return nil }) extHost.Events.AfterMessageDeleted.AddListener("msghub", - func(msg event.MessageMetadata) *extension.Void { + func(msg event.MessageMetadata) { hub.Delete(msg.Mailbox, msg.ID) - return nil }) return hub diff --git a/pkg/test/storage_suite.go b/pkg/test/storage_suite.go index f22abb2..f8a7ca9 100644 --- a/pkg/test/storage_suite.go +++ b/pkg/test/storage_suite.go @@ -298,7 +298,7 @@ func testDelete(t *testing.T, store storage.Store, extHost *extension.Host) { msgs := GetAndCountMessages(t, store, mailbox, len(subjects)) // Subscribe to events. - eventListener := extHost.Events.AfterMessageDeleted.AsyncTestListener(2) + eventListener := extHost.Events.AfterMessageDeleted.AsyncTestListener("test", 2) // Delete a couple messages. deleteIDs := []string{msgs[1].ID(), msgs[3].ID()} @@ -345,7 +345,7 @@ func testPurge(t *testing.T, store storage.Store, extHost *extension.Host) { subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"} // Subscribe to events. - eventListener := extHost.Events.AfterMessageDeleted.AsyncTestListener(len(subjects)) + eventListener := extHost.Events.AfterMessageDeleted.AsyncTestListener("test", len(subjects)) // Populate mailbox. for _, subj := range subjects {