mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-17 17:47:03 +00:00
Reorganize packages, closes #79
- All packages go into either cmd or pkg directories - Most packages renamed - Server packages moved into pkg/server - sanitize moved into webui, as that's the only place it's used - filestore moved into pkg/storage/file - Makefile updated, and PKG variable use fixed
This commit is contained in:
656
pkg/server/pop3/handler.go
Normal file
656
pkg/server/pop3/handler.go
Normal file
@@ -0,0 +1,656 @@
|
||||
package pop3
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
// State tracks the current mode of our POP3 state machine
|
||||
type State int
|
||||
|
||||
const (
|
||||
// AUTHORIZATION state: the client must now identify and authenticate
|
||||
AUTHORIZATION State = iota
|
||||
// TRANSACTION state: mailbox open, client may now issue commands
|
||||
TRANSACTION
|
||||
// QUIT state: client requests us to end session
|
||||
QUIT
|
||||
)
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case AUTHORIZATION:
|
||||
return "AUTHORIZATION"
|
||||
case TRANSACTION:
|
||||
return "TRANSACTION"
|
||||
case QUIT:
|
||||
return "QUIT"
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
var commands = map[string]bool{
|
||||
"QUIT": true,
|
||||
"STAT": true,
|
||||
"LIST": true,
|
||||
"RETR": true,
|
||||
"DELE": true,
|
||||
"NOOP": true,
|
||||
"RSET": true,
|
||||
"TOP": true,
|
||||
"UIDL": true,
|
||||
"USER": true,
|
||||
"PASS": true,
|
||||
"APOP": true,
|
||||
"CAPA": true,
|
||||
}
|
||||
|
||||
// Session defines an active POP3 session
|
||||
type Session struct {
|
||||
server *Server // Reference to the server we belong to
|
||||
id int // Session ID number
|
||||
conn net.Conn // Our network connection
|
||||
remoteHost string // IP address of client
|
||||
sendError error // Used to bail out of read loop on send error
|
||||
state State // Current session state
|
||||
reader *bufio.Reader // Buffered reader for our net conn
|
||||
user string // Mailbox name
|
||||
mailbox datastore.Mailbox // Mailbox instance
|
||||
messages []datastore.Message // Slice of messages in mailbox
|
||||
retain []bool // Messages to retain upon UPDATE (true=retain)
|
||||
msgCount int // Number of undeleted messages
|
||||
}
|
||||
|
||||
// NewSession creates a new POP3 session
|
||||
func NewSession(server *Server, id int, conn net.Conn) *Session {
|
||||
reader := bufio.NewReader(conn)
|
||||
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
return &Session{server: server, id: id, conn: conn, state: AUTHORIZATION,
|
||||
reader: reader, remoteHost: host}
|
||||
}
|
||||
|
||||
func (ses *Session) String() string {
|
||||
return fmt.Sprintf("Session{id: %v, state: %v}", ses.id, ses.state)
|
||||
}
|
||||
|
||||
/* Session flow:
|
||||
* 1. Send initial greeting
|
||||
* 2. Receive cmd
|
||||
* 3. If good cmd, respond, optionally change state
|
||||
* 4. If bad cmd, respond error
|
||||
* 5. Goto 2
|
||||
*/
|
||||
func (s *Server) startSession(id int, conn net.Conn) {
|
||||
log.Infof("POP3 connection from %v, starting session <%v>", conn.RemoteAddr(), id)
|
||||
//expConnectsCurrent.Add(1)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Error closing POP3 connection for <%v>: %v", id, err)
|
||||
}
|
||||
s.waitgroup.Done()
|
||||
//expConnectsCurrent.Add(-1)
|
||||
}()
|
||||
|
||||
ses := NewSession(s, id, conn)
|
||||
ses.send(fmt.Sprintf("+OK Inbucket POP3 server ready <%v.%v@%v>", os.Getpid(),
|
||||
time.Now().Unix(), s.domain))
|
||||
|
||||
// This is our command reading loop
|
||||
for ses.state != QUIT && ses.sendError == nil {
|
||||
line, err := ses.readLine()
|
||||
if err == nil {
|
||||
if cmd, arg, ok := ses.parseCmd(line); ok {
|
||||
// Check against valid SMTP commands
|
||||
if cmd == "" {
|
||||
ses.send("-ERR Speak up")
|
||||
continue
|
||||
}
|
||||
if !commands[cmd] {
|
||||
ses.send(fmt.Sprintf("-ERR Syntax error, %v command unrecognized", cmd))
|
||||
ses.logWarn("Unrecognized command: %v", cmd)
|
||||
continue
|
||||
}
|
||||
|
||||
// Commands we handle in any state
|
||||
switch cmd {
|
||||
case "CAPA":
|
||||
// List our capabilities per RFC2449
|
||||
ses.send("+OK Capability list follows")
|
||||
ses.send("TOP")
|
||||
ses.send("USER")
|
||||
ses.send("UIDL")
|
||||
ses.send("IMPLEMENTATION Inbucket")
|
||||
ses.send(".")
|
||||
continue
|
||||
}
|
||||
|
||||
// Send command to handler for current state
|
||||
switch ses.state {
|
||||
case AUTHORIZATION:
|
||||
ses.authorizationHandler(cmd, arg)
|
||||
continue
|
||||
case TRANSACTION:
|
||||
ses.transactionHandler(cmd, arg)
|
||||
continue
|
||||
}
|
||||
ses.logError("Session entered unexpected state %v", ses.state)
|
||||
break
|
||||
} else {
|
||||
ses.send("-ERR Syntax error, command garbled")
|
||||
}
|
||||
} else {
|
||||
// readLine() returned an error
|
||||
if err == io.EOF {
|
||||
switch ses.state {
|
||||
case AUTHORIZATION:
|
||||
// EOF is common here
|
||||
ses.logInfo("Client closed connection (state %v)", ses.state)
|
||||
default:
|
||||
ses.logWarn("Got EOF while in state %v", ses.state)
|
||||
}
|
||||
break
|
||||
}
|
||||
// not an EOF
|
||||
ses.logWarn("Connection error: %v", err)
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
if netErr.Timeout() {
|
||||
ses.send("-ERR Idle timeout, bye bye")
|
||||
break
|
||||
}
|
||||
}
|
||||
ses.send("-ERR Connection error, sorry")
|
||||
break
|
||||
}
|
||||
}
|
||||
if ses.sendError != nil {
|
||||
ses.logWarn("Network send error: %v", ses.sendError)
|
||||
}
|
||||
ses.logInfo("Closing connection")
|
||||
}
|
||||
|
||||
// AUTHORIZATION state
|
||||
func (ses *Session) authorizationHandler(cmd string, args []string) {
|
||||
switch cmd {
|
||||
case "QUIT":
|
||||
ses.send("+OK Goodnight and good luck")
|
||||
ses.enterState(QUIT)
|
||||
case "USER":
|
||||
if len(args) > 0 {
|
||||
ses.user = args[0]
|
||||
ses.send(fmt.Sprintf("+OK Hello %v, welcome to Inbucket", ses.user))
|
||||
} else {
|
||||
ses.send("-ERR Missing username argument")
|
||||
}
|
||||
case "PASS":
|
||||
if ses.user == "" {
|
||||
ses.ooSeq(cmd)
|
||||
} else {
|
||||
var err error
|
||||
ses.mailbox, err = ses.server.dataStore.MailboxFor(ses.user)
|
||||
if err != nil {
|
||||
ses.logError("Failed to open mailbox for %v", ses.user)
|
||||
ses.send(fmt.Sprintf("-ERR Failed to open mailbox for %v", ses.user))
|
||||
ses.enterState(QUIT)
|
||||
return
|
||||
}
|
||||
ses.loadMailbox()
|
||||
ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user))
|
||||
ses.enterState(TRANSACTION)
|
||||
}
|
||||
case "APOP":
|
||||
if len(args) != 2 {
|
||||
ses.logWarn("Expected two arguments for APOP")
|
||||
ses.send("-ERR APOP requires two arguments")
|
||||
return
|
||||
}
|
||||
ses.user = args[0]
|
||||
var err error
|
||||
ses.mailbox, err = ses.server.dataStore.MailboxFor(ses.user)
|
||||
if err != nil {
|
||||
ses.logError("Failed to open mailbox for %v", ses.user)
|
||||
ses.send(fmt.Sprintf("-ERR Failed to open mailbox for %v", ses.user))
|
||||
ses.enterState(QUIT)
|
||||
return
|
||||
}
|
||||
ses.loadMailbox()
|
||||
ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user))
|
||||
ses.enterState(TRANSACTION)
|
||||
default:
|
||||
ses.ooSeq(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// TRANSACTION state
|
||||
func (ses *Session) transactionHandler(cmd string, args []string) {
|
||||
switch cmd {
|
||||
case "STAT":
|
||||
if len(args) != 0 {
|
||||
ses.logWarn("STAT got an unexpected argument")
|
||||
ses.send("-ERR STAT command must have no arguments")
|
||||
return
|
||||
}
|
||||
var count int
|
||||
var size int64
|
||||
for i, msg := range ses.messages {
|
||||
if ses.retain[i] {
|
||||
count++
|
||||
size += msg.Size()
|
||||
}
|
||||
}
|
||||
ses.send(fmt.Sprintf("+OK %v %v", count, size))
|
||||
case "LIST":
|
||||
if len(args) > 1 {
|
||||
ses.logWarn("LIST command had more than 1 argument")
|
||||
ses.send("-ERR LIST command must have zero or one argument")
|
||||
return
|
||||
}
|
||||
if len(args) == 1 {
|
||||
msgNum, err := strconv.ParseInt(args[0], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("LIST command argument was not an integer")
|
||||
ses.send("-ERR LIST command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if msgNum < 1 {
|
||||
ses.logWarn("LIST command argument was less than 1")
|
||||
ses.send("-ERR LIST argument must be greater than 0")
|
||||
return
|
||||
}
|
||||
if int(msgNum) > len(ses.messages) {
|
||||
ses.logWarn("LIST command argument was greater than number of messages")
|
||||
ses.send("-ERR LIST argument must not exceed the number of messages")
|
||||
return
|
||||
}
|
||||
if !ses.retain[msgNum-1] {
|
||||
ses.logWarn("Client tried to LIST a message it had deleted")
|
||||
ses.send(fmt.Sprintf("-ERR You deleted message %v", msgNum))
|
||||
return
|
||||
}
|
||||
ses.send(fmt.Sprintf("+OK %v %v", msgNum, ses.messages[msgNum-1].Size()))
|
||||
} else {
|
||||
ses.send(fmt.Sprintf("+OK Listing %v messages", ses.msgCount))
|
||||
for i, msg := range ses.messages {
|
||||
if ses.retain[i] {
|
||||
ses.send(fmt.Sprintf("%v %v", i+1, msg.Size()))
|
||||
}
|
||||
}
|
||||
ses.send(".")
|
||||
}
|
||||
case "UIDL":
|
||||
if len(args) > 1 {
|
||||
ses.logWarn("UIDL command had more than 1 argument")
|
||||
ses.send("-ERR UIDL command must have zero or one argument")
|
||||
return
|
||||
}
|
||||
if len(args) == 1 {
|
||||
msgNum, err := strconv.ParseInt(args[0], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("UIDL command argument was not an integer")
|
||||
ses.send("-ERR UIDL command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if msgNum < 1 {
|
||||
ses.logWarn("UIDL command argument was less than 1")
|
||||
ses.send("-ERR UIDL argument must be greater than 0")
|
||||
return
|
||||
}
|
||||
if int(msgNum) > len(ses.messages) {
|
||||
ses.logWarn("UIDL command argument was greater than number of messages")
|
||||
ses.send("-ERR UIDL argument must not exceed the number of messages")
|
||||
return
|
||||
}
|
||||
if !ses.retain[msgNum-1] {
|
||||
ses.logWarn("Client tried to UIDL a message it had deleted")
|
||||
ses.send(fmt.Sprintf("-ERR You deleted message %v", msgNum))
|
||||
return
|
||||
}
|
||||
ses.send(fmt.Sprintf("+OK %v %v", msgNum, ses.messages[msgNum-1].ID()))
|
||||
} else {
|
||||
ses.send(fmt.Sprintf("+OK Listing %v messages", ses.msgCount))
|
||||
for i, msg := range ses.messages {
|
||||
if ses.retain[i] {
|
||||
ses.send(fmt.Sprintf("%v %v", i+1, msg.ID()))
|
||||
}
|
||||
}
|
||||
ses.send(".")
|
||||
}
|
||||
case "DELE":
|
||||
if len(args) != 1 {
|
||||
ses.logWarn("DELE command had invalid number of arguments")
|
||||
ses.send("-ERR DELE command requires a single argument")
|
||||
return
|
||||
}
|
||||
msgNum, err := strconv.ParseInt(args[0], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("DELE command argument was not an integer")
|
||||
ses.send("-ERR DELE command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if msgNum < 1 {
|
||||
ses.logWarn("DELE command argument was less than 1")
|
||||
ses.send("-ERR DELE argument must be greater than 0")
|
||||
return
|
||||
}
|
||||
if int(msgNum) > len(ses.messages) {
|
||||
ses.logWarn("DELE command argument was greater than number of messages")
|
||||
ses.send("-ERR DELE argument must not exceed the number of messages")
|
||||
return
|
||||
}
|
||||
if ses.retain[msgNum-1] {
|
||||
ses.retain[msgNum-1] = false
|
||||
ses.msgCount--
|
||||
ses.send(fmt.Sprintf("+OK Deleted message %v", msgNum))
|
||||
} else {
|
||||
ses.logWarn("Client tried to DELE an already deleted message")
|
||||
ses.send(fmt.Sprintf("-ERR Message %v has already been deleted", msgNum))
|
||||
}
|
||||
case "RETR":
|
||||
if len(args) != 1 {
|
||||
ses.logWarn("RETR command had invalid number of arguments")
|
||||
ses.send("-ERR RETR command requires a single argument")
|
||||
return
|
||||
}
|
||||
msgNum, err := strconv.ParseInt(args[0], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("RETR command argument was not an integer")
|
||||
ses.send("-ERR RETR command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if msgNum < 1 {
|
||||
ses.logWarn("RETR command argument was less than 1")
|
||||
ses.send("-ERR RETR argument must be greater than 0")
|
||||
return
|
||||
}
|
||||
if int(msgNum) > len(ses.messages) {
|
||||
ses.logWarn("RETR command argument was greater than number of messages")
|
||||
ses.send("-ERR RETR argument must not exceed the number of messages")
|
||||
return
|
||||
}
|
||||
ses.send(fmt.Sprintf("+OK %v bytes follows", ses.messages[msgNum-1].Size()))
|
||||
ses.sendMessage(ses.messages[msgNum-1])
|
||||
case "TOP":
|
||||
if len(args) != 2 {
|
||||
ses.logWarn("TOP command had invalid number of arguments")
|
||||
ses.send("-ERR TOP command requires two arguments")
|
||||
return
|
||||
}
|
||||
msgNum, err := strconv.ParseInt(args[0], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("TOP command first argument was not an integer")
|
||||
ses.send("-ERR TOP command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if msgNum < 1 {
|
||||
ses.logWarn("TOP command first argument was less than 1")
|
||||
ses.send("-ERR TOP first argument must be greater than 0")
|
||||
return
|
||||
}
|
||||
if int(msgNum) > len(ses.messages) {
|
||||
ses.logWarn("TOP command first argument was greater than number of messages")
|
||||
ses.send("-ERR TOP first argument must not exceed the number of messages")
|
||||
return
|
||||
}
|
||||
|
||||
var lines int64
|
||||
lines, err = strconv.ParseInt(args[1], 10, 32)
|
||||
if err != nil {
|
||||
ses.logWarn("TOP command second argument was not an integer")
|
||||
ses.send("-ERR TOP command requires an integer argument")
|
||||
return
|
||||
}
|
||||
if lines < 0 {
|
||||
ses.logWarn("TOP command second argument was negative")
|
||||
ses.send("-ERR TOP second argument must be non-negative")
|
||||
return
|
||||
}
|
||||
ses.send("+OK Top of message follows")
|
||||
ses.sendMessageTop(ses.messages[msgNum-1], int(lines))
|
||||
case "QUIT":
|
||||
ses.send("+OK We will process your deletes")
|
||||
ses.processDeletes()
|
||||
ses.enterState(QUIT)
|
||||
case "NOOP":
|
||||
ses.send("+OK I have sucessfully done nothing")
|
||||
case "RSET":
|
||||
// Reset session, don't actually delete anything I told you to
|
||||
ses.logTrace("Resetting session state on RSET request")
|
||||
ses.reset()
|
||||
ses.send("+OK Session reset")
|
||||
default:
|
||||
ses.ooSeq(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Send the contents of the message to the client
|
||||
func (ses *Session) sendMessage(msg datastore.Message) {
|
||||
reader, err := msg.RawReader()
|
||||
if err != nil {
|
||||
ses.logError("Failed to read message for RETR command")
|
||||
ses.send("-ERR Failed to RETR that message, internal error")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := reader.Close(); err != nil {
|
||||
ses.logError("Failed to close message: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
// Lines starting with . must be prefixed with another .
|
||||
if strings.HasPrefix(line, ".") {
|
||||
line = "." + line
|
||||
}
|
||||
ses.send(line)
|
||||
}
|
||||
|
||||
if err = scanner.Err(); err != nil {
|
||||
ses.logError("Failed to read message for RETR command")
|
||||
ses.send(".")
|
||||
ses.send("-ERR Failed to RETR that message, internal error")
|
||||
return
|
||||
}
|
||||
ses.send(".")
|
||||
}
|
||||
|
||||
// Send the headers plus the top N lines to the client
|
||||
func (ses *Session) sendMessageTop(msg datastore.Message, lineCount int) {
|
||||
reader, err := msg.RawReader()
|
||||
if err != nil {
|
||||
ses.logError("Failed to read message for RETR command")
|
||||
ses.send("-ERR Failed to RETR that message, internal error")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := reader.Close(); err != nil {
|
||||
ses.logError("Failed to close message: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
inBody := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
// Lines starting with . must be prefixed with another .
|
||||
if strings.HasPrefix(line, ".") {
|
||||
line = "." + line
|
||||
}
|
||||
if inBody {
|
||||
// Check if we need to send anymore lines
|
||||
if lineCount < 1 {
|
||||
break
|
||||
} else {
|
||||
lineCount--
|
||||
}
|
||||
} else {
|
||||
if line == "" {
|
||||
// We've hit the end of the header
|
||||
inBody = true
|
||||
}
|
||||
}
|
||||
ses.send(line)
|
||||
}
|
||||
|
||||
if err = scanner.Err(); err != nil {
|
||||
ses.logError("Failed to read message for RETR command")
|
||||
ses.send(".")
|
||||
ses.send("-ERR Failed to RETR that message, internal error")
|
||||
return
|
||||
}
|
||||
ses.send(".")
|
||||
}
|
||||
|
||||
// Load the users mailbox
|
||||
func (ses *Session) loadMailbox() {
|
||||
var err error
|
||||
ses.messages, err = ses.mailbox.GetMessages()
|
||||
if err != nil {
|
||||
ses.logError("Failed to load messages for %v", ses.user)
|
||||
}
|
||||
|
||||
ses.retainAll()
|
||||
}
|
||||
|
||||
// Reset retain flag to true for all messages
|
||||
func (ses *Session) retainAll() {
|
||||
ses.retain = make([]bool, len(ses.messages))
|
||||
for i := range ses.retain {
|
||||
ses.retain[i] = true
|
||||
}
|
||||
ses.msgCount = len(ses.messages)
|
||||
}
|
||||
|
||||
// This would be considered the "UPDATE" state in the RFC, but it does not fit
|
||||
// with our state-machine design here, since no commands are accepted - it just
|
||||
// indicates that the session was closed cleanly and that deletes should be
|
||||
// processed.
|
||||
func (ses *Session) processDeletes() {
|
||||
ses.logInfo("Processing deletes")
|
||||
for i, msg := range ses.messages {
|
||||
if !ses.retain[i] {
|
||||
ses.logTrace("Deleting %v", msg)
|
||||
if err := msg.Delete(); err != nil {
|
||||
ses.logWarn("Error deleting %v: %v", msg, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ses *Session) enterState(state State) {
|
||||
ses.state = state
|
||||
ses.logTrace("Entering state %v", state)
|
||||
}
|
||||
|
||||
// Calculate the next read or write deadline based on maxIdleSeconds
|
||||
func (ses *Session) nextDeadline() time.Time {
|
||||
return time.Now().Add(time.Duration(ses.server.maxIdleSeconds) * time.Second)
|
||||
}
|
||||
|
||||
// Send requested message, store errors in Session.sendError
|
||||
func (ses *Session) send(msg string) {
|
||||
if err := ses.conn.SetWriteDeadline(ses.nextDeadline()); err != nil {
|
||||
ses.sendError = err
|
||||
return
|
||||
}
|
||||
if _, err := fmt.Fprint(ses.conn, msg+"\r\n"); err != nil {
|
||||
ses.sendError = err
|
||||
ses.logWarn("Failed to send: '%v'", msg)
|
||||
return
|
||||
}
|
||||
ses.logTrace(">> %v >>", msg)
|
||||
}
|
||||
|
||||
// readByteLine reads a line of input into the provided buffer. Does
|
||||
// not reset the Buffer - please do so prior to calling.
|
||||
func (ses *Session) readByteLine(buf *bytes.Buffer) error {
|
||||
if err := ses.conn.SetReadDeadline(ses.nextDeadline()); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
line, err := ses.reader.ReadBytes('\r')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = buf.Write(line); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read the next byte looking for '\n'
|
||||
c, err := ses.reader.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := buf.WriteByte(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if c == '\n' {
|
||||
// We've reached the end of the line, return
|
||||
return nil
|
||||
}
|
||||
// Else, keep looking
|
||||
}
|
||||
// Should be unreachable
|
||||
}
|
||||
|
||||
// Reads a line of input
|
||||
func (ses *Session) readLine() (line string, err error) {
|
||||
if err = ses.conn.SetReadDeadline(ses.nextDeadline()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
line, err = ses.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ses.logTrace("<< %v <<", strings.TrimRight(line, "\r\n"))
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func (ses *Session) parseCmd(line string) (cmd string, args []string, ok bool) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
return "", nil, true
|
||||
}
|
||||
|
||||
words := strings.Split(line, " ")
|
||||
return strings.ToUpper(words[0]), words[1:], true
|
||||
}
|
||||
|
||||
func (ses *Session) reset() {
|
||||
ses.retainAll()
|
||||
}
|
||||
|
||||
func (ses *Session) ooSeq(cmd string) {
|
||||
ses.send(fmt.Sprintf("-ERR Command %v is out of sequence", cmd))
|
||||
ses.logWarn("Wasn't expecting %v here", cmd)
|
||||
}
|
||||
|
||||
// Session specific logging methods
|
||||
func (ses *Session) logTrace(msg string, args ...interface{}) {
|
||||
log.Tracef("POP3[%v]<%v> %v", ses.remoteHost, ses.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ses *Session) logInfo(msg string, args ...interface{}) {
|
||||
log.Infof("POP3[%v]<%v> %v", ses.remoteHost, ses.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ses *Session) logWarn(msg string, args ...interface{}) {
|
||||
// Update metrics
|
||||
//expWarnsTotal.Add(1)
|
||||
log.Warnf("POP3[%v]<%v> %v", ses.remoteHost, ses.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ses *Session) logError(msg string, args ...interface{}) {
|
||||
// Update metrics
|
||||
//expErrorsTotal.Add(1)
|
||||
log.Errorf("POP3[%v]<%v> %v", ses.remoteHost, ses.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
123
pkg/server/pop3/listener.go
Normal file
123
pkg/server/pop3/listener.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package pop3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/config"
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
// Server defines an instance of our POP3 server
|
||||
type Server struct {
|
||||
host string
|
||||
domain string
|
||||
maxIdleSeconds int
|
||||
dataStore datastore.DataStore
|
||||
listener net.Listener
|
||||
globalShutdown chan bool
|
||||
waitgroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
// New creates a new Server struct
|
||||
func New(cfg config.POP3Config, shutdownChan chan bool, ds datastore.DataStore) *Server {
|
||||
return &Server{
|
||||
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
|
||||
domain: cfg.Domain,
|
||||
dataStore: ds,
|
||||
maxIdleSeconds: cfg.MaxIdleSeconds,
|
||||
globalShutdown: shutdownChan,
|
||||
waitgroup: new(sync.WaitGroup),
|
||||
}
|
||||
}
|
||||
|
||||
// Start the server and listen for connections
|
||||
func (s *Server) Start(ctx context.Context) {
|
||||
addr, err := net.ResolveTCPAddr("tcp4", s.host)
|
||||
if err != nil {
|
||||
log.Errorf("POP3 Failed to build tcp4 address: %v", err)
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("POP3 listening on TCP4 %v", addr)
|
||||
s.listener, err = net.ListenTCP("tcp4", addr)
|
||||
if err != nil {
|
||||
log.Errorf("POP3 failed to start tcp4 listener: %v", err)
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
// Listener go routine
|
||||
go s.serve(ctx)
|
||||
|
||||
// Wait for shutdown
|
||||
select {
|
||||
case _ = <-ctx.Done():
|
||||
}
|
||||
|
||||
log.Tracef("POP3 shutdown requested, connections will be drained")
|
||||
// Closing the listener will cause the serve() go routine to exit
|
||||
if err := s.listener.Close(); err != nil {
|
||||
log.Errorf("Error closing POP3 listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// serve is the listen/accept loop
|
||||
func (s *Server) serve(ctx context.Context) {
|
||||
// Handle incoming connections
|
||||
var tempDelay time.Duration
|
||||
for sid := 1; ; sid++ {
|
||||
if conn, err := s.listener.Accept(); err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
// Temporary error, sleep for a bit and try again
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 5 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if max := 1 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
log.Errorf("POP3 accept error: %v; retrying in %v", err, tempDelay)
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
} else {
|
||||
// Permanent error
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// POP3 is shutting down
|
||||
return
|
||||
default:
|
||||
// Something went wrong
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tempDelay = 0
|
||||
s.waitgroup.Add(1)
|
||||
go s.startSession(sid, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) emergencyShutdown() {
|
||||
// Shutdown Inbucket
|
||||
select {
|
||||
case _ = <-s.globalShutdown:
|
||||
default:
|
||||
close(s.globalShutdown)
|
||||
}
|
||||
}
|
||||
|
||||
// Drain causes the caller to block until all active POP3 sessions have finished
|
||||
func (s *Server) Drain() {
|
||||
// Wait for sessions to close
|
||||
s.waitgroup.Wait()
|
||||
log.Tracef("POP3 connections have drained")
|
||||
}
|
||||
628
pkg/server/smtp/handler.go
Normal file
628
pkg/server/smtp/handler.go
Normal file
@@ -0,0 +1,628 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"container/list"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
"github.com/jhillyerd/inbucket/pkg/msghub"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
"github.com/jhillyerd/inbucket/pkg/stringutil"
|
||||
)
|
||||
|
||||
// State tracks the current mode of our SMTP state machine
|
||||
type State int
|
||||
|
||||
const (
|
||||
// GREET State: Waiting for HELO
|
||||
GREET State = iota
|
||||
// READY State: Got HELO, waiting for MAIL
|
||||
READY
|
||||
// MAIL State: Got MAIL, accepting RCPTs
|
||||
MAIL
|
||||
// DATA State: Got DATA, waiting for "."
|
||||
DATA
|
||||
// QUIT State: Client requested end of session
|
||||
QUIT
|
||||
)
|
||||
|
||||
const timeStampFormat = "Mon, 02 Jan 2006 15:04:05 -0700 (MST)"
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case GREET:
|
||||
return "GREET"
|
||||
case READY:
|
||||
return "READY"
|
||||
case MAIL:
|
||||
return "MAIL"
|
||||
case DATA:
|
||||
return "DATA"
|
||||
case QUIT:
|
||||
return "QUIT"
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
var commands = map[string]bool{
|
||||
"HELO": true,
|
||||
"EHLO": true,
|
||||
"MAIL": true,
|
||||
"RCPT": true,
|
||||
"DATA": true,
|
||||
"RSET": true,
|
||||
"SEND": true,
|
||||
"SOML": true,
|
||||
"SAML": true,
|
||||
"VRFY": true,
|
||||
"EXPN": true,
|
||||
"HELP": true,
|
||||
"NOOP": true,
|
||||
"QUIT": true,
|
||||
"TURN": true,
|
||||
}
|
||||
|
||||
// recipientDetails for message delivery
|
||||
type recipientDetails struct {
|
||||
address, localPart, domainPart string
|
||||
mailbox datastore.Mailbox
|
||||
}
|
||||
|
||||
// Session holds the state of an SMTP session
|
||||
type Session struct {
|
||||
server *Server
|
||||
id int
|
||||
conn net.Conn
|
||||
remoteDomain string
|
||||
remoteHost string
|
||||
sendError error
|
||||
state State
|
||||
reader *bufio.Reader
|
||||
from string
|
||||
recipients *list.List
|
||||
}
|
||||
|
||||
// NewSession creates a new Session for the given connection
|
||||
func NewSession(server *Server, id int, conn net.Conn) *Session {
|
||||
reader := bufio.NewReader(conn)
|
||||
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
return &Session{server: server, id: id, conn: conn, state: GREET, reader: reader, remoteHost: host}
|
||||
}
|
||||
|
||||
func (ss *Session) String() string {
|
||||
return fmt.Sprintf("Session{id: %v, state: %v}", ss.id, ss.state)
|
||||
}
|
||||
|
||||
/* Session flow:
|
||||
* 1. Send initial greeting
|
||||
* 2. Receive cmd
|
||||
* 3. If good cmd, respond, optionally change state
|
||||
* 4. If bad cmd, respond error
|
||||
* 5. Goto 2
|
||||
*/
|
||||
func (s *Server) startSession(id int, conn net.Conn) {
|
||||
log.Infof("SMTP Connection from %v, starting session <%v>", conn.RemoteAddr(), id)
|
||||
expConnectsCurrent.Add(1)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Error closing connection for <%v>: %v", id, err)
|
||||
}
|
||||
s.waitgroup.Done()
|
||||
expConnectsCurrent.Add(-1)
|
||||
}()
|
||||
|
||||
ss := NewSession(s, id, conn)
|
||||
ss.greet()
|
||||
|
||||
// This is our command reading loop
|
||||
for ss.state != QUIT && ss.sendError == nil {
|
||||
if ss.state == DATA {
|
||||
// Special case, does not use SMTP command format
|
||||
ss.dataHandler()
|
||||
continue
|
||||
}
|
||||
line, err := ss.readLine()
|
||||
if err == nil {
|
||||
if cmd, arg, ok := ss.parseCmd(line); ok {
|
||||
// Check against valid SMTP commands
|
||||
if cmd == "" {
|
||||
ss.send("500 Speak up")
|
||||
continue
|
||||
}
|
||||
if !commands[cmd] {
|
||||
ss.send(fmt.Sprintf("500 Syntax error, %v command unrecognized", cmd))
|
||||
ss.logWarn("Unrecognized command: %v", cmd)
|
||||
continue
|
||||
}
|
||||
|
||||
// Commands we handle in any state
|
||||
switch cmd {
|
||||
case "SEND", "SOML", "SAML", "EXPN", "HELP", "TURN":
|
||||
// These commands are not implemented in any state
|
||||
ss.send(fmt.Sprintf("502 %v command not implemented", cmd))
|
||||
ss.logWarn("Command %v not implemented by Inbucket", cmd)
|
||||
continue
|
||||
case "VRFY":
|
||||
ss.send("252 Cannot VRFY user, but will accept message")
|
||||
continue
|
||||
case "NOOP":
|
||||
ss.send("250 I have sucessfully done nothing")
|
||||
continue
|
||||
case "RSET":
|
||||
// Reset session
|
||||
ss.logTrace("Resetting session state on RSET request")
|
||||
ss.reset()
|
||||
ss.send("250 Session reset")
|
||||
continue
|
||||
case "QUIT":
|
||||
ss.send("221 Goodnight and good luck")
|
||||
ss.enterState(QUIT)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send command to handler for current state
|
||||
switch ss.state {
|
||||
case GREET:
|
||||
ss.greetHandler(cmd, arg)
|
||||
continue
|
||||
case READY:
|
||||
ss.readyHandler(cmd, arg)
|
||||
continue
|
||||
case MAIL:
|
||||
ss.mailHandler(cmd, arg)
|
||||
continue
|
||||
}
|
||||
ss.logError("Session entered unexpected state %v", ss.state)
|
||||
break
|
||||
} else {
|
||||
ss.send("500 Syntax error, command garbled")
|
||||
}
|
||||
} else {
|
||||
// readLine() returned an error
|
||||
if err == io.EOF {
|
||||
switch ss.state {
|
||||
case GREET, READY:
|
||||
// EOF is common here
|
||||
ss.logInfo("Client closed connection (state %v)", ss.state)
|
||||
default:
|
||||
ss.logWarn("Got EOF while in state %v", ss.state)
|
||||
}
|
||||
break
|
||||
}
|
||||
// not an EOF
|
||||
ss.logWarn("Connection error: %v", err)
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
if netErr.Timeout() {
|
||||
ss.send("221 Idle timeout, bye bye")
|
||||
break
|
||||
}
|
||||
}
|
||||
ss.send("221 Connection error, sorry")
|
||||
break
|
||||
}
|
||||
}
|
||||
if ss.sendError != nil {
|
||||
ss.logWarn("Network send error: %v", ss.sendError)
|
||||
}
|
||||
ss.logInfo("Closing connection")
|
||||
}
|
||||
|
||||
// GREET state -> waiting for HELO
|
||||
func (ss *Session) greetHandler(cmd string, arg string) {
|
||||
switch cmd {
|
||||
case "HELO":
|
||||
domain, err := parseHelloArgument(arg)
|
||||
if err != nil {
|
||||
ss.send("501 Domain/address argument required for HELO")
|
||||
return
|
||||
}
|
||||
ss.remoteDomain = domain
|
||||
ss.send("250 Great, let's get this show on the road")
|
||||
ss.enterState(READY)
|
||||
case "EHLO":
|
||||
domain, err := parseHelloArgument(arg)
|
||||
if err != nil {
|
||||
ss.send("501 Domain/address argument required for EHLO")
|
||||
return
|
||||
}
|
||||
ss.remoteDomain = domain
|
||||
ss.send("250-Great, let's get this show on the road")
|
||||
ss.send("250-8BITMIME")
|
||||
ss.send(fmt.Sprintf("250 SIZE %v", ss.server.maxMessageBytes))
|
||||
ss.enterState(READY)
|
||||
default:
|
||||
ss.ooSeq(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func parseHelloArgument(arg string) (string, error) {
|
||||
domain := arg
|
||||
if idx := strings.IndexRune(arg, ' '); idx >= 0 {
|
||||
domain = arg[:idx]
|
||||
}
|
||||
if domain == "" {
|
||||
return "", fmt.Errorf("Invalid domain")
|
||||
}
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// READY state -> waiting for MAIL
|
||||
func (ss *Session) readyHandler(cmd string, arg string) {
|
||||
if cmd == "MAIL" {
|
||||
// Match FROM, while accepting '>' as quoted pair and in double quoted strings
|
||||
// (?i) makes the regex case insensitive, (?:) is non-grouping sub-match
|
||||
re := regexp.MustCompile("(?i)^FROM:\\s*<((?:\\\\>|[^>])+|\"[^\"]+\"@[^>]+)>( [\\w= ]+)?$")
|
||||
m := re.FindStringSubmatch(arg)
|
||||
if m == nil {
|
||||
ss.send("501 Was expecting MAIL arg syntax of FROM:<address>")
|
||||
ss.logWarn("Bad MAIL argument: %q", arg)
|
||||
return
|
||||
}
|
||||
from := m[1]
|
||||
if _, _, err := stringutil.ParseEmailAddress(from); err != nil {
|
||||
ss.send("501 Bad sender address syntax")
|
||||
ss.logWarn("Bad address as MAIL arg: %q, %s", from, err)
|
||||
return
|
||||
}
|
||||
// This is where the client may put BODY=8BITMIME, but we already
|
||||
// read the DATA as bytes, so it does not effect our processing.
|
||||
if m[2] != "" {
|
||||
args, ok := ss.parseArgs(m[2])
|
||||
if !ok {
|
||||
ss.send("501 Unable to parse MAIL ESMTP parameters")
|
||||
ss.logWarn("Bad MAIL argument: %q", arg)
|
||||
return
|
||||
}
|
||||
if args["SIZE"] != "" {
|
||||
size, err := strconv.ParseInt(args["SIZE"], 10, 32)
|
||||
if err != nil {
|
||||
ss.send("501 Unable to parse SIZE as an integer")
|
||||
ss.logWarn("Unable to parse SIZE %q as an integer", args["SIZE"])
|
||||
return
|
||||
}
|
||||
if int(size) > ss.server.maxMessageBytes {
|
||||
ss.send("552 Max message size exceeded")
|
||||
ss.logWarn("Client wanted to send oversized message: %v", args["SIZE"])
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
ss.from = from
|
||||
ss.recipients = list.New()
|
||||
ss.logInfo("Mail from: %v", from)
|
||||
ss.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from))
|
||||
ss.enterState(MAIL)
|
||||
} else {
|
||||
ss.ooSeq(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// MAIL state -> waiting for RCPTs followed by DATA
|
||||
func (ss *Session) mailHandler(cmd string, arg string) {
|
||||
switch cmd {
|
||||
case "RCPT":
|
||||
if (len(arg) < 4) || (strings.ToUpper(arg[0:3]) != "TO:") {
|
||||
ss.send("501 Was expecting RCPT arg syntax of TO:<address>")
|
||||
ss.logWarn("Bad RCPT argument: %q", arg)
|
||||
return
|
||||
}
|
||||
// This trim is probably too forgiving
|
||||
recip := strings.Trim(arg[3:], "<> ")
|
||||
if _, _, err := stringutil.ParseEmailAddress(recip); err != nil {
|
||||
ss.send("501 Bad recipient address syntax")
|
||||
ss.logWarn("Bad address as RCPT arg: %q, %s", recip, err)
|
||||
return
|
||||
}
|
||||
if ss.recipients.Len() >= ss.server.maxRecips {
|
||||
ss.logWarn("Maximum limit of %v recipients reached", ss.server.maxRecips)
|
||||
ss.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", ss.server.maxRecips))
|
||||
return
|
||||
}
|
||||
ss.recipients.PushBack(recip)
|
||||
ss.logInfo("Recipient: %v", recip)
|
||||
ss.send(fmt.Sprintf("250 I'll make sure <%v> gets this", recip))
|
||||
return
|
||||
case "DATA":
|
||||
if arg != "" {
|
||||
ss.send("501 DATA command should not have any arguments")
|
||||
ss.logWarn("Got unexpected args on DATA: %q", arg)
|
||||
return
|
||||
}
|
||||
if ss.recipients.Len() > 0 {
|
||||
// We have recipients, go to accept data
|
||||
ss.enterState(DATA)
|
||||
return
|
||||
}
|
||||
// DATA out of sequence
|
||||
ss.ooSeq(cmd)
|
||||
return
|
||||
}
|
||||
ss.ooSeq(cmd)
|
||||
}
|
||||
|
||||
// DATA
|
||||
func (ss *Session) dataHandler() {
|
||||
recipients := make([]recipientDetails, 0, ss.recipients.Len())
|
||||
// Get a Mailbox and a new Message for each recipient
|
||||
msgSize := 0
|
||||
if ss.server.storeMessages {
|
||||
for e := ss.recipients.Front(); e != nil; e = e.Next() {
|
||||
recip := e.Value.(string)
|
||||
local, domain, err := stringutil.ParseEmailAddress(recip)
|
||||
if err != nil {
|
||||
ss.logError("Failed to parse address for %q", recip)
|
||||
ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", recip))
|
||||
ss.reset()
|
||||
return
|
||||
}
|
||||
if strings.ToLower(domain) != ss.server.domainNoStore {
|
||||
// Not our "no store" domain, so store the message
|
||||
mb, err := ss.server.dataStore.MailboxFor(local)
|
||||
if err != nil {
|
||||
ss.logError("Failed to open mailbox for %q: %s", local, err)
|
||||
ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", local))
|
||||
ss.reset()
|
||||
return
|
||||
}
|
||||
recipients = append(recipients, recipientDetails{recip, local, domain, mb})
|
||||
} else {
|
||||
log.Tracef("Not storing message for %q", recip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ss.send("354 Start mail input; end with <CRLF>.<CRLF>")
|
||||
var lineBuf bytes.Buffer
|
||||
msgBuf := make([][]byte, 0, 1024)
|
||||
for {
|
||||
lineBuf.Reset()
|
||||
err := ss.readByteLine(&lineBuf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
if netErr.Timeout() {
|
||||
ss.send("221 Idle timeout, bye bye")
|
||||
}
|
||||
}
|
||||
ss.logWarn("Error: %v while reading", err)
|
||||
ss.enterState(QUIT)
|
||||
return
|
||||
}
|
||||
line := lineBuf.Bytes()
|
||||
// ss.logTrace("DATA: %q", line)
|
||||
if string(line) == ".\r\n" || string(line) == ".\n" {
|
||||
// Mail data complete
|
||||
if ss.server.storeMessages {
|
||||
// Create a message for each valid recipient
|
||||
for _, r := range recipients {
|
||||
// TODO temporary hack to fix #77 until datastore revamp
|
||||
mu, err := ss.server.dataStore.LockFor(r.localPart)
|
||||
if err != nil {
|
||||
ss.logError("Failed to get lock for %q: %s", r.localPart, err)
|
||||
// Delivery failure
|
||||
ss.send(fmt.Sprintf("451 Failed to store message for %v", r.localPart))
|
||||
ss.reset()
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
ok := ss.deliverMessage(r, msgBuf)
|
||||
mu.Unlock()
|
||||
if ok {
|
||||
expReceivedTotal.Add(1)
|
||||
} else {
|
||||
// Delivery failure
|
||||
ss.send(fmt.Sprintf("451 Failed to store message for %v", r.localPart))
|
||||
ss.reset()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
expReceivedTotal.Add(1)
|
||||
}
|
||||
ss.send("250 Mail accepted for delivery")
|
||||
ss.logInfo("Message size %v bytes", msgSize)
|
||||
ss.reset()
|
||||
return
|
||||
}
|
||||
// SMTP RFC says remove leading periods from input
|
||||
if len(line) > 0 && line[0] == '.' {
|
||||
line = line[1:]
|
||||
}
|
||||
// Second append copies line/lineBuf so we can reuse it
|
||||
msgBuf = append(msgBuf, append([]byte{}, line...))
|
||||
msgSize += len(line)
|
||||
if msgSize > ss.server.maxMessageBytes {
|
||||
// Max message size exceeded
|
||||
ss.send("552 Maximum message size exceeded")
|
||||
ss.logWarn("Max message size exceeded while in DATA")
|
||||
ss.reset()
|
||||
// Should really cleanup the crap on filesystem (after issue #23)
|
||||
return
|
||||
}
|
||||
} // end for
|
||||
}
|
||||
|
||||
// deliverMessage creates and populates a new Message for the specified recipient
|
||||
func (ss *Session) deliverMessage(r recipientDetails, msgBuf [][]byte) (ok bool) {
|
||||
msg, err := r.mailbox.NewMessage()
|
||||
if err != nil {
|
||||
ss.logError("Failed to create message for %q: %s", r.localPart, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Generate Received header
|
||||
stamp := time.Now().Format(timeStampFormat)
|
||||
recd := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
|
||||
ss.remoteDomain, ss.remoteHost, ss.server.domain, r.address, stamp)
|
||||
if err := msg.Append([]byte(recd)); err != nil {
|
||||
ss.logError("Failed to write received header for %q: %s", r.localPart, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Append lines from msgBuf
|
||||
for _, line := range msgBuf {
|
||||
if err := msg.Append(line); err != nil {
|
||||
ss.logError("Failed to append to mailbox %v: %v", r.mailbox, err)
|
||||
// Should really cleanup the crap on filesystem
|
||||
return false
|
||||
}
|
||||
}
|
||||
if err := msg.Close(); err != nil {
|
||||
ss.logError("Error while closing message for %v: %v", r.mailbox, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Broadcast message information
|
||||
broadcast := msghub.Message{
|
||||
Mailbox: r.mailbox.Name(),
|
||||
ID: msg.ID(),
|
||||
From: msg.From(),
|
||||
To: msg.To(),
|
||||
Subject: msg.Subject(),
|
||||
Date: msg.Date(),
|
||||
Size: msg.Size(),
|
||||
}
|
||||
ss.server.msgHub.Dispatch(broadcast)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (ss *Session) enterState(state State) {
|
||||
ss.state = state
|
||||
ss.logTrace("Entering state %v", state)
|
||||
}
|
||||
|
||||
func (ss *Session) greet() {
|
||||
ss.send(fmt.Sprintf("220 %v Inbucket SMTP ready", ss.server.domain))
|
||||
}
|
||||
|
||||
// Calculate the next read or write deadline based on maxIdleSeconds
|
||||
func (ss *Session) nextDeadline() time.Time {
|
||||
return time.Now().Add(time.Duration(ss.server.maxIdleSeconds) * time.Second)
|
||||
}
|
||||
|
||||
// Send requested message, store errors in Session.sendError
|
||||
func (ss *Session) send(msg string) {
|
||||
if err := ss.conn.SetWriteDeadline(ss.nextDeadline()); err != nil {
|
||||
ss.sendError = err
|
||||
return
|
||||
}
|
||||
if _, err := fmt.Fprint(ss.conn, msg+"\r\n"); err != nil {
|
||||
ss.sendError = err
|
||||
ss.logWarn("Failed to send: %q", msg)
|
||||
return
|
||||
}
|
||||
ss.logTrace(">> %v >>", msg)
|
||||
}
|
||||
|
||||
// readByteLine reads a line of input into the provided buffer. Does
|
||||
// not reset the Buffer - please do so prior to calling.
|
||||
func (ss *Session) readByteLine(buf io.Writer) error {
|
||||
if err := ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil {
|
||||
return err
|
||||
}
|
||||
line, err := ss.reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buf.Write(line)
|
||||
return err
|
||||
}
|
||||
|
||||
// Reads a line of input
|
||||
func (ss *Session) readLine() (line string, err error) {
|
||||
if err = ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
line, err = ss.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ss.logTrace("<< %v <<", strings.TrimRight(line, "\r\n"))
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func (ss *Session) parseCmd(line string) (cmd string, arg string, ok bool) {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
l := len(line)
|
||||
switch {
|
||||
case l == 0:
|
||||
return "", "", true
|
||||
case l < 4:
|
||||
ss.logWarn("Command too short: %q", line)
|
||||
return "", "", false
|
||||
case l == 4:
|
||||
return strings.ToUpper(line), "", true
|
||||
case l == 5:
|
||||
// Too long to be only command, too short to have args
|
||||
ss.logWarn("Mangled command: %q", line)
|
||||
return "", "", false
|
||||
}
|
||||
// If we made it here, command is long enough to have args
|
||||
if line[4] != ' ' {
|
||||
// There wasn't a space after the command?
|
||||
ss.logWarn("Mangled command: %q", line)
|
||||
return "", "", false
|
||||
}
|
||||
// I'm not sure if we should trim the args or not, but we will for now
|
||||
return strings.ToUpper(line[0:4]), strings.Trim(line[5:], " "), true
|
||||
}
|
||||
|
||||
// parseArgs takes the arguments proceeding a command and files them
|
||||
// into a map[string]string after uppercasing each key. Sample arg
|
||||
// string:
|
||||
// " BODY=8BITMIME SIZE=1024"
|
||||
// The leading space is mandatory.
|
||||
func (ss *Session) parseArgs(arg string) (args map[string]string, ok bool) {
|
||||
args = make(map[string]string)
|
||||
re := regexp.MustCompile(` (\w+)=(\w+)`)
|
||||
pm := re.FindAllStringSubmatch(arg, -1)
|
||||
if pm == nil {
|
||||
ss.logWarn("Failed to parse arg string: %q")
|
||||
return nil, false
|
||||
}
|
||||
for _, m := range pm {
|
||||
args[strings.ToUpper(m[1])] = m[2]
|
||||
}
|
||||
ss.logTrace("ESMTP params: %v", args)
|
||||
return args, true
|
||||
}
|
||||
|
||||
func (ss *Session) reset() {
|
||||
ss.enterState(READY)
|
||||
ss.from = ""
|
||||
ss.recipients = nil
|
||||
}
|
||||
|
||||
func (ss *Session) ooSeq(cmd string) {
|
||||
ss.send(fmt.Sprintf("503 Command %v is out of sequence", cmd))
|
||||
ss.logWarn("Wasn't expecting %v here", cmd)
|
||||
}
|
||||
|
||||
// Session specific logging methods
|
||||
func (ss *Session) logTrace(msg string, args ...interface{}) {
|
||||
log.Tracef("SMTP[%v]<%v> %v", ss.remoteHost, ss.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ss *Session) logInfo(msg string, args ...interface{}) {
|
||||
log.Infof("SMTP[%v]<%v> %v", ss.remoteHost, ss.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ss *Session) logWarn(msg string, args ...interface{}) {
|
||||
// Update metrics
|
||||
expWarnsTotal.Add(1)
|
||||
log.Warnf("SMTP[%v]<%v> %v", ss.remoteHost, ss.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (ss *Session) logError(msg string, args ...interface{}) {
|
||||
// Update metrics
|
||||
expErrorsTotal.Add(1)
|
||||
log.Errorf("SMTP[%v]<%v> %v", ss.remoteHost, ss.id, fmt.Sprintf(msg, args...))
|
||||
}
|
||||
409
pkg/server/smtp/handler_test.go
Normal file
409
pkg/server/smtp/handler_test.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"log"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/config"
|
||||
"github.com/jhillyerd/inbucket/pkg/msghub"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
type scriptStep struct {
|
||||
send string
|
||||
expect int
|
||||
}
|
||||
|
||||
// Test commands in GREET state
|
||||
func TestGreetState(t *testing.T) {
|
||||
// Setup mock objects
|
||||
mds := &datastore.MockDataStore{}
|
||||
|
||||
server, logbuf, teardown := setupSMTPServer(mds)
|
||||
defer teardown()
|
||||
|
||||
// Test out some mangled HELOs
|
||||
script := []scriptStep{
|
||||
{"HELO", 501},
|
||||
{"EHLO", 501},
|
||||
{"HELLO", 500},
|
||||
{"HELL", 500},
|
||||
{"hello", 500},
|
||||
{"Outlook", 500},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Valid HELOs
|
||||
if err := playSession(t, server, []scriptStep{{"HELO mydomain", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"HELO mydom.com", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"HelO mydom.com", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"helo 127.0.0.1", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Valid EHLOs
|
||||
if err := playSession(t, server, []scriptStep{{"EHLO mydomain", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"EHLO mydom.com", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"EhlO mydom.com", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := playSession(t, server, []scriptStep{{"ehlo 127.0.0.1", 250}}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
// Wait for handler to finish logging
|
||||
time.Sleep(2 * time.Second)
|
||||
// Dump buffered log data if there was a failure
|
||||
_, _ = io.Copy(os.Stderr, logbuf)
|
||||
}
|
||||
}
|
||||
|
||||
// Test commands in READY state
|
||||
func TestReadyState(t *testing.T) {
|
||||
// Setup mock objects
|
||||
mds := &datastore.MockDataStore{}
|
||||
|
||||
server, logbuf, teardown := setupSMTPServer(mds)
|
||||
defer teardown()
|
||||
|
||||
// Test out some mangled READY commands
|
||||
script := []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"FOOB", 500},
|
||||
{"HELO", 503},
|
||||
{"DATA", 503},
|
||||
{"MAIL", 501},
|
||||
{"MAIL FROM john@gmail.com", 501},
|
||||
{"MAIL FROM:john@gmail.com", 501},
|
||||
{"MAIL FROM:<john@gmail.com> SIZE=147KB", 501},
|
||||
{"MAIL FROM: <john@gmail.com> SIZE147", 501},
|
||||
{"MAIL FROM:<first@last@gmail.com>", 501},
|
||||
{"MAIL FROM:<first last@gmail.com>", 501},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test out some valid MAIL commands
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM: <john@gmail.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM: <john@gmail.com> BODY=8BITMIME", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<john@gmail.com> SIZE=1024", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<host!host!user/data@foo.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<\"first last\"@space.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<user\\@internal@external.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<user\\>name@host.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<\"user>name\"@host.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<\"user@internal\"@external.com>", 250},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
// Wait for handler to finish logging
|
||||
time.Sleep(2 * time.Second)
|
||||
// Dump buffered log data if there was a failure
|
||||
_, _ = io.Copy(os.Stderr, logbuf)
|
||||
}
|
||||
}
|
||||
|
||||
// Test commands in MAIL state
|
||||
func TestMailState(t *testing.T) {
|
||||
// Setup mock objects
|
||||
mds := &datastore.MockDataStore{}
|
||||
mb1 := &datastore.MockMailbox{}
|
||||
msg1 := &datastore.MockMessage{}
|
||||
mds.On("MailboxFor", "u1").Return(mb1, nil)
|
||||
mb1.On("NewMessage").Return(msg1, nil)
|
||||
mb1.On("Name").Return("u1")
|
||||
msg1.On("ID").Return("")
|
||||
msg1.On("From").Return("")
|
||||
msg1.On("To").Return(make([]string, 0))
|
||||
msg1.On("Date").Return(time.Time{})
|
||||
msg1.On("Subject").Return("")
|
||||
msg1.On("Size").Return(0)
|
||||
msg1.On("Close").Return(nil)
|
||||
|
||||
server, logbuf, teardown := setupSMTPServer(mds)
|
||||
defer teardown()
|
||||
|
||||
// Test out some mangled READY commands
|
||||
script := []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"FOOB", 500},
|
||||
{"HELO", 503},
|
||||
{"DATA", 503},
|
||||
{"MAIL", 503},
|
||||
{"RCPT", 501},
|
||||
{"RCPT TO", 501},
|
||||
{"RCPT TO james@gmail.com", 501},
|
||||
{"RCPT TO:<first last@host.com>", 501},
|
||||
{"RCPT TO:<fred@fish@host.com", 501},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test out some good RCPT commands
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"RCPT TO: <u2@gmail.com>", 250},
|
||||
{"RCPT TO:u3@gmail.com", 250},
|
||||
{"RCPT TO: u4@gmail.com", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<user\\@internal@external.com", 250},
|
||||
{"RCPT TO:<\"first last\"@host.com", 250},
|
||||
{"RCPT TO:<user\\>name@host.com>", 250},
|
||||
{"RCPT TO:<\"user>name\"@host.com>", 250},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test out recipient limit
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"RCPT TO:<u2@gmail.com>", 250},
|
||||
{"RCPT TO:<u3@gmail.com>", 250},
|
||||
{"RCPT TO:<u4@gmail.com>", 250},
|
||||
{"RCPT TO:<u5@gmail.com>", 250},
|
||||
{"RCPT TO:<u6@gmail.com>", 552},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test DATA
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"DATA", 354},
|
||||
{".", 250},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test RSET
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"RSET", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Test QUIT
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"QUIT", 221},
|
||||
}
|
||||
if err := playSession(t, server, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
// Wait for handler to finish logging
|
||||
time.Sleep(2 * time.Second)
|
||||
// Dump buffered log data if there was a failure
|
||||
_, _ = io.Copy(os.Stderr, logbuf)
|
||||
}
|
||||
}
|
||||
|
||||
// Test commands in DATA state
|
||||
func TestDataState(t *testing.T) {
|
||||
// Setup mock objects
|
||||
mds := &datastore.MockDataStore{}
|
||||
mb1 := &datastore.MockMailbox{}
|
||||
msg1 := &datastore.MockMessage{}
|
||||
mds.On("MailboxFor", "u1").Return(mb1, nil)
|
||||
mb1.On("NewMessage").Return(msg1, nil)
|
||||
mb1.On("Name").Return("u1")
|
||||
msg1.On("ID").Return("")
|
||||
msg1.On("From").Return("")
|
||||
msg1.On("To").Return(make([]string, 0))
|
||||
msg1.On("Date").Return(time.Time{})
|
||||
msg1.On("Subject").Return("")
|
||||
msg1.On("Size").Return(0)
|
||||
msg1.On("Close").Return(nil)
|
||||
|
||||
server, logbuf, teardown := setupSMTPServer(mds)
|
||||
defer teardown()
|
||||
|
||||
var script []scriptStep
|
||||
pipe := setupSMTPSession(server)
|
||||
c := textproto.NewConn(pipe)
|
||||
|
||||
// Get us into DATA state
|
||||
if code, _, err := c.ReadCodeLine(220); err != nil {
|
||||
t.Errorf("Expected a 220 greeting, got %v", code)
|
||||
}
|
||||
script = []scriptStep{
|
||||
{"HELO localhost", 250},
|
||||
{"MAIL FROM:<john@gmail.com>", 250},
|
||||
{"RCPT TO:<u1@gmail.com>", 250},
|
||||
{"DATA", 354},
|
||||
}
|
||||
if err := playScriptAgainst(t, c, script); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// Send a message
|
||||
body := `To: u1@gmail.com
|
||||
From: john@gmail.com
|
||||
Subject: test
|
||||
|
||||
Hi!
|
||||
`
|
||||
dw := c.DotWriter()
|
||||
_, _ = io.WriteString(dw, body)
|
||||
_ = dw.Close()
|
||||
if code, _, err := c.ReadCodeLine(250); err != nil {
|
||||
t.Errorf("Expected a 250 greeting, got %v", code)
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
// Wait for handler to finish logging
|
||||
time.Sleep(2 * time.Second)
|
||||
// Dump buffered log data if there was a failure
|
||||
_, _ = io.Copy(os.Stderr, logbuf)
|
||||
}
|
||||
}
|
||||
|
||||
// playSession creates a new session, reads the greeting and then plays the script
|
||||
func playSession(t *testing.T, server *Server, script []scriptStep) error {
|
||||
pipe := setupSMTPSession(server)
|
||||
c := textproto.NewConn(pipe)
|
||||
|
||||
if code, _, err := c.ReadCodeLine(220); err != nil {
|
||||
return fmt.Errorf("Expected a 220 greeting, got %v", code)
|
||||
}
|
||||
|
||||
err := playScriptAgainst(t, c, script)
|
||||
|
||||
// Not all tests leave the session in a clean state, so the following two
|
||||
// calls can fail
|
||||
_, _ = c.Cmd("QUIT")
|
||||
_, _, _ = c.ReadCodeLine(221)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// playScriptAgainst an existing connection, does not handle server greeting
|
||||
func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) error {
|
||||
for i, step := range script {
|
||||
id, err := c.Cmd(step.send)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Step %d, failed to send %q: %v", i, step.send, err)
|
||||
}
|
||||
|
||||
c.StartResponse(id)
|
||||
code, msg, err := c.ReadResponse(step.expect)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q",
|
||||
i, step.send, step.expect, code, msg)
|
||||
}
|
||||
c.EndResponse(id)
|
||||
|
||||
if err != nil {
|
||||
// Return after c.EndResponse so we don't hang the connection
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// net.Pipe does not implement deadlines
|
||||
type mockConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func setupSMTPServer(ds datastore.DataStore) (s *Server, buf *bytes.Buffer, teardown func()) {
|
||||
// Test Server Config
|
||||
cfg := config.SMTPConfig{
|
||||
IP4address: net.IPv4(127, 0, 0, 1),
|
||||
IP4port: 2500,
|
||||
Domain: "inbucket.local",
|
||||
DomainNoStore: "bitbucket.local",
|
||||
MaxRecipients: 5,
|
||||
MaxIdleSeconds: 5,
|
||||
MaxMessageBytes: 5000,
|
||||
StoreMessages: true,
|
||||
}
|
||||
|
||||
// Capture log output
|
||||
buf = new(bytes.Buffer)
|
||||
log.SetOutput(buf)
|
||||
|
||||
// Create a server, don't start it
|
||||
shutdownChan := make(chan bool)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
teardown = func() {
|
||||
close(shutdownChan)
|
||||
cancel()
|
||||
}
|
||||
s = NewServer(cfg, shutdownChan, ds, msghub.New(ctx, 100))
|
||||
return s, buf, teardown
|
||||
}
|
||||
|
||||
var sessionNum int
|
||||
|
||||
func setupSMTPSession(server *Server) net.Conn {
|
||||
// Pair of pipes to communicate
|
||||
serverConn, clientConn := net.Pipe()
|
||||
// Start the session
|
||||
server.waitgroup.Add(1)
|
||||
sessionNum++
|
||||
go server.startSession(sessionNum, &mockConn{serverConn})
|
||||
|
||||
return clientConn
|
||||
}
|
||||
199
pkg/server/smtp/listener.go
Normal file
199
pkg/server/smtp/listener.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/config"
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
"github.com/jhillyerd/inbucket/pkg/msghub"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
func init() {
|
||||
m := expvar.NewMap("smtp")
|
||||
m.Set("ConnectsTotal", expConnectsTotal)
|
||||
m.Set("ConnectsHist", expConnectsHist)
|
||||
m.Set("ConnectsCurrent", expConnectsCurrent)
|
||||
m.Set("ReceivedTotal", expReceivedTotal)
|
||||
m.Set("ReceivedHist", expReceivedHist)
|
||||
m.Set("ErrorsTotal", expErrorsTotal)
|
||||
m.Set("ErrorsHist", expErrorsHist)
|
||||
m.Set("WarnsTotal", expWarnsTotal)
|
||||
m.Set("WarnsHist", expWarnsHist)
|
||||
|
||||
log.AddTickerFunc(func() {
|
||||
expReceivedHist.Set(log.PushMetric(deliveredHist, expReceivedTotal))
|
||||
expConnectsHist.Set(log.PushMetric(connectsHist, expConnectsTotal))
|
||||
expErrorsHist.Set(log.PushMetric(errorsHist, expErrorsTotal))
|
||||
expWarnsHist.Set(log.PushMetric(warnsHist, expWarnsTotal))
|
||||
})
|
||||
}
|
||||
|
||||
// Server holds the configuration and state of our SMTP server
|
||||
type Server struct {
|
||||
// Configuration
|
||||
host string
|
||||
domain string
|
||||
domainNoStore string
|
||||
maxRecips int
|
||||
maxIdleSeconds int
|
||||
maxMessageBytes int
|
||||
storeMessages bool
|
||||
|
||||
// Dependencies
|
||||
dataStore datastore.DataStore // Mailbox/message store
|
||||
globalShutdown chan bool // Shuts down Inbucket
|
||||
msgHub *msghub.Hub // Pub/sub for message info
|
||||
retentionScanner *datastore.RetentionScanner // Deletes expired messages
|
||||
|
||||
// State
|
||||
listener net.Listener // Incoming network connections
|
||||
waitgroup *sync.WaitGroup // Waitgroup tracks individual sessions
|
||||
}
|
||||
|
||||
var (
|
||||
// Raw stat collectors
|
||||
expConnectsTotal = new(expvar.Int)
|
||||
expConnectsCurrent = new(expvar.Int)
|
||||
expReceivedTotal = new(expvar.Int)
|
||||
expErrorsTotal = new(expvar.Int)
|
||||
expWarnsTotal = new(expvar.Int)
|
||||
|
||||
// History of certain stats
|
||||
deliveredHist = list.New()
|
||||
connectsHist = list.New()
|
||||
errorsHist = list.New()
|
||||
warnsHist = list.New()
|
||||
|
||||
// History rendered as comma delim string
|
||||
expReceivedHist = new(expvar.String)
|
||||
expConnectsHist = new(expvar.String)
|
||||
expErrorsHist = new(expvar.String)
|
||||
expWarnsHist = new(expvar.String)
|
||||
)
|
||||
|
||||
// NewServer creates a new Server instance with the specificed config
|
||||
func NewServer(
|
||||
cfg config.SMTPConfig,
|
||||
globalShutdown chan bool,
|
||||
ds datastore.DataStore,
|
||||
msgHub *msghub.Hub) *Server {
|
||||
return &Server{
|
||||
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
|
||||
domain: cfg.Domain,
|
||||
domainNoStore: strings.ToLower(cfg.DomainNoStore),
|
||||
maxRecips: cfg.MaxRecipients,
|
||||
maxIdleSeconds: cfg.MaxIdleSeconds,
|
||||
maxMessageBytes: cfg.MaxMessageBytes,
|
||||
storeMessages: cfg.StoreMessages,
|
||||
globalShutdown: globalShutdown,
|
||||
dataStore: ds,
|
||||
msgHub: msgHub,
|
||||
retentionScanner: datastore.NewRetentionScanner(ds, globalShutdown),
|
||||
waitgroup: new(sync.WaitGroup),
|
||||
}
|
||||
}
|
||||
|
||||
// Start the listener and handle incoming connections
|
||||
func (s *Server) Start(ctx context.Context) {
|
||||
addr, err := net.ResolveTCPAddr("tcp4", s.host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to build tcp4 address: %v", err)
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("SMTP listening on TCP4 %v", addr)
|
||||
s.listener, err = net.ListenTCP("tcp4", addr)
|
||||
if err != nil {
|
||||
log.Errorf("SMTP failed to start tcp4 listener: %v", err)
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.storeMessages {
|
||||
log.Infof("Load test mode active, messages will not be stored")
|
||||
} else if s.domainNoStore != "" {
|
||||
log.Infof("Messages sent to domain '%v' will be discarded", s.domainNoStore)
|
||||
}
|
||||
|
||||
// Start retention scanner
|
||||
s.retentionScanner.Start()
|
||||
|
||||
// Listener go routine
|
||||
go s.serve(ctx)
|
||||
|
||||
// Wait for shutdown
|
||||
<-ctx.Done()
|
||||
log.Tracef("SMTP shutdown requested, connections will be drained")
|
||||
|
||||
// Closing the listener will cause the serve() go routine to exit
|
||||
if err := s.listener.Close(); err != nil {
|
||||
log.Errorf("Failed to close SMTP listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// serve is the listen/accept loop
|
||||
func (s *Server) serve(ctx context.Context) {
|
||||
// Handle incoming connections
|
||||
var tempDelay time.Duration
|
||||
for sessionID := 1; ; sessionID++ {
|
||||
if conn, err := s.listener.Accept(); err != nil {
|
||||
// There was an error accepting the connection
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
// Temporary error, sleep for a bit and try again
|
||||
if tempDelay == 0 {
|
||||
tempDelay = 5 * time.Millisecond
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
if max := 1 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
log.Errorf("SMTP accept error: %v; retrying in %v", err, tempDelay)
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
} else {
|
||||
// Permanent error
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// SMTP is shutting down
|
||||
return
|
||||
default:
|
||||
// Something went wrong
|
||||
s.emergencyShutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tempDelay = 0
|
||||
expConnectsTotal.Add(1)
|
||||
s.waitgroup.Add(1)
|
||||
go s.startSession(sessionID, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) emergencyShutdown() {
|
||||
// Shutdown Inbucket
|
||||
select {
|
||||
case <-s.globalShutdown:
|
||||
default:
|
||||
close(s.globalShutdown)
|
||||
}
|
||||
}
|
||||
|
||||
// Drain causes the caller to block until all active SMTP sessions have finished
|
||||
func (s *Server) Drain() {
|
||||
// Wait for sessions to close
|
||||
s.waitgroup.Wait()
|
||||
log.Tracef("SMTP connections have drained")
|
||||
s.retentionScanner.Join()
|
||||
}
|
||||
68
pkg/server/web/context.go
Normal file
68
pkg/server/web/context.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jhillyerd/inbucket/pkg/config"
|
||||
"github.com/jhillyerd/inbucket/pkg/msghub"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
// Context is passed into every request handler function
|
||||
type Context struct {
|
||||
Vars map[string]string
|
||||
Session *sessions.Session
|
||||
DataStore datastore.DataStore
|
||||
MsgHub *msghub.Hub
|
||||
WebConfig config.WebConfig
|
||||
IsJSON bool
|
||||
}
|
||||
|
||||
// Close the Context (currently does nothing)
|
||||
func (c *Context) Close() {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
// headerMatch returns true if the request header specified by name contains
|
||||
// the specified value. Case is ignored.
|
||||
func headerMatch(req *http.Request, name string, value string) bool {
|
||||
name = http.CanonicalHeaderKey(name)
|
||||
value = strings.ToLower(value)
|
||||
|
||||
if header := req.Header[name]; header != nil {
|
||||
for _, hv := range header {
|
||||
if value == strings.ToLower(hv) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NewContext returns a Context for the given HTTP Request
|
||||
func NewContext(req *http.Request) (*Context, error) {
|
||||
vars := mux.Vars(req)
|
||||
sess, err := sessionStore.Get(req, "inbucket")
|
||||
if err != nil {
|
||||
if sess == nil {
|
||||
// No session, must fail
|
||||
return nil, err
|
||||
}
|
||||
// The session cookie was probably signed by an old key, ignore it
|
||||
// gorilla created an empty session for us
|
||||
err = nil
|
||||
}
|
||||
ctx := &Context{
|
||||
Vars: vars,
|
||||
Session: sess,
|
||||
DataStore: DataStore,
|
||||
MsgHub: msgHub,
|
||||
WebConfig: webConfig,
|
||||
IsJSON: headerMatch(req, "Accept", "application/json"),
|
||||
}
|
||||
return ctx, err
|
||||
}
|
||||
64
pkg/server/web/helpers.go
Normal file
64
pkg/server/web/helpers.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"html/template"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
)
|
||||
|
||||
// TemplateFuncs declares functions made available to all templates (including partials)
|
||||
var TemplateFuncs = template.FuncMap{
|
||||
"friendlyTime": FriendlyTime,
|
||||
"reverse": Reverse,
|
||||
"textToHtml": TextToHTML,
|
||||
}
|
||||
|
||||
// From http://daringfireball.net/2010/07/improved_regex_for_matching_urls
|
||||
var urlRE = regexp.MustCompile("(?i)\\b((?:[a-z][\\w-]+:(?:/{1,3}|[a-z0-9%])|www\\d{0,3}[.]|[a-z0-9.\\-]+[.][a-z]{2,4}/)(?:[^\\s()<>]+|\\(([^\\s()<>]+|(\\([^\\s()<>]+\\)))*\\))+(?:\\(([^\\s()<>]+|(\\([^\\s()<>]+\\)))*\\)|[^\\s`!()\\[\\]{};:'\".,<>?«»“”‘’]))")
|
||||
|
||||
// FriendlyTime renders a timestamp in a friendly fashion: 03:04:05 PM if same day,
|
||||
// otherwise Mon Jan 2, 2006
|
||||
func FriendlyTime(t time.Time) template.HTML {
|
||||
ty, tm, td := t.Date()
|
||||
ny, nm, nd := time.Now().Date()
|
||||
if (ty == ny) && (tm == nm) && (td == nd) {
|
||||
return template.HTML(t.Format("03:04:05 PM"))
|
||||
}
|
||||
return template.HTML(t.Format("Mon Jan 2, 2006"))
|
||||
}
|
||||
|
||||
// Reverse routing function (shared with templates)
|
||||
func Reverse(name string, things ...interface{}) string {
|
||||
// Convert the things to strings
|
||||
strs := make([]string, len(things))
|
||||
for i, th := range things {
|
||||
strs[i] = fmt.Sprint(th)
|
||||
}
|
||||
// Grab the route
|
||||
u, err := Router.Get(name).URL(strs...)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to reverse route: %v", err)
|
||||
return "/ROUTE-ERROR"
|
||||
}
|
||||
return u.Path
|
||||
}
|
||||
|
||||
// TextToHTML takes plain text, escapes it and tries to pretty it up for
|
||||
// HTML display
|
||||
func TextToHTML(text string) template.HTML {
|
||||
text = html.EscapeString(text)
|
||||
text = urlRE.ReplaceAllStringFunc(text, WrapURL)
|
||||
replacer := strings.NewReplacer("\r\n", "<br/>\n", "\r", "<br/>\n", "\n", "<br/>\n")
|
||||
return template.HTML(replacer.Replace(text))
|
||||
}
|
||||
|
||||
// WrapURL wraps a <a href> tag around the provided URL
|
||||
func WrapURL(url string) string {
|
||||
unescaped := strings.Replace(url, "&", "&", -1)
|
||||
return fmt.Sprintf("<a href=\"%s\" target=\"_blank\">%s</a>", unescaped, url)
|
||||
}
|
||||
30
pkg/server/web/helpers_test.go
Normal file
30
pkg/server/web/helpers_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTextToHtml(t *testing.T) {
|
||||
// Identity
|
||||
assert.Equal(t, TextToHTML("html"), template.HTML("html"))
|
||||
|
||||
// Check it escapes
|
||||
assert.Equal(t, TextToHTML("<html>"), template.HTML("<html>"))
|
||||
|
||||
// Check for linebreaks
|
||||
assert.Equal(t, TextToHTML("line\nbreak"), template.HTML("line<br/>\nbreak"))
|
||||
assert.Equal(t, TextToHTML("line\r\nbreak"), template.HTML("line<br/>\nbreak"))
|
||||
assert.Equal(t, TextToHTML("line\rbreak"), template.HTML("line<br/>\nbreak"))
|
||||
}
|
||||
|
||||
func TestURLDetection(t *testing.T) {
|
||||
assert.Equal(t,
|
||||
TextToHTML("http://google.com/"),
|
||||
template.HTML("<a href=\"http://google.com/\" target=\"_blank\">http://google.com/</a>"))
|
||||
assert.Equal(t,
|
||||
TextToHTML("http://a.com/?q=a&n=v"),
|
||||
template.HTML("<a href=\"http://a.com/?q=a&n=v\" target=\"_blank\">http://a.com/?q=a&n=v</a>"))
|
||||
}
|
||||
15
pkg/server/web/rest.go
Normal file
15
pkg/server/web/rest.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RenderJSON sets the correct HTTP headers for JSON, then writes the specified
|
||||
// data (typically a struct) encoded in JSON
|
||||
func RenderJSON(w http.ResponseWriter, data interface{}) error {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Expires", "-1")
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(data)
|
||||
}
|
||||
159
pkg/server/web/server.go
Normal file
159
pkg/server/web/server.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Package web provides the plumbing for Inbucket's web GUI and RESTful API
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jhillyerd/inbucket/pkg/config"
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
"github.com/jhillyerd/inbucket/pkg/msghub"
|
||||
"github.com/jhillyerd/inbucket/pkg/storage"
|
||||
)
|
||||
|
||||
// Handler is a function type that handles an HTTP request in Inbucket
|
||||
type Handler func(http.ResponseWriter, *http.Request, *Context) error
|
||||
|
||||
var (
|
||||
// DataStore is where all the mailboxes and messages live
|
||||
DataStore datastore.DataStore
|
||||
|
||||
// msgHub holds a reference to the message pub/sub system
|
||||
msgHub *msghub.Hub
|
||||
|
||||
// Router is shared between httpd, webui and rest packages. It sends
|
||||
// incoming requests to the correct handler function
|
||||
Router = mux.NewRouter()
|
||||
|
||||
webConfig config.WebConfig
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
sessionStore sessions.Store
|
||||
globalShutdown chan bool
|
||||
|
||||
// ExpWebSocketConnectsCurrent tracks the number of open WebSockets
|
||||
ExpWebSocketConnectsCurrent = new(expvar.Int)
|
||||
)
|
||||
|
||||
func init() {
|
||||
m := expvar.NewMap("http")
|
||||
m.Set("WebSocketConnectsCurrent", ExpWebSocketConnectsCurrent)
|
||||
}
|
||||
|
||||
// Initialize sets up things for unit tests or the Start() method
|
||||
func Initialize(
|
||||
cfg config.WebConfig,
|
||||
shutdownChan chan bool,
|
||||
ds datastore.DataStore,
|
||||
mh *msghub.Hub) {
|
||||
|
||||
webConfig = cfg
|
||||
globalShutdown = shutdownChan
|
||||
|
||||
// NewContext() will use this DataStore for the web handlers
|
||||
DataStore = ds
|
||||
msgHub = mh
|
||||
|
||||
// Content Paths
|
||||
log.Infof("HTTP templates mapped to %q", cfg.TemplateDir)
|
||||
log.Infof("HTTP static content mapped to %q", cfg.PublicDir)
|
||||
Router.PathPrefix("/public/").Handler(http.StripPrefix("/public/",
|
||||
http.FileServer(http.Dir(cfg.PublicDir))))
|
||||
http.Handle("/", Router)
|
||||
|
||||
// Session cookie setup
|
||||
if cfg.CookieAuthKey == "" {
|
||||
log.Infof("HTTP generating random cookie.auth.key")
|
||||
sessionStore = sessions.NewCookieStore(securecookie.GenerateRandomKey(64))
|
||||
} else {
|
||||
log.Tracef("HTTP using configured cookie.auth.key")
|
||||
sessionStore = sessions.NewCookieStore([]byte(cfg.CookieAuthKey))
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for HTTP requests
|
||||
func Start(ctx context.Context) {
|
||||
addr := fmt.Sprintf("%v:%v", webConfig.IP4address, webConfig.IP4port)
|
||||
server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: nil,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
// We don't use ListenAndServe because it lacks a way to close the listener
|
||||
log.Infof("HTTP listening on TCP4 %v", addr)
|
||||
var err error
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("HTTP failed to start TCP4 listener: %v", err)
|
||||
emergencyShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
// Listener go routine
|
||||
go serve(ctx)
|
||||
|
||||
// Wait for shutdown
|
||||
select {
|
||||
case _ = <-ctx.Done():
|
||||
log.Tracef("HTTP server shutting down on request")
|
||||
}
|
||||
|
||||
// Closing the listener will cause the serve() go routine to exit
|
||||
if err := listener.Close(); err != nil {
|
||||
log.Errorf("Failed to close HTTP listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// serve begins serving HTTP requests
|
||||
func serve(ctx context.Context) {
|
||||
// server.Serve blocks until we close the listener
|
||||
err := server.Serve(listener)
|
||||
|
||||
select {
|
||||
case _ = <-ctx.Done():
|
||||
// Nop
|
||||
default:
|
||||
log.Errorf("HTTP server failed: %v", err)
|
||||
emergencyShutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP builds the context and passes onto the real handler
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// Create the context
|
||||
ctx, err := NewContext(req)
|
||||
if err != nil {
|
||||
log.Errorf("HTTP failed to create context: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer ctx.Close()
|
||||
|
||||
// Run the handler, grab the error, and report it
|
||||
log.Tracef("HTTP[%v] %v %v %q", req.RemoteAddr, req.Proto, req.Method, req.RequestURI)
|
||||
err = h(w, req, ctx)
|
||||
if err != nil {
|
||||
log.Errorf("HTTP error handling %q: %v", req.RequestURI, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func emergencyShutdown() {
|
||||
// Shutdown Inbucket
|
||||
select {
|
||||
case _ = <-globalShutdown:
|
||||
default:
|
||||
close(globalShutdown)
|
||||
}
|
||||
}
|
||||
83
pkg/server/web/template.go
Normal file
83
pkg/server/web/template.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jhillyerd/inbucket/pkg/log"
|
||||
)
|
||||
|
||||
var cachedMutex sync.Mutex
|
||||
var cachedTemplates = map[string]*template.Template{}
|
||||
var cachedPartials = map[string]*template.Template{}
|
||||
|
||||
// RenderTemplate fetches the named template and renders it to the provided
|
||||
// ResponseWriter.
|
||||
func RenderTemplate(name string, w http.ResponseWriter, data interface{}) error {
|
||||
t, err := ParseTemplate(name, false)
|
||||
if err != nil {
|
||||
log.Errorf("Error in template '%v': %v", name, err)
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Expires", "-1")
|
||||
return t.Execute(w, data)
|
||||
}
|
||||
|
||||
// RenderPartial fetches the named template and renders it to the provided
|
||||
// ResponseWriter.
|
||||
func RenderPartial(name string, w http.ResponseWriter, data interface{}) error {
|
||||
t, err := ParseTemplate(name, true)
|
||||
if err != nil {
|
||||
log.Errorf("Error in template '%v': %v", name, err)
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Expires", "-1")
|
||||
return t.Execute(w, data)
|
||||
}
|
||||
|
||||
// ParseTemplate loads the requested template along with _base.html, caching
|
||||
// the result (if configured to do so)
|
||||
func ParseTemplate(name string, partial bool) (*template.Template, error) {
|
||||
cachedMutex.Lock()
|
||||
defer cachedMutex.Unlock()
|
||||
|
||||
if t, ok := cachedTemplates[name]; ok {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
tempPath := strings.Replace(name, "/", string(filepath.Separator), -1)
|
||||
tempFile := filepath.Join(webConfig.TemplateDir, tempPath)
|
||||
log.Tracef("Parsing template %v", tempFile)
|
||||
|
||||
var err error
|
||||
var t *template.Template
|
||||
if partial {
|
||||
// Need to get basename of file to make it root template w/ funcs
|
||||
base := path.Base(name)
|
||||
t = template.New(base).Funcs(TemplateFuncs)
|
||||
t, err = t.ParseFiles(tempFile)
|
||||
} else {
|
||||
t = template.New("_base.html").Funcs(TemplateFuncs)
|
||||
t, err = t.ParseFiles(filepath.Join(webConfig.TemplateDir, "_base.html"), tempFile)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Allows us to disable caching for theme development
|
||||
if webConfig.TemplateCache {
|
||||
if partial {
|
||||
log.Tracef("Caching partial %v", name)
|
||||
cachedTemplates[name] = t
|
||||
} else {
|
||||
log.Tracef("Caching template %v", name)
|
||||
cachedTemplates[name] = t
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
Reference in New Issue
Block a user