diff --git a/msghub/hub.go b/msghub/hub.go index 0f43077..afc32b5 100644 --- a/msghub/hub.go +++ b/msghub/hub.go @@ -3,10 +3,12 @@ package msghub import ( "container/ring" "context" - "sync" "time" ) +// Length of msghub operation queue +const opChanLen = 100 + // Message contains the basic header data for a message type Message struct { Mailbox string @@ -18,33 +20,27 @@ type Message struct { Size int64 } -// Listener receives the contents of the log, followed by new messages +// Listener receives the contents of the history buffer, followed by new messages type Listener interface { Receive(msg Message) error } // Hub relays messages on to its listeners type Hub struct { - // log stores history, points next spot to write. First non-nil entry is oldest Message - log *ring.Ring - logMx sync.RWMutex - - // listeners interested in new messages - listeners map[Listener]struct{} - listenersMx sync.RWMutex - - // broadcast receives new messages - broadcast chan Message + // history buffer, points next Message to write. Proceeding non-nil entry is oldest Message + history *ring.Ring + listeners map[Listener]struct{} // listeners interested in new messages + opChan chan func(h *Hub) // operations queued for this actor } -// New constructs a new Hub which will cache logSize messages in memory for playback to future +// New constructs a new Hub which will cache historyLen messages in memory for playback to future // listeners. A goroutine is created to handle incoming messages; it will run until the provided // context is canceled. -func New(ctx context.Context, logSize int) *Hub { +func New(ctx context.Context, historyLen int) *Hub { h := &Hub{ - log: ring.New(logSize), + history: ring.New(historyLen), listeners: make(map[Listener]struct{}), - broadcast: make(chan Message, 100), + opChan: make(chan func(h *Hub), opChanLen), } go func() { @@ -52,17 +48,10 @@ func New(ctx context.Context, logSize int) *Hub { select { case <-ctx.Done(): // Shutdown - close(h.broadcast) - h.broadcast = nil + close(h.opChan) return - case msg := <-h.broadcast: - // Log message - h.logMx.Lock() - h.log.Value = msg - h.log = h.log.Next() - h.logMx.Unlock() - // Deliver message to listeners - h.deliver(msg) + case op := <-h.opChan: + op(h) } } }() @@ -70,47 +59,50 @@ func New(ctx context.Context, logSize int) *Hub { return h } -// Broadcast queues a message for processing by the hub. The message will be placed into the -// in-memory log and relayed to all registered listeners. -func (h *Hub) Broadcast(msg Message) { - if h.broadcast != nil { - h.broadcast <- msg +// Dispatch queues a message for broadcast by the hub. The message will be placed into the +// history buffer and then relayed to all registered listeners. +func (hub *Hub) Dispatch(msg Message) { + hub.opChan <- func(h *Hub) { + // Add to history buffer + h.history.Value = msg + h.history = h.history.Next() + // Deliver message to all listeners, removing listeners if they return an error + for l := range h.listeners { + if err := l.Receive(msg); err != nil { + delete(h.listeners, l) + } + } } } // AddListener registers a listener to receive broadcasted messages. -func (h *Hub) AddListener(l Listener) { - // Playback log - h.logMx.RLock() - h.log.Do(func(v interface{}) { - if v != nil { - l.Receive(v.(Message)) - } - }) - h.logMx.RUnlock() +func (hub *Hub) AddListener(l Listener) { + hub.opChan <- func(h *Hub) { + // Playback log + h.history.Do(func(v interface{}) { + if v != nil { + l.Receive(v.(Message)) + } + }) - // Add to listeners - h.listenersMx.Lock() - h.listeners[l] = struct{}{} - h.listenersMx.Unlock() + // Add to listeners + h.listeners[l] = struct{}{} + } } // RemoveListener deletes a listener registration, it will cease to receive messages. -func (h *Hub) RemoveListener(l Listener) { - h.listenersMx.Lock() - defer h.listenersMx.Unlock() - if _, ok := h.listeners[l]; ok { +func (hub *Hub) RemoveListener(l Listener) { + hub.opChan <- func(h *Hub) { delete(h.listeners, l) } } -// deliver message to all listeners, removing listeners if they return an error -func (h *Hub) deliver(msg Message) { - h.listenersMx.RLock() - defer h.listenersMx.RUnlock() - for l := range h.listeners { - if err := l.Receive(msg); err != nil { - h.RemoveListener(l) - } +// Sync blocks until the msghub has processed its queue up to this point, useful +// for unit tests. +func (hub *Hub) Sync() { + done := make(chan struct{}) + hub.opChan <- func(h *Hub) { + close(done) } + <-done } diff --git a/msghub/hub_test.go b/msghub/hub_test.go index f6c131e..f5da3ad 100644 --- a/msghub/hub_test.go +++ b/msghub/hub_test.go @@ -66,7 +66,7 @@ func TestHubZeroListeners(t *testing.T) { hub := New(ctx, 5) m := Message{} for i := 0; i < 100; i++ { - hub.Broadcast(m) + hub.Dispatch(m) } // Just making sure Hub doesn't panic } @@ -79,7 +79,7 @@ func TestHubOneListener(t *testing.T) { l := newTestListener(1) hub.AddListener(l) - hub.Broadcast(m) + hub.Dispatch(m) // Wait for messages select { @@ -97,15 +97,16 @@ func TestHubRemoveListener(t *testing.T) { l := newTestListener(1) hub.AddListener(l) - hub.Broadcast(m) + hub.Dispatch(m) hub.RemoveListener(l) - hub.Broadcast(m) + hub.Dispatch(m) + hub.Sync() // Wait for messages select { case <-l.overflow: t.Error(l) - case <-time.After(250 * time.Millisecond): + case <-time.After(50 * time.Millisecond): // Expected result, no overflow } } @@ -121,21 +122,22 @@ func TestHubRemoveListenerOnError(t *testing.T) { l.errorAfter = 1 hub.AddListener(l) - hub.Broadcast(m) - hub.Broadcast(m) - hub.Broadcast(m) - hub.Broadcast(m) + hub.Dispatch(m) + hub.Dispatch(m) + hub.Dispatch(m) + hub.Dispatch(m) + hub.Sync() // Wait for messages select { case <-l.overflow: t.Error(l) - case <-time.After(250 * time.Millisecond): + case <-time.After(50 * time.Millisecond): // Expected result, no overflow } } -func TestHubLogReplay(t *testing.T) { +func TestHubHistoryReplay(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() hub := New(ctx, 100) @@ -148,7 +150,7 @@ func TestHubLogReplay(t *testing.T) { msgs[i] = Message{ Subject: fmt.Sprintf("subj %v", i), } - hub.Broadcast(msgs[i]) + hub.Dispatch(msgs[i]) } // Wait for messages (live) @@ -162,7 +164,7 @@ func TestHubLogReplay(t *testing.T) { l2 := newTestListener(3) hub.AddListener(l2) - // Wait for messages (log) + // Wait for messages (history) select { case <-l2.done: case <-time.After(time.Second): @@ -178,7 +180,7 @@ func TestHubLogReplay(t *testing.T) { } } -func TestHubLogReplayWrap(t *testing.T) { +func TestHubHistoryReplayWrap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() hub := New(ctx, 5) @@ -191,7 +193,7 @@ func TestHubLogReplayWrap(t *testing.T) { msgs[i] = Message{ Subject: fmt.Sprintf("subj %v", i), } - hub.Broadcast(msgs[i]) + hub.Dispatch(msgs[i]) } // Wait for messages (live) @@ -205,7 +207,7 @@ func TestHubLogReplayWrap(t *testing.T) { l2 := newTestListener(5) hub.AddListener(l2) - // Wait for messages (log) + // Wait for messages (history) select { case <-l2.done: case <-time.After(time.Second): @@ -228,16 +230,15 @@ func TestHubContextCancel(t *testing.T) { l := newTestListener(1) hub.AddListener(l) - hub.Broadcast(m) + hub.Dispatch(m) + hub.Sync() cancel() - time.Sleep(50 * time.Millisecond) - hub.Broadcast(m) // Wait for messages select { case <-l.overflow: t.Error(l) - case <-time.After(250 * time.Millisecond): + case <-time.After(50 * time.Millisecond): // Expected result, no overflow } } diff --git a/smtpd/handler.go b/smtpd/handler.go index e267d11..7353194 100644 --- a/smtpd/handler.go +++ b/smtpd/handler.go @@ -475,7 +475,7 @@ func (ss *Session) deliverMessage(r recipientDetails, msgBuf [][]byte) (ok bool) Date: msg.Date(), Size: msg.Size(), } - ss.server.msgHub.Broadcast(broadcast) + ss.server.msgHub.Dispatch(broadcast) return true } diff --git a/smtpd/handler_test.go b/smtpd/handler_test.go index 3f4dc82..61d4f69 100644 --- a/smtpd/handler_test.go +++ b/smtpd/handler_test.go @@ -2,6 +2,7 @@ package smtpd import ( "bytes" + "context" "fmt" "io" @@ -28,8 +29,8 @@ func TestGreetState(t *testing.T) { mb1 := &MockMailbox{} mds.On("MailboxFor").Return(mb1, nil) - server, logbuf := setupSMTPServer(mds) - defer teardownSMTPServer(server) + server, logbuf, teardown := setupSMTPServer(mds) + defer teardown() var script []scriptStep @@ -89,8 +90,8 @@ func TestReadyState(t *testing.T) { mb1 := &MockMailbox{} mds.On("MailboxFor").Return(mb1, nil) - server, logbuf := setupSMTPServer(mds) - defer teardownSMTPServer(server) + server, logbuf, teardown := setupSMTPServer(mds) + defer teardown() var script []scriptStep @@ -164,8 +165,8 @@ func TestMailState(t *testing.T) { msg1.On("Size").Return(0) msg1.On("Close").Return(nil) - server, logbuf := setupSMTPServer(mds) - defer teardownSMTPServer(server) + server, logbuf, teardown := setupSMTPServer(mds) + defer teardown() var script []scriptStep @@ -281,8 +282,8 @@ func TestDataState(t *testing.T) { msg1.On("Size").Return(0) msg1.On("Close").Return(nil) - server, logbuf := setupSMTPServer(mds) - defer teardownSMTPServer(server) + server, logbuf, teardown := setupSMTPServer(mds) + defer teardown() var script []scriptStep pipe := setupSMTPSession(server) @@ -375,7 +376,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil } func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } -func setupSMTPServer(ds DataStore) (*Server, *bytes.Buffer) { +func setupSMTPServer(ds DataStore) (s *Server, buf *bytes.Buffer, teardown func()) { // Test Server Config cfg := config.SMTPConfig{ IP4address: net.IPv4(127, 0, 0, 1), @@ -389,12 +390,18 @@ func setupSMTPServer(ds DataStore) (*Server, *bytes.Buffer) { } // Capture log output - buf := new(bytes.Buffer) + buf = new(bytes.Buffer) log.SetOutput(buf) // Create a server, don't start it shutdownChan := make(chan bool) - return NewServer(cfg, shutdownChan, ds, &msghub.Hub{}), buf + ctx, cancel := context.WithCancel(context.Background()) + teardown = func() { + close(shutdownChan) + cancel() + } + s = NewServer(cfg, shutdownChan, ds, msghub.New(ctx, 100)) + return s, buf, teardown } var sessionNum int @@ -409,7 +416,3 @@ func setupSMTPSession(server *Server) net.Conn { return clientConn } - -func teardownSMTPServer(server *Server) { - //log.SetOutput(os.Stderr) -}