1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 09:37:02 +00:00

smtp: Use config.SMTP directly in Server #91

This commit is contained in:
James Hillyerd
2018-03-31 16:49:52 -07:00
parent acd48773da
commit 2c813081eb
4 changed files with 57 additions and 75 deletions

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"log" "log"
"os" "os"
"strings"
"text/tabwriter" "text/tabwriter"
"time" "time"
@@ -81,6 +82,7 @@ type Storage struct {
func Process() (*Root, error) { func Process() (*Root, error) {
c := &Root{} c := &Root{}
err := envconfig.Process(prefix, c) err := envconfig.Process(prefix, c)
c.SMTP.DomainNoStore = strings.ToLower(c.SMTP.DomainNoStore)
return c, err return c, err
} }

View File

@@ -123,7 +123,7 @@ func (s *Server) startSession(id int, conn net.Conn) {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
logger.Warn().Err(err).Msg("Closing connection") logger.Warn().Err(err).Msg("Closing connection")
} }
s.waitgroup.Done() s.wg.Done()
expConnectsCurrent.Add(-1) expConnectsCurrent.Add(-1)
}() }()
@@ -244,7 +244,7 @@ func (s *Session) greetHandler(cmd string, arg string) {
s.remoteDomain = domain s.remoteDomain = domain
s.send("250-Great, let's get this show on the road") s.send("250-Great, let's get this show on the road")
s.send("250-8BITMIME") s.send("250-8BITMIME")
s.send(fmt.Sprintf("250 SIZE %v", s.server.maxMessageBytes)) s.send(fmt.Sprintf("250 SIZE %v", s.server.config.MaxMessageBytes))
s.enterState(READY) s.enterState(READY)
default: default:
s.ooSeq(cmd) s.ooSeq(cmd)
@@ -296,7 +296,7 @@ func (s *Session) readyHandler(cmd string, arg string) {
s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"]) s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"])
return return
} }
if int(size) > s.server.maxMessageBytes { if int(size) > s.server.config.MaxMessageBytes {
s.send("552 Max message size exceeded") s.send("552 Max message size exceeded")
s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"]) s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"])
return return
@@ -323,15 +323,17 @@ func (s *Session) mailHandler(cmd string, arg string) {
} }
// This trim is probably too forgiving // This trim is probably too forgiving
addr := strings.Trim(arg[3:], "<> ") addr := strings.Trim(arg[3:], "<> ")
recip, err := s.server.apolicy.NewRecipient(addr) recip, err := s.server.addrPolicy.NewRecipient(addr)
if err != nil { if err != nil {
s.send("501 Bad recipient address syntax") s.send("501 Bad recipient address syntax")
s.logger.Warn().Msgf("Bad address as RCPT arg: %q, %s", addr, err) s.logger.Warn().Msgf("Bad address as RCPT arg: %q, %s", addr, err)
return return
} }
if len(s.recipients) >= s.server.maxRecips { if len(s.recipients) >= s.server.config.MaxRecipients {
s.logger.Warn().Msgf("Maximum limit of %v recipients reached", s.server.maxRecips) s.logger.Warn().Msgf("Maximum limit of %v recipients reached",
s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", s.server.maxRecips)) s.server.config.MaxRecipients)
s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached",
s.server.config.MaxRecipients))
return return
} }
s.recipients = append(s.recipients, recip) s.recipients = append(s.recipients, recip)
@@ -379,7 +381,7 @@ func (s *Session) dataHandler() {
if recip.ShouldStore() { if recip.ShouldStore() {
// Generate Received header. // Generate Received header.
prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n", prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
s.remoteDomain, s.remoteHost, s.server.domain, recip.Address.Address, s.remoteDomain, s.remoteHost, s.server.config.Domain, recip.Address.Address,
tstamp) tstamp)
// Deliver message. // Deliver message.
_, err := s.server.manager.Deliver( _, err := s.server.manager.Deliver(
@@ -403,7 +405,7 @@ func (s *Session) dataHandler() {
lineBuf = lineBuf[1:] lineBuf = lineBuf[1:]
} }
msgBuf.Write(lineBuf) msgBuf.Write(lineBuf)
if msgBuf.Len() > s.server.maxMessageBytes { if msgBuf.Len() > s.server.config.MaxMessageBytes {
s.send("552 Maximum message size exceeded") s.send("552 Maximum message size exceeded")
s.logger.Warn().Msgf("Max message size exceeded while in DATA") s.logger.Warn().Msgf("Max message size exceeded while in DATA")
s.reset() s.reset()
@@ -418,12 +420,12 @@ func (s *Session) enterState(state State) {
} }
func (s *Session) greet() { func (s *Session) greet() {
s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.domain)) s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.config.Domain))
} }
// Calculate the next read or write deadline based on maxIdle // nextDeadline calculates the next read or write deadline based on configured timeout.
func (s *Session) nextDeadline() time.Time { func (s *Session) nextDeadline() time.Time {
return time.Now().Add(s.server.timeout) return time.Now().Add(s.server.config.Timeout)
} }
// Send requested message, store errors in Session.sendError // Send requested message, store errors in Session.sendError

