From 88ccb193600fdcb6e4cbdf19661c0aae7b4491ff Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Sun, 15 Jan 2017 20:01:20 -0800 Subject: [PATCH] Implement pub/sub message hub as a base for #44 --- msghub/hub.go | 116 ++++++++++++++++++++++ msghub/hub_test.go | 243 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 msghub/hub.go create mode 100644 msghub/hub_test.go diff --git a/msghub/hub.go b/msghub/hub.go new file mode 100644 index 0000000..0f43077 --- /dev/null +++ b/msghub/hub.go @@ -0,0 +1,116 @@ +package msghub + +import ( + "container/ring" + "context" + "sync" + "time" +) + +// Message contains the basic header data for a message +type Message struct { + Mailbox string + ID string + From string + To []string + Subject string + Date time.Time + Size int64 +} + +// Listener receives the contents of the log, followed by new messages +type Listener interface { + Receive(msg Message) error +} + +// Hub relays messages on to its listeners +type Hub struct { + // log stores history, points next spot to write. First non-nil entry is oldest Message + log *ring.Ring + logMx sync.RWMutex + + // listeners interested in new messages + listeners map[Listener]struct{} + listenersMx sync.RWMutex + + // broadcast receives new messages + broadcast chan Message +} + +// New constructs a new Hub which will cache logSize messages in memory for playback to future +// listeners. A goroutine is created to handle incoming messages; it will run until the provided +// context is canceled. +func New(ctx context.Context, logSize int) *Hub { + h := &Hub{ + log: ring.New(logSize), + listeners: make(map[Listener]struct{}), + broadcast: make(chan Message, 100), + } + + go func() { + for { + select { + case <-ctx.Done(): + // Shutdown + close(h.broadcast) + h.broadcast = nil + return + case msg := <-h.broadcast: + // Log message + h.logMx.Lock() + h.log.Value = msg + h.log = h.log.Next() + h.logMx.Unlock() + // Deliver message to listeners + h.deliver(msg) + } + } + }() + + return h +} + +// Broadcast queues a message for processing by the hub. The message will be placed into the +// in-memory log and relayed to all registered listeners. +func (h *Hub) Broadcast(msg Message) { + if h.broadcast != nil { + h.broadcast <- msg + } +} + +// AddListener registers a listener to receive broadcasted messages. +func (h *Hub) AddListener(l Listener) { + // Playback log + h.logMx.RLock() + h.log.Do(func(v interface{}) { + if v != nil { + l.Receive(v.(Message)) + } + }) + h.logMx.RUnlock() + + // Add to listeners + h.listenersMx.Lock() + h.listeners[l] = struct{}{} + h.listenersMx.Unlock() +} + +// RemoveListener deletes a listener registration, it will cease to receive messages. +func (h *Hub) RemoveListener(l Listener) { + h.listenersMx.Lock() + defer h.listenersMx.Unlock() + if _, ok := h.listeners[l]; ok { + delete(h.listeners, l) + } +} + +// deliver message to all listeners, removing listeners if they return an error +func (h *Hub) deliver(msg Message) { + h.listenersMx.RLock() + defer h.listenersMx.RUnlock() + for l := range h.listeners { + if err := l.Receive(msg); err != nil { + h.RemoveListener(l) + } + } +} diff --git a/msghub/hub_test.go b/msghub/hub_test.go new file mode 100644 index 0000000..f6c131e --- /dev/null +++ b/msghub/hub_test.go @@ -0,0 +1,243 @@ +package msghub + +import ( + "context" + "fmt" + "testing" + "time" +) + +// testListener implements the Listener interface, mock for unit tests +type testListener struct { + messages []*Message // received messages + wantMessages int // how many messages this listener wants to receive + errorAfter int // when != 0, messages until Receive() begins returning error + + done chan struct{} // closed once we have received wantMessages + overflow chan struct{} // closed if we receive wantMessages+1 +} + +func newTestListener(want int) *testListener { + l := &testListener{ + messages: make([]*Message, 0, want*2), + wantMessages: want, + done: make(chan struct{}), + overflow: make(chan struct{}), + } + if want == 0 { + close(l.done) + } + return l +} + +// Receive a Message, store it in the messages slice, close applicable channels, and return an error +// if instructed +func (l *testListener) Receive(msg Message) error { + l.messages = append(l.messages, &msg) + if len(l.messages) == l.wantMessages { + close(l.done) + } + if len(l.messages) == l.wantMessages+1 { + close(l.overflow) + } + if l.errorAfter > 0 && len(l.messages) > l.errorAfter { + return fmt.Errorf("Too many messages") + } + return nil +} + +// String formats the got vs wanted message counts +func (l *testListener) String() string { + return fmt.Sprintf("got %v messages, wanted %v", len(l.messages), l.wantMessages) +} + +func TestHubNew(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + if hub == nil { + t.Fatal("New() == nil, expected a new Hub") + } +} + +func TestHubZeroListeners(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + m := Message{} + for i := 0; i < 100; i++ { + hub.Broadcast(m) + } + // Just making sure Hub doesn't panic +} + +func TestHubOneListener(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + m := Message{} + l := newTestListener(1) + + hub.AddListener(l) + hub.Broadcast(m) + + // Wait for messages + select { + case <-l.done: + case <-time.After(time.Second): + t.Error("Timeout:", l) + } +} + +func TestHubRemoveListener(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + m := Message{} + l := newTestListener(1) + + hub.AddListener(l) + hub.Broadcast(m) + hub.RemoveListener(l) + hub.Broadcast(m) + + // Wait for messages + select { + case <-l.overflow: + t.Error(l) + case <-time.After(250 * time.Millisecond): + // Expected result, no overflow + } +} + +func TestHubRemoveListenerOnError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + m := Message{} + + // error after 1 means listener should receive 2 messages before being removed + l := newTestListener(2) + l.errorAfter = 1 + + hub.AddListener(l) + hub.Broadcast(m) + hub.Broadcast(m) + hub.Broadcast(m) + hub.Broadcast(m) + + // Wait for messages + select { + case <-l.overflow: + t.Error(l) + case <-time.After(250 * time.Millisecond): + // Expected result, no overflow + } +} + +func TestHubLogReplay(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 100) + l1 := newTestListener(3) + hub.AddListener(l1) + + // Broadcast 3 messages with no listeners + msgs := make([]Message, 3) + for i := 0; i < len(msgs); i++ { + msgs[i] = Message{ + Subject: fmt.Sprintf("subj %v", i), + } + hub.Broadcast(msgs[i]) + } + + // Wait for messages (live) + select { + case <-l1.done: + case <-time.After(time.Second): + t.Fatal("Timeout:", l1) + } + + // Add a new listener + l2 := newTestListener(3) + hub.AddListener(l2) + + // Wait for messages (log) + select { + case <-l2.done: + case <-time.After(time.Second): + t.Fatal("Timeout:", l2) + } + + for i := 0; i < len(msgs); i++ { + got := l2.messages[i].Subject + want := msgs[i].Subject + if got != want { + t.Errorf("msg[%v].Subject == %q, want %q", i, got, want) + } + } +} + +func TestHubLogReplayWrap(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hub := New(ctx, 5) + l1 := newTestListener(20) + hub.AddListener(l1) + + // Broadcast more messages than the hub can hold + msgs := make([]Message, 20) + for i := 0; i < len(msgs); i++ { + msgs[i] = Message{ + Subject: fmt.Sprintf("subj %v", i), + } + hub.Broadcast(msgs[i]) + } + + // Wait for messages (live) + select { + case <-l1.done: + case <-time.After(time.Second): + t.Fatal("Timeout:", l1) + } + + // Add a new listener + l2 := newTestListener(5) + hub.AddListener(l2) + + // Wait for messages (log) + select { + case <-l2.done: + case <-time.After(time.Second): + t.Fatal("Timeout:", l2) + } + + for i := 0; i < 5; i++ { + got := l2.messages[i].Subject + want := msgs[i+15].Subject + if got != want { + t.Errorf("msg[%v].Subject == %q, want %q", i, got, want) + } + } +} + +func TestHubContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + hub := New(ctx, 5) + m := Message{} + l := newTestListener(1) + + hub.AddListener(l) + hub.Broadcast(m) + cancel() + time.Sleep(50 * time.Millisecond) + hub.Broadcast(m) + + // Wait for messages + select { + case <-l.overflow: + t.Error(l) + case <-time.After(250 * time.Millisecond): + // Expected result, no overflow + } +}