diff --git a/pkg/config/config.go b/pkg/config/config.go index a8452f1..3dc7549 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,6 +3,7 @@ package config import ( "log" "os" + "strings" "text/tabwriter" "time" @@ -81,6 +82,7 @@ type Storage struct { func Process() (*Root, error) { c := &Root{} err := envconfig.Process(prefix, c) + c.SMTP.DomainNoStore = strings.ToLower(c.SMTP.DomainNoStore) return c, err } diff --git a/pkg/server/smtp/handler.go b/pkg/server/smtp/handler.go index a86be5e..9d4f238 100644 --- a/pkg/server/smtp/handler.go +++ b/pkg/server/smtp/handler.go @@ -123,7 +123,7 @@ func (s *Server) startSession(id int, conn net.Conn) { if err := conn.Close(); err != nil { logger.Warn().Err(err).Msg("Closing connection") } - s.waitgroup.Done() + s.wg.Done() expConnectsCurrent.Add(-1) }() @@ -244,7 +244,7 @@ func (s *Session) greetHandler(cmd string, arg string) { s.remoteDomain = domain s.send("250-Great, let's get this show on the road") s.send("250-8BITMIME") - s.send(fmt.Sprintf("250 SIZE %v", s.server.maxMessageBytes)) + s.send(fmt.Sprintf("250 SIZE %v", s.server.config.MaxMessageBytes)) s.enterState(READY) default: s.ooSeq(cmd) @@ -296,7 +296,7 @@ func (s *Session) readyHandler(cmd string, arg string) { s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"]) return } - if int(size) > s.server.maxMessageBytes { + if int(size) > s.server.config.MaxMessageBytes { s.send("552 Max message size exceeded") s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"]) return @@ -323,15 +323,17 @@ func (s *Session) mailHandler(cmd string, arg string) { } // This trim is probably too forgiving addr := strings.Trim(arg[3:], "<> ") - recip, err := s.server.apolicy.NewRecipient(addr) + recip, err := s.server.addrPolicy.NewRecipient(addr) if err != nil { s.send("501 Bad recipient address syntax") s.logger.Warn().Msgf("Bad address as RCPT arg: %q, %s", addr, err) return } - if len(s.recipients) >= s.server.maxRecips { - s.logger.Warn().Msgf("Maximum limit of %v recipients reached", s.server.maxRecips) - s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", s.server.maxRecips)) + if len(s.recipients) >= s.server.config.MaxRecipients { + s.logger.Warn().Msgf("Maximum limit of %v recipients reached", + s.server.config.MaxRecipients) + s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", + s.server.config.MaxRecipients)) return } s.recipients = append(s.recipients, recip) @@ -379,7 +381,7 @@ func (s *Session) dataHandler() { if recip.ShouldStore() { // Generate Received header. prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n", - s.remoteDomain, s.remoteHost, s.server.domain, recip.Address.Address, + s.remoteDomain, s.remoteHost, s.server.config.Domain, recip.Address.Address, tstamp) // Deliver message. _, err := s.server.manager.Deliver( @@ -403,7 +405,7 @@ func (s *Session) dataHandler() { lineBuf = lineBuf[1:] } msgBuf.Write(lineBuf) - if msgBuf.Len() > s.server.maxMessageBytes { + if msgBuf.Len() > s.server.config.MaxMessageBytes { s.send("552 Maximum message size exceeded") s.logger.Warn().Msgf("Max message size exceeded while in DATA") s.reset() @@ -418,12 +420,12 @@ func (s *Session) enterState(state State) { } func (s *Session) greet() { - s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.domain)) + s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.config.Domain)) } -// Calculate the next read or write deadline based on maxIdle +// nextDeadline calculates the next read or write deadline based on configured timeout. func (s *Session) nextDeadline() time.Time { - return time.Now().Add(s.server.timeout) + return time.Now().Add(s.server.config.Timeout) } // Send requested message, store errors in Session.sendError diff --git a/pkg/server/smtp/handler_test.go b/pkg/server/smtp/handler_test.go index 27fa663..e1b75d8 100644 --- a/pkg/server/smtp/handler_test.go +++ b/pkg/server/smtp/handler_test.go @@ -392,7 +392,7 @@ func setupSMTPSession(server *Server) net.Conn { // Pair of pipes to communicate serverConn, clientConn := net.Pipe() // Start the session - server.waitgroup.Add(1) + server.wg.Add(1) sessionNum++ go server.startSession(sessionNum, &mockConn{serverConn}) diff --git a/pkg/server/smtp/listener.go b/pkg/server/smtp/listener.go index e3a63e2..3741eef 100644 --- a/pkg/server/smtp/listener.go +++ b/pkg/server/smtp/listener.go @@ -5,7 +5,6 @@ import ( "context" "expvar" "net" - "strings" "sync" "time" @@ -16,26 +15,6 @@ import ( "github.com/rs/zerolog/log" ) -func init() { - m := expvar.NewMap("smtp") - m.Set("ConnectsTotal", expConnectsTotal) - m.Set("ConnectsHist", expConnectsHist) - m.Set("ConnectsCurrent", expConnectsCurrent) - m.Set("ReceivedTotal", expReceivedTotal) - m.Set("ReceivedHist", expReceivedHist) - m.Set("ErrorsTotal", expErrorsTotal) - m.Set("ErrorsHist", expErrorsHist) - m.Set("WarnsTotal", expWarnsTotal) - m.Set("WarnsHist", expWarnsHist) - - metric.AddTickerFunc(func() { - expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal)) - expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal)) - expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal)) - expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal)) - }) -} - var ( // Raw stat collectors expConnectsTotal = new(expvar.Int) @@ -57,56 +36,55 @@ var ( expWarnsHist = new(expvar.String) ) -// Server holds the configuration and state of our SMTP server -type Server struct { - // TODO(#91) Refactor config items out of this struct - config config.SMTP - // Configuration - host string - domain string - domainNoStore string - maxRecips int - maxMessageBytes int - storeMessages bool - timeout time.Duration - - // Dependencies - apolicy *policy.Addressing // Address policy. - globalShutdown chan bool // Shuts down Inbucket. - manager message.Manager // Used to deliver messages. - - // State - listener net.Listener // Incoming network connections - waitgroup *sync.WaitGroup // Waitgroup tracks individual sessions +func init() { + m := expvar.NewMap("smtp") + m.Set("ConnectsTotal", expConnectsTotal) + m.Set("ConnectsHist", expConnectsHist) + m.Set("ConnectsCurrent", expConnectsCurrent) + m.Set("ReceivedTotal", expReceivedTotal) + m.Set("ReceivedHist", expReceivedHist) + m.Set("ErrorsTotal", expErrorsTotal) + m.Set("ErrorsHist", expErrorsHist) + m.Set("WarnsTotal", expWarnsTotal) + m.Set("WarnsHist", expWarnsHist) + metric.AddTickerFunc(func() { + expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal)) + expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal)) + expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal)) + expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal)) + }) } -// NewServer creates a new Server instance with the specificed config +// Server holds the configuration and state of our SMTP server. +type Server struct { + config config.SMTP // SMTP configuration. + addrPolicy *policy.Addressing // Address policy. + globalShutdown chan bool // Shuts down Inbucket. + manager message.Manager // Used to deliver messages. + listener net.Listener // Incoming network connections. + wg *sync.WaitGroup // Waitgroup tracks individual sessions. +} + +// NewServer creates a new Server instance with the specificed config. func NewServer( - cfg config.SMTP, + smtpConfig config.SMTP, globalShutdown chan bool, manager message.Manager, apolicy *policy.Addressing, ) *Server { return &Server{ - config: cfg, - host: cfg.Addr, - domain: cfg.Domain, - domainNoStore: strings.ToLower(cfg.DomainNoStore), - maxRecips: cfg.MaxRecipients, - timeout: cfg.Timeout, - maxMessageBytes: cfg.MaxMessageBytes, - storeMessages: cfg.StoreMessages, - globalShutdown: globalShutdown, - manager: manager, - apolicy: apolicy, - waitgroup: new(sync.WaitGroup), + config: smtpConfig, + globalShutdown: globalShutdown, + manager: manager, + addrPolicy: apolicy, + wg: new(sync.WaitGroup), } } // Start the listener and handle incoming connections. func (s *Server) Start(ctx context.Context) { slog := log.With().Str("module", "smtp").Str("phase", "startup").Logger() - addr, err := net.ResolveTCPAddr("tcp4", s.host) + addr, err := net.ResolveTCPAddr("tcp4", s.config.Addr) if err != nil { slog.Error().Err(err).Msg("Failed to build tcp4 address") s.emergencyShutdown() @@ -119,10 +97,10 @@ func (s *Server) Start(ctx context.Context) { s.emergencyShutdown() return } - if !s.storeMessages { + if !s.config.StoreMessages { slog.Info().Msg("Load test mode active, messages will not be stored") - } else if s.domainNoStore != "" { - slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.domainNoStore) + } else if s.config.DomainNoStore != "" { + slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.config.DomainNoStore) } // Listener go routine. go s.serve(ctx) @@ -172,7 +150,7 @@ func (s *Server) serve(ctx context.Context) { } else { tempDelay = 0 expConnectsTotal.Add(1) - s.waitgroup.Add(1) + s.wg.Add(1) go s.startSession(sessionID, conn) } } @@ -190,6 +168,6 @@ func (s *Server) emergencyShutdown() { // Drain causes the caller to block until all active SMTP sessions have finished func (s *Server) Drain() { // Wait for sessions to close. - s.waitgroup.Wait() + s.wg.Wait() log.Debug().Str("module", "smtp").Str("phase", "shutdown").Msg("SMTP connections have drained") }