mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-17 09:37:02 +00:00
smtp: Use config.SMTP directly in Server #91
This commit is contained in:
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -81,6 +82,7 @@ type Storage struct {
|
|||||||
func Process() (*Root, error) {
|
func Process() (*Root, error) {
|
||||||
c := &Root{}
|
c := &Root{}
|
||||||
err := envconfig.Process(prefix, c)
|
err := envconfig.Process(prefix, c)
|
||||||
|
c.SMTP.DomainNoStore = strings.ToLower(c.SMTP.DomainNoStore)
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func (s *Server) startSession(id int, conn net.Conn) {
|
|||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
logger.Warn().Err(err).Msg("Closing connection")
|
logger.Warn().Err(err).Msg("Closing connection")
|
||||||
}
|
}
|
||||||
s.waitgroup.Done()
|
s.wg.Done()
|
||||||
expConnectsCurrent.Add(-1)
|
expConnectsCurrent.Add(-1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ func (s *Session) greetHandler(cmd string, arg string) {
|
|||||||
s.remoteDomain = domain
|
s.remoteDomain = domain
|
||||||
s.send("250-Great, let's get this show on the road")
|
s.send("250-Great, let's get this show on the road")
|
||||||
s.send("250-8BITMIME")
|
s.send("250-8BITMIME")
|
||||||
s.send(fmt.Sprintf("250 SIZE %v", s.server.maxMessageBytes))
|
s.send(fmt.Sprintf("250 SIZE %v", s.server.config.MaxMessageBytes))
|
||||||
s.enterState(READY)
|
s.enterState(READY)
|
||||||
default:
|
default:
|
||||||
s.ooSeq(cmd)
|
s.ooSeq(cmd)
|
||||||
@@ -296,7 +296,7 @@ func (s *Session) readyHandler(cmd string, arg string) {
|
|||||||
s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"])
|
s.logger.Warn().Msgf("Unable to parse SIZE %q as an integer", args["SIZE"])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if int(size) > s.server.maxMessageBytes {
|
if int(size) > s.server.config.MaxMessageBytes {
|
||||||
s.send("552 Max message size exceeded")
|
s.send("552 Max message size exceeded")
|
||||||
s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"])
|
s.logger.Warn().Msgf("Client wanted to send oversized message: %v", args["SIZE"])
|
||||||
return
|
return
|
||||||
@@ -323,15 +323,17 @@ func (s *Session) mailHandler(cmd string, arg string) {
|
|||||||
}
|
}
|
||||||
// This trim is probably too forgiving
|
// This trim is probably too forgiving
|
||||||
addr := strings.Trim(arg[3:], "<> ")
|
addr := strings.Trim(arg[3:], "<> ")
|
||||||
recip, err := s.server.apolicy.NewRecipient(addr)
|
recip, err := s.server.addrPolicy.NewRecipient(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.send("501 Bad recipient address syntax")
|
s.send("501 Bad recipient address syntax")
|
||||||
s.logger.Warn().Msgf("Bad address as RCPT arg: %q, %s", addr, err)
|
s.logger.Warn().Msgf("Bad address as RCPT arg: %q, %s", addr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(s.recipients) >= s.server.maxRecips {
|
if len(s.recipients) >= s.server.config.MaxRecipients {
|
||||||
s.logger.Warn().Msgf("Maximum limit of %v recipients reached", s.server.maxRecips)
|
s.logger.Warn().Msgf("Maximum limit of %v recipients reached",
|
||||||
s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", s.server.maxRecips))
|
s.server.config.MaxRecipients)
|
||||||
|
s.send(fmt.Sprintf("552 Maximum limit of %v recipients reached",
|
||||||
|
s.server.config.MaxRecipients))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.recipients = append(s.recipients, recip)
|
s.recipients = append(s.recipients, recip)
|
||||||
@@ -379,7 +381,7 @@ func (s *Session) dataHandler() {
|
|||||||
if recip.ShouldStore() {
|
if recip.ShouldStore() {
|
||||||
// Generate Received header.
|
// Generate Received header.
|
||||||
prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
|
prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
|
||||||
s.remoteDomain, s.remoteHost, s.server.domain, recip.Address.Address,
|
s.remoteDomain, s.remoteHost, s.server.config.Domain, recip.Address.Address,
|
||||||
tstamp)
|
tstamp)
|
||||||
// Deliver message.
|
// Deliver message.
|
||||||
_, err := s.server.manager.Deliver(
|
_, err := s.server.manager.Deliver(
|
||||||
@@ -403,7 +405,7 @@ func (s *Session) dataHandler() {
|
|||||||
lineBuf = lineBuf[1:]
|
lineBuf = lineBuf[1:]
|
||||||
}
|
}
|
||||||
msgBuf.Write(lineBuf)
|
msgBuf.Write(lineBuf)
|
||||||
if msgBuf.Len() > s.server.maxMessageBytes {
|
if msgBuf.Len() > s.server.config.MaxMessageBytes {
|
||||||
s.send("552 Maximum message size exceeded")
|
s.send("552 Maximum message size exceeded")
|
||||||
s.logger.Warn().Msgf("Max message size exceeded while in DATA")
|
s.logger.Warn().Msgf("Max message size exceeded while in DATA")
|
||||||
s.reset()
|
s.reset()
|
||||||
@@ -418,12 +420,12 @@ func (s *Session) enterState(state State) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) greet() {
|
func (s *Session) greet() {
|
||||||
s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.domain))
|
s.send(fmt.Sprintf("220 %v Inbucket SMTP ready", s.server.config.Domain))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the next read or write deadline based on maxIdle
|
// nextDeadline calculates the next read or write deadline based on configured timeout.
|
||||||
func (s *Session) nextDeadline() time.Time {
|
func (s *Session) nextDeadline() time.Time {
|
||||||
return time.Now().Add(s.server.timeout)
|
return time.Now().Add(s.server.config.Timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send requested message, store errors in Session.sendError
|
// Send requested message, store errors in Session.sendError
|
||||||
|
|||||||
@@ -392,7 +392,7 @@ func setupSMTPSession(server *Server) net.Conn {
|
|||||||
// Pair of pipes to communicate
|
// Pair of pipes to communicate
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
// Start the session
|
// Start the session
|
||||||
server.waitgroup.Add(1)
|
server.wg.Add(1)
|
||||||
sessionNum++
|
sessionNum++
|
||||||
go server.startSession(sessionNum, &mockConn{serverConn})
|
go server.startSession(sessionNum, &mockConn{serverConn})
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"expvar"
|
"expvar"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,26 +15,6 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
m := expvar.NewMap("smtp")
|
|
||||||
m.Set("ConnectsTotal", expConnectsTotal)
|
|
||||||
m.Set("ConnectsHist", expConnectsHist)
|
|
||||||
m.Set("ConnectsCurrent", expConnectsCurrent)
|
|
||||||
m.Set("ReceivedTotal", expReceivedTotal)
|
|
||||||
m.Set("ReceivedHist", expReceivedHist)
|
|
||||||
m.Set("ErrorsTotal", expErrorsTotal)
|
|
||||||
m.Set("ErrorsHist", expErrorsHist)
|
|
||||||
m.Set("WarnsTotal", expWarnsTotal)
|
|
||||||
m.Set("WarnsHist", expWarnsHist)
|
|
||||||
|
|
||||||
metric.AddTickerFunc(func() {
|
|
||||||
expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal))
|
|
||||||
expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal))
|
|
||||||
expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal))
|
|
||||||
expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Raw stat collectors
|
// Raw stat collectors
|
||||||
expConnectsTotal = new(expvar.Int)
|
expConnectsTotal = new(expvar.Int)
|
||||||
@@ -57,56 +36,55 @@ var (
|
|||||||
expWarnsHist = new(expvar.String)
|
expWarnsHist = new(expvar.String)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server holds the configuration and state of our SMTP server
|
func init() {
|
||||||
type Server struct {
|
m := expvar.NewMap("smtp")
|
||||||
// TODO(#91) Refactor config items out of this struct
|
m.Set("ConnectsTotal", expConnectsTotal)
|
||||||
config config.SMTP
|
m.Set("ConnectsHist", expConnectsHist)
|
||||||
// Configuration
|
m.Set("ConnectsCurrent", expConnectsCurrent)
|
||||||
host string
|
m.Set("ReceivedTotal", expReceivedTotal)
|
||||||
domain string
|
m.Set("ReceivedHist", expReceivedHist)
|
||||||
domainNoStore string
|
m.Set("ErrorsTotal", expErrorsTotal)
|
||||||
maxRecips int
|
m.Set("ErrorsHist", expErrorsHist)
|
||||||
maxMessageBytes int
|
m.Set("WarnsTotal", expWarnsTotal)
|
||||||
storeMessages bool
|
m.Set("WarnsHist", expWarnsHist)
|
||||||
timeout time.Duration
|
metric.AddTickerFunc(func() {
|
||||||
|
expReceivedHist.Set(metric.Push(deliveredHist, expReceivedTotal))
|
||||||
// Dependencies
|
expConnectsHist.Set(metric.Push(connectsHist, expConnectsTotal))
|
||||||
apolicy *policy.Addressing // Address policy.
|
expErrorsHist.Set(metric.Push(errorsHist, expErrorsTotal))
|
||||||
globalShutdown chan bool // Shuts down Inbucket.
|
expWarnsHist.Set(metric.Push(warnsHist, expWarnsTotal))
|
||||||
manager message.Manager // Used to deliver messages.
|
})
|
||||||
|
|
||||||
// State
|
|
||||||
listener net.Listener // Incoming network connections
|
|
||||||
waitgroup *sync.WaitGroup // Waitgroup tracks individual sessions
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Server instance with the specificed config
|
// Server holds the configuration and state of our SMTP server.
|
||||||
|
type Server struct {
|
||||||
|
config config.SMTP // SMTP configuration.
|
||||||
|
addrPolicy *policy.Addressing // Address policy.
|
||||||
|
globalShutdown chan bool // Shuts down Inbucket.
|
||||||
|
manager message.Manager // Used to deliver messages.
|
||||||
|
listener net.Listener // Incoming network connections.
|
||||||
|
wg *sync.WaitGroup // Waitgroup tracks individual sessions.
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new Server instance with the specificed config.
|
||||||
func NewServer(
|
func NewServer(
|
||||||
cfg config.SMTP,
|
smtpConfig config.SMTP,
|
||||||
globalShutdown chan bool,
|
globalShutdown chan bool,
|
||||||
manager message.Manager,
|
manager message.Manager,
|
||||||
apolicy *policy.Addressing,
|
apolicy *policy.Addressing,
|
||||||
) *Server {
|
) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
config: cfg,
|
config: smtpConfig,
|
||||||
host: cfg.Addr,
|
globalShutdown: globalShutdown,
|
||||||
domain: cfg.Domain,
|
manager: manager,
|
||||||
domainNoStore: strings.ToLower(cfg.DomainNoStore),
|
addrPolicy: apolicy,
|
||||||
maxRecips: cfg.MaxRecipients,
|
wg: new(sync.WaitGroup),
|
||||||
timeout: cfg.Timeout,
|
|
||||||
maxMessageBytes: cfg.MaxMessageBytes,
|
|
||||||
storeMessages: cfg.StoreMessages,
|
|
||||||
globalShutdown: globalShutdown,
|
|
||||||
manager: manager,
|
|
||||||
apolicy: apolicy,
|
|
||||||
waitgroup: new(sync.WaitGroup),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the listener and handle incoming connections.
|
// Start the listener and handle incoming connections.
|
||||||
func (s *Server) Start(ctx context.Context) {
|
func (s *Server) Start(ctx context.Context) {
|
||||||
slog := log.With().Str("module", "smtp").Str("phase", "startup").Logger()
|
slog := log.With().Str("module", "smtp").Str("phase", "startup").Logger()
|
||||||
addr, err := net.ResolveTCPAddr("tcp4", s.host)
|
addr, err := net.ResolveTCPAddr("tcp4", s.config.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error().Err(err).Msg("Failed to build tcp4 address")
|
slog.Error().Err(err).Msg("Failed to build tcp4 address")
|
||||||
s.emergencyShutdown()
|
s.emergencyShutdown()
|
||||||
@@ -119,10 +97,10 @@ func (s *Server) Start(ctx context.Context) {
|
|||||||
s.emergencyShutdown()
|
s.emergencyShutdown()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !s.storeMessages {
|
if !s.config.StoreMessages {
|
||||||
slog.Info().Msg("Load test mode active, messages will not be stored")
|
slog.Info().Msg("Load test mode active, messages will not be stored")
|
||||||
} else if s.domainNoStore != "" {
|
} else if s.config.DomainNoStore != "" {
|
||||||
slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.domainNoStore)
|
slog.Info().Msgf("Messages sent to domain '%v' will be discarded", s.config.DomainNoStore)
|
||||||
}
|
}
|
||||||
// Listener go routine.
|
// Listener go routine.
|
||||||
go s.serve(ctx)
|
go s.serve(ctx)
|
||||||
@@ -172,7 +150,7 @@ func (s *Server) serve(ctx context.Context) {
|
|||||||
} else {
|
} else {
|
||||||
tempDelay = 0
|
tempDelay = 0
|
||||||
expConnectsTotal.Add(1)
|
expConnectsTotal.Add(1)
|
||||||
s.waitgroup.Add(1)
|
s.wg.Add(1)
|
||||||
go s.startSession(sessionID, conn)
|
go s.startSession(sessionID, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -190,6 +168,6 @@ func (s *Server) emergencyShutdown() {
|
|||||||
// Drain causes the caller to block until all active SMTP sessions have finished
|
// Drain causes the caller to block until all active SMTP sessions have finished
|
||||||
func (s *Server) Drain() {
|
func (s *Server) Drain() {
|
||||||
// Wait for sessions to close.
|
// Wait for sessions to close.
|
||||||
s.waitgroup.Wait()
|
s.wg.Wait()
|
||||||
log.Debug().Str("module", "smtp").Str("phase", "shutdown").Msg("SMTP connections have drained")
|
log.Debug().Str("module", "smtp").Str("phase", "shutdown").Msg("SMTP connections have drained")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user