From c59e7937756c687e5107e93e7160c8a9a32b43f0 Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Mon, 19 Feb 2024 17:30:37 -0800 Subject: [PATCH] chore: refactor smtp/handler if-else chain (#504) * chore: convert smtp/handler if-else chain to switch-case Signed-off-by: James Hillyerd * chore: extract long case into parseMailCmd func Signed-off-by: James Hillyerd * chore: remove extraneous braces in cases Signed-off-by: James Hillyerd --------- Signed-off-by: James Hillyerd --- pkg/server/smtp/handler.go | 192 +++++++++++++++++++------------------ 1 file changed, 101 insertions(+), 91 deletions(-) diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index 1281b5a..cdc1057 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -330,7 +330,8 @@ func (s *Session) passwordHandler() { // READY state -> waiting for MAIL // AUTH can change func (s *Session) readyHandler(cmd string, arg string) { - if cmd == "STARTTLS" { + switch cmd { + case "STARTTLS": if !s.Server.config.TLSEnabled { // Invalid command since TLS unconfigured. s.logger.Debug().Msgf("454 TLS unavailable on the server") @@ -353,116 +354,125 @@ func (s *Session) readyHandler(cmd string, arg string) { s.tlsState = new(tls.ConnectionState) *s.tlsState = tlsConn.ConnectionState() s.enterState(GREET) - } else if cmd == "AUTH" { + + case "AUTH": args := strings.SplitN(arg, " ", 3) authMethod := args[0] switch authMethod { case "PLAIN": - { - if len(args) != 2 { - s.send("500 Bad auth arguments") - s.logger.Warn().Msgf("Bad auth attempt: %q", arg) - return - } - s.logger.Info().Msgf("Accepting credentials: %q", args[1]) - s.send("235 2.7.0 Authentication successful") + if len(args) != 2 { + s.send("500 Bad auth arguments") + s.logger.Warn().Msgf("Bad auth attempt: %q", arg) return } + s.logger.Info().Msgf("Accepting credentials: %q", args[1]) + s.send("235 2.7.0 Authentication successful") + return + case "LOGIN": - { - s.send(fmt.Sprintf("334 %v", usernameChallenge)) - s.enterState(LOGIN) - return - } + s.send(fmt.Sprintf("334 %v", usernameChallenge)) + s.enterState(LOGIN) + return + default: - { - s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod)) - return - } - } - } else if cmd == "MAIL" { - // Capture group 1: from address. 2: optional params. - m := fromRegex.FindStringSubmatch(arg) - if m == nil { - s.send("501 Was expecting MAIL arg syntax of FROM:
") - s.logger.Warn().Msgf("Bad MAIL argument: %q", arg) + s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod)) return } - from := m[1] - s.logger.Debug().Msgf("Mail sender is %v", from) - localpart, domain, err := policy.ParseEmailAddress(from) - s.logger.Debug().Msgf("Origin domain is %v", domain) - if from != "" && err != nil { - s.send("501 Bad sender address syntax") - s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err) - return - } - if from == "" { - from = "unspecified" - } + case "MAIL": + s.parseMailFromCmd(arg) - // This is where the client may put BODY=8BITMIME, but we already - // read the DATA as bytes, so it does not effect our processing. - if m[2] != "" { - args, ok := s.parseArgs(m[2]) - if !ok { - s.send("501 Unable to parse MAIL ESMTP parameters") - s.logger.Warn().Msgf("Bad MAIL argument: %q", arg) - return - } - if args["SIZE"] != "" { - size, err := strconv.ParseInt(args["SIZE"], 10, 32) - if err != nil { - s.send("501 Unable to parse SIZE as an integer") - s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"]) - return - } - if int(size) > s.config.MaxMessageBytes { - s.send("552 Max message size exceeded") - s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"]) - return - } - } - } - - // Process through extensions. - extResult := s.extHost.Events.BeforeMailAccepted.Emit( - &event.AddressParts{Local: localpart, Domain: domain}) - - if extResult == nil || *extResult { - // Permitted by extension, or none had an opinion. - origin, err := s.addrPolicy.ParseOrigin(from) - if err != nil { - s.send("501 Bad origin address syntax") - s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg") - return - } - s.from = origin - if !s.from.ShouldAccept() { - s.send("501 Unauthorized domain") - s.logger.Warn().Msgf("Bad domain sender %s", domain) - return - } - - s.logger.Info().Msgf("Mail from: %v", from) - s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) - s.enterState(MAIL) - } else { - s.send("550 Mail denied by policy") - s.logger.Warn().Msgf("Extension denied mail from <%v>", from) - return - } - } else if cmd == "EHLO" { + case "EHLO": // Reset session s.logger.Debug().Msgf("Resetting session state on EHLO request") s.reset() s.send("250 Session reset") - } else { + + default: s.ooSeq(cmd) } } +// Parses `MAIL FROM` command. +func (s *Session) parseMailFromCmd(arg string) { + // Capture group 1: from address. 2: optional params. + m := fromRegex.FindStringSubmatch(arg) + if m == nil { + s.send("501 Was expecting MAIL arg syntax of FROM:
") + s.logger.Warn().Msgf("Bad MAIL argument: %q", arg) + return + } + from := m[1] + s.logger.Debug().Msgf("Mail sender is %v", from) + + // Parse from address. + localpart, domain, err := policy.ParseEmailAddress(from) + s.logger.Debug().Msgf("Origin domain is %v", domain) + if from != "" && err != nil { + s.send("501 Bad sender address syntax") + s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err) + return + } + if from == "" { + from = "unspecified" + } + + // Parse ESMTP parameters. + if m[2] != "" { + // Here the client may put BODY=8BITMIME, but Inbucket already + // reads the DATA as bytes, so it does not effect mail processing. + args, ok := s.parseArgs(m[2]) + if !ok { + s.send("501 Unable to parse MAIL ESMTP parameters") + s.logger.Warn().Msgf("Bad MAIL argument: %q", arg) + return + } + + // Reject oversized messages. + if args["SIZE"] != "" { + size, err := strconv.ParseInt(args["SIZE"], 10, 32) + if err != nil { + s.send("501 Unable to parse SIZE as an integer") + s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"]) + return + } + if int(size) > s.config.MaxMessageBytes { + s.send("552 Max message size exceeded") + s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"]) + return + } + } + } + + // Process through extensions. + extResult := s.extHost.Events.BeforeMailAccepted.Emit( + &event.AddressParts{Local: localpart, Domain: domain}) + if extResult != nil && !*extResult { + s.send("550 Mail denied by policy") + s.logger.Warn().Msgf("Extension denied mail from <%v>", from) + return + } + + // Sender was permitted by an extension, or no extension rejected it. + origin, err := s.addrPolicy.ParseOrigin(from) + if err != nil { + s.send("501 Bad origin address syntax") + s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg") + return + } + s.from = origin + if !s.from.ShouldAccept() { + s.send("501 Unauthorized domain") + s.logger.Warn().Msgf("Bad domain sender %s", domain) + return + } + + // Ok to transition to MAIL state. + s.logger.Info().Msgf("Mail from: %v", from) + s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) + s.enterState(MAIL) +} + // MAIL state -> waiting for RCPTs followed by DATA func (s *Session) mailHandler(cmd string, arg string) { switch cmd {