1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-18 18:17:03 +00:00

chore: refactor smtp/handler if-else chain (#504)

* chore: convert smtp/handler if-else chain to switch-case

Signed-off-by: James Hillyerd <james@hillyerd.com>

* chore: extract long case into parseMailCmd func

Signed-off-by: James Hillyerd <james@hillyerd.com>

* chore: remove extraneous braces in cases

Signed-off-by: James Hillyerd <james@hillyerd.com>

---------

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2024-02-19 17:30:37 -08:00
committed by GitHub
parent 40ec108daf
commit c59e793775

View File

@@ -330,7 +330,8 @@ func (s *Session) passwordHandler() {
// READY state -> waiting for MAIL // READY state -> waiting for MAIL
// AUTH can change // AUTH can change
func (s *Session) readyHandler(cmd string, arg string) { func (s *Session) readyHandler(cmd string, arg string) {
if cmd == "STARTTLS" { switch cmd {
case "STARTTLS":
if !s.Server.config.TLSEnabled { if !s.Server.config.TLSEnabled {
// Invalid command since TLS unconfigured. // Invalid command since TLS unconfigured.
s.logger.Debug().Msgf("454 TLS unavailable on the server") s.logger.Debug().Msgf("454 TLS unavailable on the server")
@@ -353,12 +354,12 @@ func (s *Session) readyHandler(cmd string, arg string) {
s.tlsState = new(tls.ConnectionState) s.tlsState = new(tls.ConnectionState)
*s.tlsState = tlsConn.ConnectionState() *s.tlsState = tlsConn.ConnectionState()
s.enterState(GREET) s.enterState(GREET)
} else if cmd == "AUTH" {
case "AUTH":
args := strings.SplitN(arg, " ", 3) args := strings.SplitN(arg, " ", 3)
authMethod := args[0] authMethod := args[0]
switch authMethod { switch authMethod {
case "PLAIN": case "PLAIN":
{
if len(args) != 2 { if len(args) != 2 {
s.send("500 Bad auth arguments") s.send("500 Bad auth arguments")
s.logger.Warn().Msgf("Bad auth attempt: %q", arg) s.logger.Warn().Msgf("Bad auth attempt: %q", arg)
@@ -367,20 +368,33 @@ func (s *Session) readyHandler(cmd string, arg string) {
s.logger.Info().Msgf("Accepting credentials: %q", args[1]) s.logger.Info().Msgf("Accepting credentials: %q", args[1])
s.send("235 2.7.0 Authentication successful") s.send("235 2.7.0 Authentication successful")
return return
}
case "LOGIN": case "LOGIN":
{
s.send(fmt.Sprintf("334 %v", usernameChallenge)) s.send(fmt.Sprintf("334 %v", usernameChallenge))
s.enterState(LOGIN) s.enterState(LOGIN)
return return
}
default: default:
{
s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod)) s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod))
return return
} }
case "MAIL":
s.parseMailFromCmd(arg)
case "EHLO":
// Reset session
s.logger.Debug().Msgf("Resetting session state on EHLO request")
s.reset()
s.send("250 Session reset")
default:
s.ooSeq(cmd)
} }
} else if cmd == "MAIL" { }
// Parses `MAIL FROM` command.
func (s *Session) parseMailFromCmd(arg string) {
// Capture group 1: from address. 2: optional params. // Capture group 1: from address. 2: optional params.
m := fromRegex.FindStringSubmatch(arg) m := fromRegex.FindStringSubmatch(arg)
if m == nil { if m == nil {
@@ -390,9 +404,10 @@ func (s *Session) readyHandler(cmd string, arg string) {
} }
from := m[1] from := m[1]
s.logger.Debug().Msgf("Mail sender is %v", from) s.logger.Debug().Msgf("Mail sender is %v", from)
// Parse from address.
localpart, domain, err := policy.ParseEmailAddress(from) localpart, domain, err := policy.ParseEmailAddress(from)
s.logger.Debug().Msgf("Origin domain is %v", domain) s.logger.Debug().Msgf("Origin domain is %v", domain)
if from != "" && err != nil { if from != "" && err != nil {
s.send("501 Bad sender address syntax") s.send("501 Bad sender address syntax")
s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err) s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err)
@@ -402,15 +417,18 @@ func (s *Session) readyHandler(cmd string, arg string) {
from = "unspecified" from = "unspecified"
} }
// This is where the client may put BODY=8BITMIME, but we already // Parse ESMTP parameters.
// read the DATA as bytes, so it does not effect our processing.
if m[2] != "" { if m[2] != "" {
// Here the client may put BODY=8BITMIME, but Inbucket already
// reads the DATA as bytes, so it does not effect mail processing.
args, ok := s.parseArgs(m[2]) args, ok := s.parseArgs(m[2])
if !ok { if !ok {
s.send("501 Unable to parse MAIL ESMTP parameters") s.send("501 Unable to parse MAIL ESMTP parameters")
s.logger.Warn().Msgf("Bad MAIL argument: %q", arg) s.logger.Warn().Msgf("Bad MAIL argument: %q", arg)
return return
} }
// Reject oversized messages.
if args["SIZE"] != "" { if args["SIZE"] != "" {
size, err := strconv.ParseInt(args["SIZE"], 10, 32) size, err := strconv.ParseInt(args["SIZE"], 10, 32)
if err != nil { if err != nil {
@@ -429,9 +447,13 @@ func (s *Session) readyHandler(cmd string, arg string) {
// Process through extensions. // Process through extensions.
extResult := s.extHost.Events.BeforeMailAccepted.Emit( extResult := s.extHost.Events.BeforeMailAccepted.Emit(
&event.AddressParts{Local: localpart, Domain: domain}) &event.AddressParts{Local: localpart, Domain: domain})
if extResult != nil && !*extResult {
s.send("550 Mail denied by policy")
s.logger.Warn().Msgf("Extension denied mail from <%v>", from)
return
}
if extResult == nil || *extResult { // Sender was permitted by an extension, or no extension rejected it.
// Permitted by extension, or none had an opinion.
origin, err := s.addrPolicy.ParseOrigin(from) origin, err := s.addrPolicy.ParseOrigin(from)
if err != nil { if err != nil {
s.send("501 Bad origin address syntax") s.send("501 Bad origin address syntax")
@@ -445,22 +467,10 @@ func (s *Session) readyHandler(cmd string, arg string) {
return return
} }
// Ok to transition to MAIL state.
s.logger.Info().Msgf("Mail from: %v", from) s.logger.Info().Msgf("Mail from: %v", from)
s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from))
s.enterState(MAIL) s.enterState(MAIL)
} else {
s.send("550 Mail denied by policy")
s.logger.Warn().Msgf("Extension denied mail from <%v>", from)
return
}
} else if cmd == "EHLO" {
// Reset session
s.logger.Debug().Msgf("Resetting session state on EHLO request")
s.reset()
s.send("250 Session reset")
} else {
s.ooSeq(cmd)
}
} }
// MAIL state -> waiting for RCPTs followed by DATA // MAIL state -> waiting for RCPTs followed by DATA