From d62a0fede9f4c127e97f2b29e499b3477770f4d0 Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Mon, 19 Feb 2024 16:46:33 -0800 Subject: [PATCH] fix: prevent smtp/handler test from freezing on panic (#503) * chore: colocate SMTP session WaitGroup incr/decr Signed-off-by: James Hillyerd * fix: smtp tests that hang on panic/t.Fatal Signed-off-by: James Hillyerd * chore: reorder smtp/handler test helpers Signed-off-by: James Hillyerd --------- Signed-off-by: James Hillyerd --- pkg/server/smtp/handler.go | 17 ++--- pkg/server/smtp/handler_test.go | 108 +++++++++++++++----------------- pkg/server/smtp/listener.go | 2 - 3 files changed, 61 insertions(+), 66 deletions(-) diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index 33c3866..1281b5a 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -139,20 +139,23 @@ func (s *Session) String() string { return fmt.Sprintf("Session{id: %v, state: %v}", s.id, s.state) } -/* Session flow: - * 1. Send initial greeting - * 2. Receive cmd - * 3. If good cmd, respond, optionally change state - * 4. If bad cmd, respond error - * 5. Goto 2 - */ +// Session flow: +// 1. Send initial greeting +// 2. Receive cmd +// 3. If good cmd, respond, optionally change state +// 4. If bad cmd, respond error +// 5. Goto 2 func (s *Server) startSession(id int, conn net.Conn, logger zerolog.Logger) { logger = logger.Hook(logHook{}).With(). Str("module", "smtp"). Str("remote", conn.RemoteAddr().String()). Int("session", id).Logger() logger.Info().Msg("Starting SMTP session") + + // Update WaitGroup and counters. + s.wg.Add(1) expConnectsCurrent.Add(1) + expConnectsTotal.Add(1) defer func() { if err := conn.Close(); err != nil { logger.Warn().Err(err).Msg("Closing connection") diff --git a/pkg/server/smtp/handler_test.go b/pkg/server/smtp/handler_test.go index f86d113..ac3eb78 100644 --- a/pkg/server/smtp/handler_test.go +++ b/pkg/server/smtp/handler_test.go @@ -45,7 +45,6 @@ func TestGreetStateValidCommands(t *testing.T) { for _, tc := range tests { t.Run(tc.send, func(t *testing.T) { - defer server.Drain() // Required to prevent test logging data race. script := []scriptStep{ tc, {"QUIT", 221}} @@ -58,7 +57,6 @@ func TestGreetStateValidCommands(t *testing.T) { func TestGreetState(t *testing.T) { ds := test.NewStore() server := setupSMTPServer(ds, extension.NewHost()) - defer server.Drain() // Required to prevent test logging data race. tests := []scriptStep{ {"HELO", 501}, @@ -71,7 +69,6 @@ func TestGreetState(t *testing.T) { for _, tc := range tests { t.Run(tc.send, func(t *testing.T) { - defer server.Drain() // Required to prevent test logging data race. script := []scriptStep{ tc, {"QUIT", 221}} @@ -83,7 +80,6 @@ func TestGreetState(t *testing.T) { func TestEmptyEnvelope(t *testing.T) { ds := test.NewStore() server := setupSMTPServer(ds, extension.NewHost()) - defer server.Drain() // Test out some empty envelope without blanks script := []scriptStep{ @@ -104,7 +100,6 @@ func TestEmptyEnvelope(t *testing.T) { func TestAuth(t *testing.T) { ds := test.NewStore() server := setupSMTPServer(ds, extension.NewHost()) - defer server.Drain() // PLAIN AUTH script := []scriptStep{ @@ -137,7 +132,6 @@ func TestAuth(t *testing.T) { func TestTLS(t *testing.T) { ds := test.NewStore() server := setupSMTPServer(ds, extension.NewHost()) - defer server.Drain() // Test Start TLS parsing. script := []scriptStep{ @@ -172,7 +166,6 @@ func TestReadyStateValidCommands(t *testing.T) { for _, tc := range tests { t.Run(tc.send, func(t *testing.T) { - defer server.Drain() script := []scriptStep{ {"HELO localhost", 250}, tc, @@ -196,7 +189,6 @@ func TestReadyStateRejectedDomains(t *testing.T) { for _, tc := range tests { t.Run(tc.send, func(t *testing.T) { - defer server.Drain() script := []scriptStep{ {"HELO localhost", 250}, tc, @@ -226,7 +218,6 @@ func TestReadyStateInvalidCommands(t *testing.T) { for _, tc := range tests { t.Run(tc.send, func(t *testing.T) { - defer server.Drain() script := []scriptStep{ {"HELO localhost", 250}, tc, @@ -240,7 +231,6 @@ func TestReadyStateInvalidCommands(t *testing.T) { func TestMailState(t *testing.T) { mds := test.NewStore() server := setupSMTPServer(mds, extension.NewHost()) - defer server.Drain() // Test out some mangled READY commands script := []scriptStep{ @@ -333,7 +323,6 @@ func TestMailState(t *testing.T) { func TestDataState(t *testing.T) { mds := test.NewStore() server := setupSMTPServer(mds, extension.NewHost()) - defer server.Drain() var script []scriptStep pipe := setupSMTPSession(t, server) @@ -395,54 +384,11 @@ Hi! _, _, _ = c.ReadCodeLine(221) } -// playSession creates a new session, reads the greeting and then plays the script -func playSession(t *testing.T, server *Server, script []scriptStep) { - t.Helper() - pipe := setupSMTPSession(t, server) - c := textproto.NewConn(pipe) - - if code, _, err := c.ReadCodeLine(220); err != nil { - t.Errorf("expected a 220 greeting, got %v", code) - } - - playScriptAgainst(t, c, script) - - // Not all tests leave the session in a clean state, so the following two calls can fail - _, _ = c.Cmd("QUIT") - _, _, _ = c.ReadCodeLine(221) -} - -// playScriptAgainst an existing connection, does not handle server greeting -func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) { - t.Helper() - - for i, step := range script { - id, err := c.Cmd(step.send) - if err != nil { - t.Fatalf("Step %d, failed to send %q: %v", i, step.send, err) - } - - c.StartResponse(id) - code, msg, err := c.ReadResponse(step.expect) - if err != nil { - err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q", - i, step.send, step.expect, code, msg) - } - c.EndResponse(id) - - if err != nil { - // Fail after c.EndResponse so we don't hang the connection - t.Fatal(err) - } - } -} - // Tests "MAIL FROM" emits BeforeMailAccepted event. func TestBeforeMailAcceptedEventEmitted(t *testing.T) { ds := test.NewStore() extHost := extension.NewHost() server := setupSMTPServer(ds, extHost) - defer server.Drain() var got *event.AddressParts extHost.Events.BeforeMailAccepted.AddListener( @@ -469,7 +415,6 @@ func TestBeforeMailAcceptedEventResponse(t *testing.T) { ds := test.NewStore() extHost := extension.NewHost() server := setupSMTPServer(ds, extHost) - defer server.Drain() var shouldReturn *bool var gotEvent *event.AddressParts @@ -519,6 +464,48 @@ func TestBeforeMailAcceptedEventResponse(t *testing.T) { } } +// playSession creates a new session, reads the greeting and then plays the script +func playSession(t *testing.T, server *Server, script []scriptStep) { + t.Helper() + pipe := setupSMTPSession(t, server) + c := textproto.NewConn(pipe) + + if code, _, err := c.ReadCodeLine(220); err != nil { + t.Errorf("expected a 220 greeting, got %v", code) + } + + playScriptAgainst(t, c, script) + + // Not all tests leave the session in a clean state, so the following two calls can fail + _, _ = c.Cmd("QUIT") + _, _, _ = c.ReadCodeLine(221) +} + +// playScriptAgainst an existing connection, does not handle server greeting +func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) { + t.Helper() + + for i, step := range script { + id, err := c.Cmd(step.send) + if err != nil { + t.Fatalf("Step %d, failed to send %q: %v", i, step.send, err) + } + + c.StartResponse(id) + code, msg, err := c.ReadResponse(step.expect) + if err != nil { + err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q", + i, step.send, step.expect, code, msg) + } + c.EndResponse(id) + + if err != nil { + // Fail after c.EndResponse so we don't hang the connection + t.Fatal(err) + } + } +} + // net.Pipe does not implement deadlines type mockConn struct { net.Conn @@ -528,6 +515,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil } func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } +// Creates an unstarted smtp.Server. func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server { cfg := &config.Root{ MailboxNaming: config.FullNaming, @@ -543,7 +531,7 @@ func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server { }, } - // Create a server, don't start it. + // Create a server, but don't start it. addrPolicy := &policy.Addressing{Config: cfg} manager := &message.StoreManager{Store: ds, ExtHost: extHost} @@ -556,9 +544,15 @@ func setupSMTPSession(t *testing.T, server *Server) net.Conn { t.Helper() logger := zerolog.New(zerolog.NewTestWriter(t)) serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + + // Drain is required to prevent a test-logging data race. If a (failing) test run is + // hanging, this may be the culprit. + server.Drain() + }) // Start the session. - server.wg.Add(1) sessionNum++ go server.startSession(sessionNum, &mockConn{serverConn}, logger) diff --git a/pkg/server/smtp/listener.go b/pkg/server/smtp/listener.go index b67648a..b8d7e10 100644 --- a/pkg/server/smtp/listener.go +++ b/pkg/server/smtp/listener.go @@ -176,8 +176,6 @@ func (s *Server) serve(ctx context.Context) { } } else { tempDelay = 0 - expConnectsTotal.Add(1) - s.wg.Add(1) go s.startSession(sessionID, conn, log.Logger) } }