View File

@@ -392,7 +392,7 @@ func setupSMTPSession(server *Server) net.Conn {
// Pair of pipes to communicate // Pair of pipes to communicate
serverConn, clientConn := net.Pipe() serverConn, clientConn := net.Pipe()
// Start the session // Start the session
server.waitgroup.Add(1) server.wg.Add(1)
sessionNum++ sessionNum++
go server.startSession(sessionNum, &mockConn{serverConn}) go server.startSession(sessionNum, &mockConn{serverConn})

View File

@@ -5,7 +5,6 @@ import (
"context" "context"
"expvar" "expvar"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@@ -16,26 +15,6 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func init() {
m := expvar.NewMap("smtp")
m.Set("ConnectsTotal", expConnectsTotal)
m.Set("ConnectsHist", expConnectsHist)
m.Set("ConnectsCurrent", expConnectsCurrent)
m.Set("ReceivedTotal", expReceivedTotal)
m.Set("ReceivedHist", expReceivedHist)
m.Set("ErrorsTotal", expErrorsTotal)
m.Set("ErrorsHist", expErrorsHist)
m.Set("WarnsTotal", expWarnsTotal)
m.Set("WarnsHist", expWarnsHist)
metric.AddTickerFunc(func() {
expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal))
expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal))
expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal))
expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal))
})
}
var ( var (
// Raw stat collectors // Raw stat collectors
expConnectsTotal = new(expvar.Int) expConnectsTotal = new(expvar.Int)
@@ -57,56 +36,55 @@ var (
expWarnsHist = new(expvar.String) expWarnsHist = new(expvar.String)
) )
// Server holds the configuration and state of our SMTP server func init() {
type Server struct { m := expvar.NewMap("smtp")
// TODO(#91) Refactor config items out of this struct m.Set("ConnectsTotal", expConnectsTotal)
config config.SMTP m.Set("ConnectsHist", expConnectsHist)
// Configuration m.Set("ConnectsCurrent", expConnectsCurrent)
host string m.Set("ReceivedTotal", expReceivedTotal)
domain string m.Set("ReceivedHist", expReceivedHist)
domainNoStore string m.Set("ErrorsTotal", expErrorsTotal)
maxRecips int m.Set("ErrorsHist", expErrorsHist)
maxMessageBytes int m.Set("WarnsTotal", expWarnsTotal)
storeMessages bool m.Set("WarnsHist", expWarnsHist)
timeout time.Duration metric.AddTickerFunc(func() {
expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal))
// Dependencies expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal))
apolicy *policy.Addressing // Address policy. expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal))
globalShutdown chan bool // Shuts down Inbucket. expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal))
manager message.Manager // Used to deliver messages. })
// State
listener net.Listener // Incoming network connections
waitgroup *sync.WaitGroup // Waitgroup tracks individual sessions
} }
// NewServer creates a new Server instance with the specificed config // Server holds the configuration and state of our SMTP server.
type Server struct {
config config.SMTP // SMTP configuration.
addrPolicy *policy.Addressing // Address policy.
globalShutdown chan bool // Shuts down Inbucket.
manager message.Manager // Used to deliver messages.
listener net.Listener // Incoming network connections.
wg *sync.WaitGroup // Waitgroup tracks individual sessions.
}
// NewServer creates a new Server instance with the specificed config.
func NewServer( func NewServer(
cfg config.SMTP, smtpConfig config.SMTP,
globalShutdown chan bool, globalShutdown chan bool,
manager message.Manager, manager message.Manager,
apolicy *policy.Addressing, apolicy *policy.Addressing,
) *Server { ) *Server {
return &Server{ return &Server{
config: cfg, config: smtpConfig,
host: cfg.Addr, globalShutdown: globalShutdown,
domain: cfg.Domain, manager: manager,
domainNoStore: strings.ToLower(cfg.DomainNoStore), addrPolicy: apolicy,
maxRecips: cfg.MaxRecipients, wg: new(sync.WaitGroup),
timeout: cfg.Timeout,
maxMessageBytes: cfg.MaxMessageBytes,
storeMessages: cfg.StoreMessages,
globalShutdown: globalShutdown,
manager: manager,
apolicy: apolicy,
waitgroup: new(sync.WaitGroup),
} }
} }
// Start the listener and handle incoming connections. // Start the listener and handle incoming connections.
func (s *Server) Start(ctx context.Context) { func (s *Server) Start(ctx context.Context) {
slog := log.With().Str("module", "smtp").Str("phase", "startup").Logger() slog := log.With().Str("module", "smtp").Str("phase", "startup").Logger()
addr, err := net.ResolveTCPAddr("tcp4", s.host) addr, err := net.ResolveTCPAddr("tcp4", s.config.Addr)
if err != nil { if err != nil {
slog.Error().Err(err).Msg("Failed to build tcp4 address") slog.Error().Err(err).Msg("Failed to build tcp4 address")
s.emergencyShutdown() s.emergencyShutdown()
@@ -119,10 +97,10 @@ func (s *Server) Start(ctx context.Context) {
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
if !s.storeMessages { if !s.config.StoreMessages {
slog.Info().Msg("Load test mode active, messages will not be stored") slog.Info().Msg("Load test mode active, messages will not be stored")
} else if s.domainNoStore != "" { } else if s.config.DomainNoStore != "" {
slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.domainNoStore) slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.config.DomainNoStore)
} }
// Listener go routine. // Listener go routine.
go s.serve(ctx) go s.serve(ctx)
@@ -172,7 +150,7 @@ func (s *Server) serve(ctx context.Context) {
} else { } else {
tempDelay = 0 tempDelay = 0
expConnectsTotal.Add(1) expConnectsTotal.Add(1)
s.waitgroup.Add(1) s.wg.Add(1)
go s.startSession(sessionID, conn) go s.startSession(sessionID, conn)
} }
} }
@@ -190,6 +168,6 @@ func (s *Server) emergencyShutdown() {
// Drain causes the caller to block until all active SMTP sessions have finished // Drain causes the caller to block until all active SMTP sessions have finished
func (s *Server) Drain() { func (s *Server) Drain() {
// Wait for sessions to close. // Wait for sessions to close.
s.waitgroup.Wait() s.wg.Wait()
log.Debug().Str("module", "smtp").Str("phase", "shutdown").Msg("SMTP connections have drained") log.Debug().Str("module", "smtp").Str("phase", "shutdown").Msg("SMTP connections have drained")
} }