1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 09:37:02 +00:00

Opportunistic TLS Support (#98)

* STARTTLS Support, disabled by default.
* Added documentation
This commit is contained in:
kingforaday
2018-05-06 14:56:38 -04:00
committed by James Hillyerd
parent 58c3e17be7
commit 894db04d70
4 changed files with 150 additions and 72 deletions

View File

@@ -3,9 +3,11 @@ package smtp
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/textproto"
"regexp"
"strconv"
"strings"
@@ -58,21 +60,22 @@ func (s State) String() string {
}
var commands = map[string]bool{
"HELO": true,
"EHLO": true,
"MAIL": true,
"RCPT": true,
"DATA": true,
"RSET": true,
"SEND": true,
"SOML": true,
"SAML": true,
"VRFY": true,
"EXPN": true,
"HELP": true,
"NOOP": true,
"QUIT": true,
"TURN": true,
"HELO": true,
"EHLO": true,
"MAIL": true,
"RCPT": true,
"DATA": true,
"RSET": true,
"SEND": true,
"SOML": true,
"SAML": true,
"VRFY": true,
"EXPN": true,
"HELP": true,
"NOOP": true,
"QUIT": true,
"TURN": true,
"STARTTLS": true,
}
// Session holds the state of an SMTP session
@@ -89,12 +92,15 @@ type Session struct {
recipients []*policy.Recipient // Recipients from RCPT commands.
logger zerolog.Logger // Session specific logger.
debug bool // Print network traffic to stdout.
tlsState *tls.ConnectionState
text *textproto.Conn
}
// NewSession creates a new Session for the given connection
func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *Session {
reader := bufio.NewReader(conn)
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
return &Session{
Server: server,
id: id,
@@ -105,6 +111,7 @@ func NewSession(server *Server, id int, conn net.Conn, logger zerolog.Logger) *S
recipients: make([]*policy.Recipient, 0),
logger: logger,
debug: server.config.Debug,
text: textproto.NewConn(conn),
}
}
@@ -135,6 +142,7 @@ func (s *Server) startSession(id int, conn net.Conn) {
}()
ssn := NewSession(s, id, conn, logger)
defer ssn.text.Close()
ssn.greet()
// This is our command reading loop
@@ -232,6 +240,7 @@ func (s *Server) startSession(id int, conn net.Conn) {
// GREET state -> waiting for HELO
func (s *Session) greetHandler(cmd string, arg string) {
const readyBanner = "Great, let's get this show on the road"
switch cmd {
case "HELO":
domain, err := parseHelloArgument(arg)
@@ -240,7 +249,7 @@ func (s *Session) greetHandler(cmd string, arg string) {
return
}
s.remoteDomain = domain
s.send("250 Great, let's get this show on the road")
s.send("250 " + readyBanner)
s.enterState(READY)
case "EHLO":
domain, err := parseHelloArgument(arg)
@@ -249,8 +258,12 @@ func (s *Session) greetHandler(cmd string, arg string) {
return
}
s.remoteDomain = domain
s.send("250-Great, let's get this show on the road")
// features before SIZE per RFC
s.send("250-" + readyBanner)
s.send("250-8BITMIME")
if s.Server.config.TLSEnabled && s.Server.tlsConfig != nil && s.tlsState == nil {
s.send("250-STARTTLS")
}
s.send(fmt.Sprintf("250 SIZE %v", s.config.MaxMessageBytes))
s.enterState(READY)
default:
@@ -271,7 +284,29 @@ func parseHelloArgument(arg string) (string, error) {
// READY state -> waiting for MAIL
func (s *Session) readyHandler(cmd string, arg string) {
if cmd == "MAIL" {
if cmd == "STARTTLS" {
if !s.Server.config.TLSEnabled {
// invalid command since unconfigured
s.logger.Debug().Msgf("454 TLS unavailable on the server")
s.send("454 TLS unavailable on the server")
return
}
if s.tlsState != nil {
// tls state previously valid
s.logger.Debug().Msg("454 A TLS session already agreed upon.")
s.send("454 A TLS session already agreed upon.")
return
}
s.logger.Debug().Msg("Initiating TLS context.")
s.send("220 STARTTLS")
// start tls connection handshake
tlsConn := tls.Server(s.conn, s.Server.tlsConfig)
s.conn = tlsConn
s.text = textproto.NewConn(s.conn)
s.tlsState = new(tls.ConnectionState)
*s.tlsState = tlsConn.ConnectionState()
s.enterState(GREET)
} else if cmd == "MAIL" {
// Capture group 1: from address. 2: optional params.
m := fromRegex.FindStringSubmatch(arg)
if m == nil {
@@ -367,57 +402,43 @@ func (s *Session) mailHandler(cmd string, arg string) {
// DATA
func (s *Session) dataHandler() {
s.send("354 Start mail input; end with <CRLF>.<CRLF>")
msgBuf := &bytes.Buffer{}
for {
lineBuf, err := s.readByteLine()
if err != nil {
if netErr, ok := err.(net.Error); ok {
if netErr.Timeout() {
s.send("221 Idle timeout, bye bye")
}
msgBuf, err := s.readByteLine()
if err != nil {
if netErr, ok := err.(net.Error); ok {
if netErr.Timeout() {
s.send("221 Idle timeout, bye bye")
}
s.logger.Warn().Msgf("Error: %v while reading", err)
s.enterState(QUIT)
return
}
if bytes.Equal(lineBuf, []byte(".\r\n")) || bytes.Equal(lineBuf, []byte(".\n")) {
// Mail data complete.
tstamp := time.Now().Format(timeStampFormat)
for _, recip := range s.recipients {
if recip.ShouldStore() {
// Generate Received header.
prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
s.remoteDomain, s.remoteHost, s.config.Domain, recip.Address.Address,
tstamp)
// Deliver message.
_, err := s.manager.Deliver(
recip, s.from, s.recipients, prefix, msgBuf.Bytes())
if err != nil {
s.logger.Error().Msgf("delivery for %v: %v", recip.LocalPart, err)
s.send(fmt.Sprintf("451 Failed to store message for %v", recip.LocalPart))
s.reset()
return
}
}
expReceivedTotal.Add(1)
}
s.send("250 Mail accepted for delivery")
s.logger.Info().Msgf("Message size %v bytes", msgBuf.Len())
s.reset()
return
}
// RFC: remove leading periods from DATA.
if len(lineBuf) > 0 && lineBuf[0] == '.' {
lineBuf = lineBuf[1:]
}
msgBuf.Write(lineBuf)
if msgBuf.Len() > s.config.MaxMessageBytes {
s.send("552 Maximum message size exceeded")
s.logger.Warn().Msgf("Max message size exceeded while in DATA")
s.reset()
return
}
s.logger.Warn().Msgf("Error: %v while reading", err)
s.enterState(QUIT)
return
}
mailData := bytes.NewBuffer(msgBuf)
// Mail data complete.
tstamp := time.Now().Format(timeStampFormat)
for _, recip := range s.recipients {
if recip.ShouldStore() {
// Generate Received header.
prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
s.remoteDomain, s.remoteHost, s.config.Domain, recip.Address.Address,
tstamp)
// Deliver message.
_, err := s.manager.Deliver(
recip, s.from, s.recipients, prefix, mailData.Bytes())
if err != nil {
s.logger.Error().Msgf("delivery for %v: %v", recip.LocalPart, err)
s.send(fmt.Sprintf("451 Failed to store message for %v", recip.LocalPart))
s.reset()
return
}
}
expReceivedTotal.Add(1)
}
s.send("250 Mail accepted for delivery")
s.logger.Info().Msgf("Message size %v bytes", mailData.Len())
s.reset()
return
}
func (s *Session) enterState(state State) {
@@ -440,7 +461,7 @@ func (s *Session) send(msg string) {
s.sendError = err
return
}
if _, err := fmt.Fprint(s.conn, msg+"\r\n"); err != nil {
if err := s.text.PrintfLine("%s", msg); err != nil {
s.sendError = err
s.logger.Warn().Msgf("Failed to send: %q", msg)
return
@@ -455,9 +476,12 @@ func (s *Session) readByteLine() ([]byte, error) {
if err := s.conn.SetReadDeadline(s.nextDeadline()); err != nil {
return nil, err
}
b, err := s.reader.ReadBytes('\n')
if err == nil && s.debug {
fmt.Printf("%04d %s\n", s.id, bytes.TrimRight(b, "\r\n"))
b, err := s.text.ReadDotBytes()
if err != nil {
return nil, err
}
if s.debug {
fmt.Printf("%04d Received %d bytes\n", s.id, len(b))
}
return b, err
}
@@ -467,7 +491,7 @@ func (s *Session) readLine() (line string, err error) {
if err = s.conn.SetReadDeadline(s.nextDeadline()); err != nil {
return "", err
}
line, err = s.reader.ReadString('\n')
line, err = s.text.ReadLine()
if err != nil {
return "", err
}
@@ -486,7 +510,7 @@ func (s *Session) parseCmd(line string) (cmd string, arg string, ok bool) {
case l < 4:
s.logger.Warn().Msgf("Command too short: %q", line)
return "", "", false
case l == 4:
case l == 4 || l == 8:
return strings.ToUpper(line), "", true
case l == 5:
// Too long to be only command, too short to have args