mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-17 17:47:03 +00:00
fix: prevent smtp/handler test from freezing on panic (#503)
* chore: colocate SMTP session WaitGroup incr/decr Signed-off-by: James Hillyerd <james@hillyerd.com> * fix: smtp tests that hang on panic/t.Fatal Signed-off-by: James Hillyerd <james@hillyerd.com> * chore: reorder smtp/handler test helpers Signed-off-by: James Hillyerd <james@hillyerd.com> --------- Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
@@ -139,20 +139,23 @@ func (s *Session) String() string {
|
|||||||
return fmt.Sprintf("Session{id: %v, state: %v}", s.id, s.state)
|
return fmt.Sprintf("Session{id: %v, state: %v}", s.id, s.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Session flow:
|
// Session flow:
|
||||||
* 1. Send initial greeting
|
// 1. Send initial greeting
|
||||||
* 2. Receive cmd
|
// 2. Receive cmd
|
||||||
* 3. If good cmd, respond, optionally change state
|
// 3. If good cmd, respond, optionally change state
|
||||||
* 4. If bad cmd, respond error
|
// 4. If bad cmd, respond error
|
||||||
* 5. Goto 2
|
// 5. Goto 2
|
||||||
*/
|
|
||||||
func (s *Server) startSession(id int, conn net.Conn, logger zerolog.Logger) {
|
func (s *Server) startSession(id int, conn net.Conn, logger zerolog.Logger) {
|
||||||
logger = logger.Hook(logHook{}).With().
|
logger = logger.Hook(logHook{}).With().
|
||||||
Str("module", "smtp").
|
Str("module", "smtp").
|
||||||
Str("remote", conn.RemoteAddr().String()).
|
Str("remote", conn.RemoteAddr().String()).
|
||||||
Int("session", id).Logger()
|
Int("session", id).Logger()
|
||||||
logger.Info().Msg("Starting SMTP session")
|
logger.Info().Msg("Starting SMTP session")
|
||||||
|
|
||||||
|
// Update WaitGroup and counters.
|
||||||
|
s.wg.Add(1)
|
||||||
expConnectsCurrent.Add(1)
|
expConnectsCurrent.Add(1)
|
||||||
|
expConnectsTotal.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
logger.Warn().Err(err).Msg("Closing connection")
|
logger.Warn().Err(err).Msg("Closing connection")
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ func TestGreetStateValidCommands(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.send, func(t *testing.T) {
|
t.Run(tc.send, func(t *testing.T) {
|
||||||
defer server.Drain() // Required to prevent test logging data race.
|
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
tc,
|
tc,
|
||||||
{"QUIT", 221}}
|
{"QUIT", 221}}
|
||||||
@@ -58,7 +57,6 @@ func TestGreetStateValidCommands(t *testing.T) {
|
|||||||
func TestGreetState(t *testing.T) {
|
func TestGreetState(t *testing.T) {
|
||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
server := setupSMTPServer(ds, extension.NewHost())
|
server := setupSMTPServer(ds, extension.NewHost())
|
||||||
defer server.Drain() // Required to prevent test logging data race.
|
|
||||||
|
|
||||||
tests := []scriptStep{
|
tests := []scriptStep{
|
||||||
{"HELO", 501},
|
{"HELO", 501},
|
||||||
@@ -71,7 +69,6 @@ func TestGreetState(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.send, func(t *testing.T) {
|
t.Run(tc.send, func(t *testing.T) {
|
||||||
defer server.Drain() // Required to prevent test logging data race.
|
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
tc,
|
tc,
|
||||||
{"QUIT", 221}}
|
{"QUIT", 221}}
|
||||||
@@ -83,7 +80,6 @@ func TestGreetState(t *testing.T) {
|
|||||||
func TestEmptyEnvelope(t *testing.T) {
|
func TestEmptyEnvelope(t *testing.T) {
|
||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
server := setupSMTPServer(ds, extension.NewHost())
|
server := setupSMTPServer(ds, extension.NewHost())
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
// Test out some empty envelope without blanks
|
// Test out some empty envelope without blanks
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
@@ -104,7 +100,6 @@ func TestEmptyEnvelope(t *testing.T) {
|
|||||||
func TestAuth(t *testing.T) {
|
func TestAuth(t *testing.T) {
|
||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
server := setupSMTPServer(ds, extension.NewHost())
|
server := setupSMTPServer(ds, extension.NewHost())
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
// PLAIN AUTH
|
// PLAIN AUTH
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
@@ -137,7 +132,6 @@ func TestAuth(t *testing.T) {
|
|||||||
func TestTLS(t *testing.T) {
|
func TestTLS(t *testing.T) {
|
||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
server := setupSMTPServer(ds, extension.NewHost())
|
server := setupSMTPServer(ds, extension.NewHost())
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
// Test Start TLS parsing.
|
// Test Start TLS parsing.
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
@@ -172,7 +166,6 @@ func TestReadyStateValidCommands(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.send, func(t *testing.T) {
|
t.Run(tc.send, func(t *testing.T) {
|
||||||
defer server.Drain()
|
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
{"HELO localhost", 250},
|
{"HELO localhost", 250},
|
||||||
tc,
|
tc,
|
||||||
@@ -196,7 +189,6 @@ func TestReadyStateRejectedDomains(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.send, func(t *testing.T) {
|
t.Run(tc.send, func(t *testing.T) {
|
||||||
defer server.Drain()
|
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
{"HELO localhost", 250},
|
{"HELO localhost", 250},
|
||||||
tc,
|
tc,
|
||||||
@@ -226,7 +218,6 @@ func TestReadyStateInvalidCommands(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.send, func(t *testing.T) {
|
t.Run(tc.send, func(t *testing.T) {
|
||||||
defer server.Drain()
|
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
{"HELO localhost", 250},
|
{"HELO localhost", 250},
|
||||||
tc,
|
tc,
|
||||||
@@ -240,7 +231,6 @@ func TestReadyStateInvalidCommands(t *testing.T) {
|
|||||||
func TestMailState(t *testing.T) {
|
func TestMailState(t *testing.T) {
|
||||||
mds := test.NewStore()
|
mds := test.NewStore()
|
||||||
server := setupSMTPServer(mds, extension.NewHost())
|
server := setupSMTPServer(mds, extension.NewHost())
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
// Test out some mangled READY commands
|
// Test out some mangled READY commands
|
||||||
script := []scriptStep{
|
script := []scriptStep{
|
||||||
@@ -333,7 +323,6 @@ func TestMailState(t *testing.T) {
|
|||||||
func TestDataState(t *testing.T) {
|
func TestDataState(t *testing.T) {
|
||||||
mds := test.NewStore()
|
mds := test.NewStore()
|
||||||
server := setupSMTPServer(mds, extension.NewHost())
|
server := setupSMTPServer(mds, extension.NewHost())
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
var script []scriptStep
|
var script []scriptStep
|
||||||
pipe := setupSMTPSession(t, server)
|
pipe := setupSMTPSession(t, server)
|
||||||
@@ -395,54 +384,11 @@ Hi!
|
|||||||
_, _, _ = c.ReadCodeLine(221)
|
_, _, _ = c.ReadCodeLine(221)
|
||||||
}
|
}
|
||||||
|
|
||||||
// playSession creates a new session, reads the greeting and then plays the script
|
|
||||||
func playSession(t *testing.T, server *Server, script []scriptStep) {
|
|
||||||
t.Helper()
|
|
||||||
pipe := setupSMTPSession(t, server)
|
|
||||||
c := textproto.NewConn(pipe)
|
|
||||||
|
|
||||||
if code, _, err := c.ReadCodeLine(220); err != nil {
|
|
||||||
t.Errorf("expected a 220 greeting, got %v", code)
|
|
||||||
}
|
|
||||||
|
|
||||||
playScriptAgainst(t, c, script)
|
|
||||||
|
|
||||||
// Not all tests leave the session in a clean state, so the following two calls can fail
|
|
||||||
_, _ = c.Cmd("QUIT")
|
|
||||||
_, _, _ = c.ReadCodeLine(221)
|
|
||||||
}
|
|
||||||
|
|
||||||
// playScriptAgainst an existing connection, does not handle server greeting
|
|
||||||
func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for i, step := range script {
|
|
||||||
id, err := c.Cmd(step.send)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Step %d, failed to send %q: %v", i, step.send, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.StartResponse(id)
|
|
||||||
code, msg, err := c.ReadResponse(step.expect)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q",
|
|
||||||
i, step.send, step.expect, code, msg)
|
|
||||||
}
|
|
||||||
c.EndResponse(id)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// Fail after c.EndResponse so we don't hang the connection
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests "MAIL FROM" emits BeforeMailAccepted event.
|
// Tests "MAIL FROM" emits BeforeMailAccepted event.
|
||||||
func TestBeforeMailAcceptedEventEmitted(t *testing.T) {
|
func TestBeforeMailAcceptedEventEmitted(t *testing.T) {
|
||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
extHost := extension.NewHost()
|
extHost := extension.NewHost()
|
||||||
server := setupSMTPServer(ds, extHost)
|
server := setupSMTPServer(ds, extHost)
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
var got *event.AddressParts
|
var got *event.AddressParts
|
||||||
extHost.Events.BeforeMailAccepted.AddListener(
|
extHost.Events.BeforeMailAccepted.AddListener(
|
||||||
@@ -469,7 +415,6 @@ func TestBeforeMailAcceptedEventResponse(t *testing.T) {
|
|||||||
ds := test.NewStore()
|
ds := test.NewStore()
|
||||||
extHost := extension.NewHost()
|
extHost := extension.NewHost()
|
||||||
server := setupSMTPServer(ds, extHost)
|
server := setupSMTPServer(ds, extHost)
|
||||||
defer server.Drain()
|
|
||||||
|
|
||||||
var shouldReturn *bool
|
var shouldReturn *bool
|
||||||
var gotEvent *event.AddressParts
|
var gotEvent *event.AddressParts
|
||||||
@@ -519,6 +464,48 @@ func TestBeforeMailAcceptedEventResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// playSession creates a new session, reads the greeting and then plays the script
|
||||||
|
func playSession(t *testing.T, server *Server, script []scriptStep) {
|
||||||
|
t.Helper()
|
||||||
|
pipe := setupSMTPSession(t, server)
|
||||||
|
c := textproto.NewConn(pipe)
|
||||||
|
|
||||||
|
if code, _, err := c.ReadCodeLine(220); err != nil {
|
||||||
|
t.Errorf("expected a 220 greeting, got %v", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
playScriptAgainst(t, c, script)
|
||||||
|
|
||||||
|
// Not all tests leave the session in a clean state, so the following two calls can fail
|
||||||
|
_, _ = c.Cmd("QUIT")
|
||||||
|
_, _, _ = c.ReadCodeLine(221)
|
||||||
|
}
|
||||||
|
|
||||||
|
// playScriptAgainst an existing connection, does not handle server greeting
|
||||||
|
func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i, step := range script {
|
||||||
|
id, err := c.Cmd(step.send)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Step %d, failed to send %q: %v", i, step.send, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.StartResponse(id)
|
||||||
|
code, msg, err := c.ReadResponse(step.expect)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q",
|
||||||
|
i, step.send, step.expect, code, msg)
|
||||||
|
}
|
||||||
|
c.EndResponse(id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Fail after c.EndResponse so we don't hang the connection
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// net.Pipe does not implement deadlines
|
// net.Pipe does not implement deadlines
|
||||||
type mockConn struct {
|
type mockConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
@@ -528,6 +515,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil }
|
|||||||
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
// Creates an unstarted smtp.Server.
|
||||||
func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server {
|
func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server {
|
||||||
cfg := &config.Root{
|
cfg := &config.Root{
|
||||||
MailboxNaming: config.FullNaming,
|
MailboxNaming: config.FullNaming,
|
||||||
@@ -543,7 +531,7 @@ func setupSMTPServer(ds storage.Store, extHost *extension.Host) *Server {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a server, don't start it.
|
// Create a server, but don't start it.
|
||||||
addrPolicy := &policy.Addressing{Config: cfg}
|
addrPolicy := &policy.Addressing{Config: cfg}
|
||||||
manager := &message.StoreManager{Store: ds, ExtHost: extHost}
|
manager := &message.StoreManager{Store: ds, ExtHost: extHost}
|
||||||
|
|
||||||
@@ -556,9 +544,15 @@ func setupSMTPSession(t *testing.T, server *Server) net.Conn {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
logger := zerolog.New(zerolog.NewTestWriter(t))
|
logger := zerolog.New(zerolog.NewTestWriter(t))
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
|
||||||
|
// Drain is required to prevent a test-logging data race. If a (failing) test run is
|
||||||
|
// hanging, this may be the culprit.
|
||||||
|
server.Drain()
|
||||||
|
})
|
||||||
|
|
||||||
// Start the session.
|
// Start the session.
|
||||||
server.wg.Add(1)
|
|
||||||
sessionNum++
|
sessionNum++
|
||||||
go server.startSession(sessionNum, &mockConn{serverConn}, logger)
|
go server.startSession(sessionNum, &mockConn{serverConn}, logger)
|
||||||
|
|
||||||
|
|||||||
@@ -176,8 +176,6 @@ func (s *Server) serve(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tempDelay = 0
|
tempDelay = 0
|
||||||
expConnectsTotal.Add(1)
|
|
||||||
s.wg.Add(1)
|
|
||||||
go s.startSession(sessionID, conn, log.Logger)
|
go s.startSession(sessionID, conn, log.Logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user