1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 17:47:03 +00:00

chore: refactor smtp/handler if-else chain (#504)

* chore: convert smtp/handler if-else chain to switch-case

Signed-off-by: James Hillyerd <james@hillyerd.com>

* chore: extract long case into parseMailCmd func

Signed-off-by: James Hillyerd <james@hillyerd.com>

* chore: remove extraneous braces in cases

Signed-off-by: James Hillyerd <james@hillyerd.com>

---------

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2024-02-19 17:30:37 -08:00
committed by GitHub
parent 40ec108daf
commit c59e793775

View File

@@ -330,7 +330,8 @@ func (s *Session) passwordHandler() {
// READY state -> waiting for MAIL
// AUTH can change
func (s *Session) readyHandler(cmd string, arg string) {
if cmd == "STARTTLS" {
switch cmd {
case "STARTTLS":
if !s.Server.config.TLSEnabled {
// Invalid command since TLS unconfigured.
s.logger.Debug().Msgf("454 TLS unavailable on the server")
@@ -353,116 +354,125 @@ func (s *Session) readyHandler(cmd string, arg string) {
s.tlsState = new(tls.ConnectionState)
*s.tlsState = tlsConn.ConnectionState()
s.enterState(GREET)
} else if cmd == "AUTH" {
case "AUTH":
args := strings.SplitN(arg, " ", 3)
authMethod := args[0]
switch authMethod {
case "PLAIN":
{
if len(args) != 2 {
s.send("500 Bad auth arguments")
s.logger.Warn().Msgf("Bad auth attempt: %q", arg)
return
}
s.logger.Info().Msgf("Accepting credentials: %q", args[1])
s.send("235 2.7.0 Authentication successful")
if len(args) != 2 {
s.send("500 Bad auth arguments")
s.logger.Warn().Msgf("Bad auth attempt: %q", arg)
return
}
s.logger.Info().Msgf("Accepting credentials: %q", args[1])
s.send("235 2.7.0 Authentication successful")
return
case "LOGIN":
{
s.send(fmt.Sprintf("334 %v", usernameChallenge))
s.enterState(LOGIN)
return
}
s.send(fmt.Sprintf("334 %v", usernameChallenge))
s.enterState(LOGIN)
return
default:
{
s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod))
return
}
}
} else if cmd == "MAIL" {
// Capture group 1: from address. 2: optional params.
m := fromRegex.FindStringSubmatch(arg)
if m == nil {
s.send("501 Was expecting MAIL arg syntax of FROM:<address>")
s.logger.Warn().Msgf("Bad MAIL argument: %q", arg)
s.send(fmt.Sprintf("500 Unsupported AUTH method: %v", authMethod))
return
}
from := m[1]
s.logger.Debug().Msgf("Mail sender is %v", from)
localpart, domain, err := policy.ParseEmailAddress(from)
s.logger.Debug().Msgf("Origin domain is %v", domain)
if from != "" && err != nil {
s.send("501 Bad sender address syntax")
s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err)
return
}
if from == "" {
from = "unspecified"
}
case "MAIL":
s.parseMailFromCmd(arg)
// This is where the client may put BODY=8BITMIME, but we already
// read the DATA as bytes, so it does not effect our processing.
if m[2] != "" {
args, ok := s.parseArgs(m[2])
if !ok {
s.send("501 Unable to parse MAIL ESMTP parameters")
s.logger.Warn().Msgf("Bad MAIL argument: %q", arg)
return
}
if args["SIZE"] != "" {
size, err := strconv.ParseInt(args["SIZE"], 10, 32)
if err != nil {
s.send("501 Unable to parse SIZE as an integer")
s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"])
return
}
if int(size) > s.config.MaxMessageBytes {
s.send("552 Max message size exceeded")
s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"])
return
}
}
}
// Process through extensions.
extResult := s.extHost.Events.BeforeMailAccepted.Emit(
&event.AddressParts{Local: localpart, Domain: domain})
if extResult == nil || *extResult {
// Permitted by extension, or none had an opinion.
origin, err := s.addrPolicy.ParseOrigin(from)
if err != nil {
s.send("501 Bad origin address syntax")
s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg")
return
}
s.from = origin
if !s.from.ShouldAccept() {
s.send("501 Unauthorized domain")
s.logger.Warn().Msgf("Bad domain sender %s", domain)
return
}
s.logger.Info().Msgf("Mail from: %v", from)
s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from))
s.enterState(MAIL)
} else {
s.send("550 Mail denied by policy")
s.logger.Warn().Msgf("Extension denied mail from <%v>", from)
return
}
} else if cmd == "EHLO" {
case "EHLO":
// Reset session
s.logger.Debug().Msgf("Resetting session state on EHLO request")
s.reset()
s.send("250 Session reset")
} else {
default:
s.ooSeq(cmd)
}
}
// Parses `MAIL FROM` command.
func (s *Session) parseMailFromCmd(arg string) {
// Capture group 1: from address. 2: optional params.
m := fromRegex.FindStringSubmatch(arg)
if m == nil {
s.send("501 Was expecting MAIL arg syntax of FROM:<address>")
s.logger.Warn().Msgf("Bad MAIL argument: %q", arg)
return
}
from := m[1]
s.logger.Debug().Msgf("Mail sender is %v", from)
// Parse from address.
localpart, domain, err := policy.ParseEmailAddress(from)
s.logger.Debug().Msgf("Origin domain is %v", domain)
if from != "" && err != nil {
s.send("501 Bad sender address syntax")
s.logger.Warn().Msgf("Bad address as MAIL arg: %q, %s", from, err)
return
}
if from == "" {
from = "unspecified"
}
// Parse ESMTP parameters.
if m[2] != "" {
// Here the client may put BODY=8BITMIME, but Inbucket already
// reads the DATA as bytes, so it does not effect mail processing.
args, ok := s.parseArgs(m[2])
if !ok {
s.send("501 Unable to parse MAIL ESMTP parameters")
s.logger.Warn().Msgf("Bad MAIL argument: %q", arg)
return
}
// Reject oversized messages.
if args["SIZE"] != "" {
size, err := strconv.ParseInt(args["SIZE"], 10, 32)
if err != nil {
s.send("501 Unable to parse SIZE as an integer")
s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"])
return
}
if int(size) > s.config.MaxMessageBytes {
s.send("552 Max message size exceeded")
s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"])
return
}
}
}
// Process through extensions.
extResult := s.extHost.Events.BeforeMailAccepted.Emit(
&event.AddressParts{Local: localpart, Domain: domain})
if extResult != nil && !*extResult {
s.send("550 Mail denied by policy")
s.logger.Warn().Msgf("Extension denied mail from <%v>", from)
return
}
// Sender was permitted by an extension, or no extension rejected it.
origin, err := s.addrPolicy.ParseOrigin(from)
if err != nil {
s.send("501 Bad origin address syntax")
s.logger.Warn().Str("from", from).Err(err).Msg("Bad address as MAIL arg")
return
}
s.from = origin
if !s.from.ShouldAccept() {
s.send("501 Unauthorized domain")
s.logger.Warn().Msgf("Bad domain sender %s", domain)
return
}
// Ok to transition to MAIL state.
s.logger.Info().Msgf("Mail from: %v", from)
s.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from))
s.enterState(MAIL)
}
// MAIL state -> waiting for RCPTs followed by DATA
func (s *Session) mailHandler(cmd string, arg string) {
switch cmd {