1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 17:47:03 +00:00

Merge branch 'feature/filestore' into develop #67

This commit is contained in:
James Hillyerd
2017-12-26 23:17:01 -08:00
25 changed files with 369 additions and 429 deletions

View File

@@ -1,4 +1,5 @@
package smtpd // Package datastore contains implementation independent datastore logic
package datastore
import ( import (
"errors" "errors"

View File

@@ -1,4 +1,4 @@
package smtpd package datastore
import ( import (
"container/list" "container/list"
@@ -36,6 +36,11 @@ func init() {
rm.Set("Period", expRetentionPeriod) rm.Set("Period", expRetentionPeriod)
rm.Set("RetainedHist", expRetainedHist) rm.Set("RetainedHist", expRetainedHist)
rm.Set("RetainedCurrent", expRetainedCurrent) rm.Set("RetainedCurrent", expRetainedCurrent)
log.AddTickerFunc(func() {
expRetentionDeletesHist.Set(log.PushMetric(retentionDeletesHist, expRetentionDeletesTotal))
expRetainedHist.Set(log.PushMetric(retainedHist, expRetainedCurrent))
})
} }
// RetentionScanner looks for messages older than the configured retention period and deletes them. // RetentionScanner looks for messages older than the configured retention period and deletes them.
@@ -85,9 +90,9 @@ retentionLoop:
dur := time.Minute - since dur := time.Minute - since
log.Tracef("Retention scanner sleeping for %v", dur) log.Tracef("Retention scanner sleeping for %v", dur)
select { select {
case _ = <-rs.globalShutdown: case <-rs.globalShutdown:
break retentionLoop break retentionLoop
case _ = <-time.After(dur): case <-time.After(dur):
} }
} }
// Kickoff scan // Kickoff scan
@@ -97,7 +102,7 @@ retentionLoop:
} }
// Check for global shutdown // Check for global shutdown
select { select {
case _ = <-rs.globalShutdown: case <-rs.globalShutdown:
break retentionLoop break retentionLoop
default: default:
} }
@@ -154,9 +159,7 @@ func (rs *RetentionScanner) doScan() error {
// Join does not retun until the retention scanner has shut down // Join does not retun until the retention scanner has shut down
func (rs *RetentionScanner) Join() { func (rs *RetentionScanner) Join() {
if rs.retentionShutdown != nil { if rs.retentionShutdown != nil {
select { <-rs.retentionShutdown
case <-rs.retentionShutdown:
}
} }
} }

View File

@@ -0,0 +1,67 @@
package datastore
import (
"fmt"
"testing"
"time"
)
func TestDoRetentionScan(t *testing.T) {
// Create mock objects
mds := &MockDataStore{}
mb1 := &MockMailbox{}
mb2 := &MockMailbox{}
mb3 := &MockMailbox{}
// Mockup some different aged messages (num is in hours)
new1 := mockMessage(0)
new2 := mockMessage(1)
new3 := mockMessage(2)
old1 := mockMessage(4)
old2 := mockMessage(12)
old3 := mockMessage(24)
// First it should ask for all mailboxes
mds.On("AllMailboxes").Return([]Mailbox{mb1, mb2, mb3}, nil)
// Then for all messages on each box
mb1.On("GetMessages").Return([]Message{new1, old1, old2}, nil)
mb2.On("GetMessages").Return([]Message{old3, new2}, nil)
mb3.On("GetMessages").Return([]Message{new3}, nil)
// Test 4 hour retention
rs := &RetentionScanner{
ds: mds,
retentionPeriod: 4*time.Hour - time.Minute,
retentionSleep: 0,
}
if err := rs.doScan(); err != nil {
t.Error(err)
}
// Check our assertions
mds.AssertExpectations(t)
mb1.AssertExpectations(t)
mb2.AssertExpectations(t)
mb3.AssertExpectations(t)
// Delete should not have been called on new messages
new1.AssertNotCalled(t, "Delete")
new2.AssertNotCalled(t, "Delete")
new3.AssertNotCalled(t, "Delete")
// Delete should have been called once on old messages
old1.AssertNumberOfCalls(t, "Delete", 1)
old2.AssertNumberOfCalls(t, "Delete", 1)
old3.AssertNumberOfCalls(t, "Delete", 1)
}
// Make a MockMessage of a specific age
func mockMessage(ageHours int) *MockMessage {
msg := &MockMessage{}
msg.On("ID").Return(fmt.Sprintf("MSG[age=%vh]", ageHours))
msg.On("Date").Return(time.Now().Add(time.Duration(ageHours*-1) * time.Hour))
msg.On("Delete").Return(nil)
return msg
}

View File

@@ -1,4 +1,4 @@
package rest package datastore
import ( import (
"io" "io"
@@ -6,130 +6,151 @@ import (
"time" "time"
"github.com/jhillyerd/enmime" "github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/smtpd"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
// Mock DataStore object // MockDataStore is a shared mock for unit testing
type MockDataStore struct { type MockDataStore struct {
mock.Mock mock.Mock
} }
func (m *MockDataStore) MailboxFor(name string) (smtpd.Mailbox, error) { // MailboxFor mock function
func (m *MockDataStore) MailboxFor(name string) (Mailbox, error) {
args := m.Called(name) args := m.Called(name)
return args.Get(0).(smtpd.Mailbox), args.Error(1) return args.Get(0).(Mailbox), args.Error(1)
} }
func (m *MockDataStore) AllMailboxes() ([]smtpd.Mailbox, error) { // AllMailboxes mock function
func (m *MockDataStore) AllMailboxes() ([]Mailbox, error) {
args := m.Called() args := m.Called()
return args.Get(0).([]smtpd.Mailbox), args.Error(1) return args.Get(0).([]Mailbox), args.Error(1)
} }
// Mock Mailbox object // MockMailbox is a shared mock for unit testing
type MockMailbox struct { type MockMailbox struct {
mock.Mock mock.Mock
} }
func (m *MockMailbox) GetMessages() ([]smtpd.Message, error) { // GetMessages mock function
func (m *MockMailbox) GetMessages() ([]Message, error) {
args := m.Called() args := m.Called()
return args.Get(0).([]smtpd.Message), args.Error(1) return args.Get(0).([]Message), args.Error(1)
} }
func (m *MockMailbox) GetMessage(id string) (smtpd.Message, error) { // GetMessage mock function
func (m *MockMailbox) GetMessage(id string) (Message, error) {
args := m.Called(id) args := m.Called(id)
return args.Get(0).(smtpd.Message), args.Error(1) return args.Get(0).(Message), args.Error(1)
} }
// Purge mock function
func (m *MockMailbox) Purge() error { func (m *MockMailbox) Purge() error {
args := m.Called() args := m.Called()
return args.Error(0) return args.Error(0)
} }
func (m *MockMailbox) NewMessage() (smtpd.Message, error) { // NewMessage mock function
func (m *MockMailbox) NewMessage() (Message, error) {
args := m.Called() args := m.Called()
return args.Get(0).(smtpd.Message), args.Error(1) return args.Get(0).(Message), args.Error(1)
} }
// Name mock function
func (m *MockMailbox) Name() string { func (m *MockMailbox) Name() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)
} }
// String mock function
func (m *MockMailbox) String() string { func (m *MockMailbox) String() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)
} }
// Mock Message object // MockMessage is a shared mock for unit testing
type MockMessage struct { type MockMessage struct {
mock.Mock mock.Mock
} }
// ID mock function
func (m *MockMessage) ID() string { func (m *MockMessage) ID() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)
} }
// From mock function
func (m *MockMessage) From() string { func (m *MockMessage) From() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)
} }
// To mock function
func (m *MockMessage) To() []string { func (m *MockMessage) To() []string {
args := m.Called() args := m.Called()
return args.Get(0).([]string) return args.Get(0).([]string)
} }
// Date mock function
func (m *MockMessage) Date() time.Time { func (m *MockMessage) Date() time.Time {
args := m.Called() args := m.Called()
return args.Get(0).(time.Time) return args.Get(0).(time.Time)
} }
// Subject mock function
func (m *MockMessage) Subject() string { func (m *MockMessage) Subject() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)
} }
// ReadHeader mock function
func (m *MockMessage) ReadHeader() (msg *mail.Message, err error) { func (m *MockMessage) ReadHeader() (msg *mail.Message, err error) {
args := m.Called() args := m.Called()
return args.Get(0).(*mail.Message), args.Error(1) return args.Get(0).(*mail.Message), args.Error(1)
} }
// ReadBody mock function
func (m *MockMessage) ReadBody() (body *enmime.Envelope, err error) { func (m *MockMessage) ReadBody() (body *enmime.Envelope, err error) {
args := m.Called() args := m.Called()
return args.Get(0).(*enmime.Envelope), args.Error(1) return args.Get(0).(*enmime.Envelope), args.Error(1)
} }
// ReadRaw mock function
func (m *MockMessage) ReadRaw() (raw *string, err error) { func (m *MockMessage) ReadRaw() (raw *string, err error) {
args := m.Called() args := m.Called()
return args.Get(0).(*string), args.Error(1) return args.Get(0).(*string), args.Error(1)
} }
// RawReader mock function
func (m *MockMessage) RawReader() (reader io.ReadCloser, err error) { func (m *MockMessage) RawReader() (reader io.ReadCloser, err error) {
args := m.Called() args := m.Called()
return args.Get(0).(io.ReadCloser), args.Error(1) return args.Get(0).(io.ReadCloser), args.Error(1)
} }
// Size mock function
func (m *MockMessage) Size() int64 { func (m *MockMessage) Size() int64 {
args := m.Called() args := m.Called()
return int64(args.Int(0)) return int64(args.Int(0))
} }
// Append mock function
func (m *MockMessage) Append(data []byte) error { func (m *MockMessage) Append(data []byte) error {
// []byte arg seems to mess up testify/mock // []byte arg seems to mess up testify/mock
return nil return nil
} }
// Close mock function
func (m *MockMessage) Close() error { func (m *MockMessage) Close() error {
args := m.Called() args := m.Called()
return args.Error(0) return args.Error(0)
} }
// Delete mock function
func (m *MockMessage) Delete() error { func (m *MockMessage) Delete() error {
args := m.Called() args := m.Called()
return args.Error(0) return args.Error(0)
} }
// String mock function
func (m *MockMessage) String() string { func (m *MockMessage) String() string {
args := m.Called() args := m.Called()
return args.String(0) return args.String(0)

View File

@@ -1,4 +1,4 @@
package smtpd package filestore
import ( import (
"bufio" "bufio"
@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/jhillyerd/enmime" "github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
) )
@@ -33,7 +34,7 @@ type FileMessage struct {
// NewMessage creates a new FileMessage object and sets the Date and Id fields. // NewMessage creates a new FileMessage object and sets the Date and Id fields.
// It will also delete messages over messageCap if configured. // It will also delete messages over messageCap if configured.
func (mb *FileMailbox) NewMessage() (Message, error) { func (mb *FileMailbox) NewMessage() (datastore.Message, error) {
// Load index // Load index
if !mb.indexLoaded { if !mb.indexLoaded {
if err := mb.readIndex(); err != nil { if err := mb.readIndex(); err != nil {
@@ -71,7 +72,7 @@ func (m *FileMessage) From() string {
return m.Ffrom return m.Ffrom
} }
// From returns the value of the Message To header // To returns the value of the Message To header
func (m *FileMessage) To() []string { func (m *FileMessage) To() []string {
return m.Fto return m.Fto
} }
@@ -165,7 +166,7 @@ func (m *FileMessage) ReadRaw() (raw *string, err error) {
func (m *FileMessage) Append(data []byte) error { func (m *FileMessage) Append(data []byte) error {
// Prevent Appending to a pre-existing Message // Prevent Appending to a pre-existing Message
if !m.writable { if !m.writable {
return ErrNotWritable return datastore.ErrNotWritable
} }
// Open file for writing if we haven't yet // Open file for writing if we haven't yet
if m.writer == nil { if m.writer == nil {

View File

@@ -1,4 +1,4 @@
package smtpd package filestore
import ( import (
"bufio" "bufio"
@@ -12,7 +12,9 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/stringutil"
) )
// Name of index file in each mailbox // Name of index file in each mailbox
@@ -55,7 +57,7 @@ type FileDataStore struct {
} }
// NewFileDataStore creates a new DataStore object using the specified path // NewFileDataStore creates a new DataStore object using the specified path
func NewFileDataStore(cfg config.DataStoreConfig) DataStore { func NewFileDataStore(cfg config.DataStoreConfig) datastore.DataStore {
path := cfg.Path path := cfg.Path
if path == "" { if path == "" {
log.Errorf("No value configured for datastore path") log.Errorf("No value configured for datastore path")
@@ -73,19 +75,19 @@ func NewFileDataStore(cfg config.DataStoreConfig) DataStore {
// DefaultFileDataStore creates a new DataStore object. It uses the inbucket.Config object to // DefaultFileDataStore creates a new DataStore object. It uses the inbucket.Config object to
// construct it's path. // construct it's path.
func DefaultFileDataStore() DataStore { func DefaultFileDataStore() datastore.DataStore {
cfg := config.GetDataStoreConfig() cfg := config.GetDataStoreConfig()
return NewFileDataStore(cfg) return NewFileDataStore(cfg)
} }
// MailboxFor retrieves the Mailbox object for a specified email address, if the mailbox // MailboxFor retrieves the Mailbox object for a specified email address, if the mailbox
// does not exist, it will attempt to create it. // does not exist, it will attempt to create it.
func (ds *FileDataStore) MailboxFor(emailAddress string) (Mailbox, error) { func (ds *FileDataStore) MailboxFor(emailAddress string) (datastore.Mailbox, error) {
name, err := ParseMailboxName(emailAddress) name, err := stringutil.ParseMailboxName(emailAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dir := HashMailboxName(name) dir := stringutil.HashMailboxName(name)
s1 := dir[0:3] s1 := dir[0:3]
s2 := dir[0:6] s2 := dir[0:6]
path := filepath.Join(ds.mailPath, s1, s2, dir) path := filepath.Join(ds.mailPath, s1, s2, dir)
@@ -96,8 +98,8 @@ func (ds *FileDataStore) MailboxFor(emailAddress string) (Mailbox, error) {
} }
// AllMailboxes returns a slice with all Mailboxes // AllMailboxes returns a slice with all Mailboxes
func (ds *FileDataStore) AllMailboxes() ([]Mailbox, error) { func (ds *FileDataStore) AllMailboxes() ([]datastore.Mailbox, error) {
mailboxes := make([]Mailbox, 0, 100) mailboxes := make([]datastore.Mailbox, 0, 100)
infos1, err := ioutil.ReadDir(ds.mailPath) infos1, err := ioutil.ReadDir(ds.mailPath)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -149,24 +151,26 @@ type FileMailbox struct {
messages []*FileMessage messages []*FileMessage
} }
// Name of the mailbox
func (mb *FileMailbox) Name() string { func (mb *FileMailbox) Name() string {
return mb.name return mb.name
} }
// String renders the name and directory path of the mailbox
func (mb *FileMailbox) String() string { func (mb *FileMailbox) String() string {
return mb.name + "[" + mb.dirName + "]" return mb.name + "[" + mb.dirName + "]"
} }
// GetMessages scans the mailbox directory for .gob files and decodes them into // GetMessages scans the mailbox directory for .gob files and decodes them into
// a slice of Message objects. // a slice of Message objects.
func (mb *FileMailbox) GetMessages() ([]Message, error) { func (mb *FileMailbox) GetMessages() ([]datastore.Message, error) {
if !mb.indexLoaded { if !mb.indexLoaded {
if err := mb.readIndex(); err != nil { if err := mb.readIndex(); err != nil {
return nil, err return nil, err
} }
} }
messages := make([]Message, len(mb.messages)) messages := make([]datastore.Message, len(mb.messages))
for i, m := range mb.messages { for i, m := range mb.messages {
messages[i] = m messages[i] = m
} }
@@ -174,7 +178,7 @@ func (mb *FileMailbox) GetMessages() ([]Message, error) {
} }
// GetMessage decodes a single message by Id and returns a Message object // GetMessage decodes a single message by Id and returns a Message object
func (mb *FileMailbox) GetMessage(id string) (Message, error) { func (mb *FileMailbox) GetMessage(id string) (datastore.Message, error) {
if !mb.indexLoaded { if !mb.indexLoaded {
if err := mb.readIndex(); err != nil { if err := mb.readIndex(); err != nil {
return nil, err return nil, err
@@ -183,15 +187,15 @@ func (mb *FileMailbox) GetMessage(id string) (Message, error) {
if id == "latest" && len(mb.messages) != 0 { if id == "latest" && len(mb.messages) != 0 {
return mb.messages[len(mb.messages)-1], nil return mb.messages[len(mb.messages)-1], nil
} else { }
for _, m := range mb.messages {
if m.Fid == id { for _, m := range mb.messages {
return m, nil if m.Fid == id {
} return m, nil
} }
} }
return nil, ErrNotExist return nil, datastore.ErrNotExist
} }
// Purge deletes all messages in this mailbox // Purge deletes all messages in this mailbox

View File

@@ -1,4 +1,4 @@
package smtpd package filestore
import ( import (
"bytes" "bytes"
@@ -470,8 +470,8 @@ func TestGetLatestMessage(t *testing.T) {
mb, err := ds.MailboxFor(mbName) mb, err := ds.MailboxFor(mbName)
assert.Nil(t, err) assert.Nil(t, err)
msg, err := mb.GetMessage("latest") msg, err := mb.GetMessage("latest")
assert.Nil(t, msg)
assert.Error(t, err) assert.Error(t, err)
fmt.Println(msg)
// Deliver test message // Deliver test message
deliverMessage(ds, mbName, "test", time.Now()) deliverMessage(ds, mbName, "test", time.Now())
@@ -496,7 +496,7 @@ func TestGetLatestMessage(t *testing.T) {
assert.True(t, msg.ID() == id3, "Expected %q to be equal to %q", msg.ID(), id3) assert.True(t, msg.ID() == id3, "Expected %q to be equal to %q", msg.ID(), id3)
// Test wrong id // Test wrong id
msg, err = mb.GetMessage("wrongid") _, err = mb.GetMessage("wrongid")
assert.Error(t, err) assert.Error(t, err)
if t.Failed() { if t.Failed() {

View File

@@ -7,15 +7,15 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
"github.com/jhillyerd/inbucket/smtpd"
) )
// Context is passed into every request handler function // Context is passed into every request handler function
type Context struct { type Context struct {
Vars map[string]string Vars map[string]string
Session *sessions.Session Session *sessions.Session
DataStore smtpd.DataStore DataStore datastore.DataStore
MsgHub *msghub.Hub MsgHub *msghub.Hub
WebConfig config.WebConfig WebConfig config.WebConfig
IsJSON bool IsJSON bool

View File

@@ -13,9 +13,9 @@ import (
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
"github.com/jhillyerd/inbucket/smtpd"
) )
// Handler is a function type that handles an HTTP request in Inbucket // Handler is a function type that handles an HTTP request in Inbucket
@@ -23,7 +23,7 @@ type Handler func(http.ResponseWriter, *http.Request, *Context) error
var ( var (
// DataStore is where all the mailboxes and messages live // DataStore is where all the mailboxes and messages live
DataStore smtpd.DataStore DataStore datastore.DataStore
// msgHub holds a reference to the message pub/sub system // msgHub holds a reference to the message pub/sub system
msgHub *msghub.Hub msgHub *msghub.Hub
@@ -51,7 +51,7 @@ func init() {
func Initialize( func Initialize(
cfg config.WebConfig, cfg config.WebConfig,
shutdownChan chan bool, shutdownChan chan bool,
ds smtpd.DataStore, ds datastore.DataStore,
mh *msghub.Hub) { mh *msghub.Hub) {
webConfig = cfg webConfig = cfg

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/filestore"
"github.com/jhillyerd/inbucket/httpd" "github.com/jhillyerd/inbucket/httpd"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
@@ -115,7 +116,7 @@ func main() {
msgHub := msghub.New(rootCtx, config.GetWebConfig().MonitorHistory) msgHub := msghub.New(rootCtx, config.GetWebConfig().MonitorHistory)
// Grab our datastore // Grab our datastore
ds := smtpd.DefaultFileDataStore() ds := filestore.DefaultFileDataStore()
// Start HTTP server // Start HTTP server
httpd.Initialize(config.GetWebConfig(), shutdownChan, ds, msgHub) httpd.Initialize(config.GetWebConfig(), shutdownChan, ds, msgHub)

62
log/metrics.go Normal file
View File

@@ -0,0 +1,62 @@
package log
import (
"container/list"
"expvar"
"strings"
"time"
)
// TickerFunc is the type of metrics function accepted by AddTickerFunc
type TickerFunc func()
var tickerFuncChan = make(chan TickerFunc)
func init() {
go metricsTicker()
}
// AddTickerFunc adds a new function callback to the list of metrics TickerFuncs that get
// called each minute.
func AddTickerFunc(f TickerFunc) {
tickerFuncChan <- f
}
// PushMetric adds the metric to the end of the list and returns a comma separated string of the
// previous 61 entries. We return 61 instead of 60 (an hour) because the chart on the client
// tracks deltas between these values - there is nothing to compare the first value against.
func PushMetric(history *list.List, ev expvar.Var) string {
history.PushBack(ev.String())
if history.Len() > 61 {
history.Remove(history.Front())
}
return joinStringList(history)
}
// joinStringList joins a List containing strings by commas
func joinStringList(listOfStrings *list.List) string {
if listOfStrings.Len() == 0 {
return ""
}
s := make([]string, 0, listOfStrings.Len())
for e := listOfStrings.Front(); e != nil; e = e.Next() {
s = append(s, e.Value.(string))
}
return strings.Join(s, ",")
}
func metricsTicker() {
funcs := make([]TickerFunc, 0)
ticker := time.NewTicker(time.Minute)
for {
select {
case <-ticker.C:
for _, f := range funcs {
f()
}
case f := <-tickerFuncChan:
funcs = append(funcs, f)
}
}
}

View File

@@ -11,8 +11,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/smtpd"
) )
// State tracks the current mode of our POP3 state machine // State tracks the current mode of our POP3 state machine
@@ -57,18 +57,18 @@ var commands = map[string]bool{
// Session defines an active POP3 session // Session defines an active POP3 session
type Session struct { type Session struct {
server *Server // Reference to the server we belong to server *Server // Reference to the server we belong to
id int // Session ID number id int // Session ID number
conn net.Conn // Our network connection conn net.Conn // Our network connection
remoteHost string // IP address of client remoteHost string // IP address of client
sendError error // Used to bail out of read loop on send error sendError error // Used to bail out of read loop on send error
state State // Current session state state State // Current session state
reader *bufio.Reader // Buffered reader for our net conn reader *bufio.Reader // Buffered reader for our net conn
user string // Mailbox name user string // Mailbox name
mailbox smtpd.Mailbox // Mailbox instance mailbox datastore.Mailbox // Mailbox instance
messages []smtpd.Message // Slice of messages in mailbox messages []datastore.Message // Slice of messages in mailbox
retain []bool // Messages to retain upon UPDATE (true=retain) retain []bool // Messages to retain upon UPDATE (true=retain)
msgCount int // Number of undeleted messages msgCount int // Number of undeleted messages
} }
// NewSession creates a new POP3 session // NewSession creates a new POP3 session
@@ -432,7 +432,7 @@ func (ses *Session) transactionHandler(cmd string, args []string) {
} }
// Send the contents of the message to the client // Send the contents of the message to the client
func (ses *Session) sendMessage(msg smtpd.Message) { func (ses *Session) sendMessage(msg datastore.Message) {
reader, err := msg.RawReader() reader, err := msg.RawReader()
if err != nil { if err != nil {
ses.logError("Failed to read message for RETR command") ses.logError("Failed to read message for RETR command")
@@ -465,7 +465,7 @@ func (ses *Session) sendMessage(msg smtpd.Message) {
} }
// Send the headers plus the top N lines to the client // Send the headers plus the top N lines to the client
func (ses *Session) sendMessageTop(msg smtpd.Message, lineCount int) { func (ses *Session) sendMessageTop(msg datastore.Message, lineCount int) {
reader, err := msg.RawReader() reader, err := msg.RawReader()
if err != nil { if err != nil {
ses.logError("Failed to read message for RETR command") ses.logError("Failed to read message for RETR command")

View File

@@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/smtpd"
) )
// Server defines an instance of our POP3 server // Server defines an instance of our POP3 server
@@ -17,14 +17,14 @@ type Server struct {
host string host string
domain string domain string
maxIdleSeconds int maxIdleSeconds int
dataStore smtpd.DataStore dataStore datastore.DataStore
listener net.Listener listener net.Listener
globalShutdown chan bool globalShutdown chan bool
waitgroup *sync.WaitGroup waitgroup *sync.WaitGroup
} }
// New creates a new Server struct // New creates a new Server struct
func New(cfg config.POP3Config, shutdownChan chan bool, ds smtpd.DataStore) *Server { func New(cfg config.POP3Config, shutdownChan chan bool, ds datastore.DataStore) *Server {
return &Server{ return &Server{
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port), host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
domain: cfg.Domain, domain: cfg.Domain,

View File

@@ -10,16 +10,17 @@ import (
"io/ioutil" "io/ioutil"
"strconv" "strconv"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/httpd" "github.com/jhillyerd/inbucket/httpd"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/rest/model" "github.com/jhillyerd/inbucket/rest/model"
"github.com/jhillyerd/inbucket/smtpd" "github.com/jhillyerd/inbucket/stringutil"
) )
// MailboxListV1 renders a list of messages in a mailbox // MailboxListV1 renders a list of messages in a mailbox
func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -54,7 +55,7 @@ func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -64,7 +65,7 @@ func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
msg, err := mb.GetMessage(id) msg, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -116,7 +117,7 @@ func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
// MailboxPurgeV1 deletes all messages from a mailbox // MailboxPurgeV1 deletes all messages from a mailbox
func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -139,7 +140,7 @@ func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context
func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -149,7 +150,7 @@ func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Contex
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -173,7 +174,7 @@ func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Contex
func MailboxDeleteV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxDeleteV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -183,7 +184,7 @@ func MailboxDeleteV1(w http.ResponseWriter, req *http.Request, ctx *httpd.Contex
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }

View File

@@ -9,7 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/jhillyerd/inbucket/smtpd" "github.com/jhillyerd/inbucket/datastore"
) )
const ( const (
@@ -31,7 +31,7 @@ const (
func TestRestMailboxList(t *testing.T) { func TestRestMailboxList(t *testing.T) {
// Setup // Setup
ds := &MockDataStore{} ds := &datastore.MockDataStore{}
logbuf := setupWebServer(ds) logbuf := setupWebServer(ds)
// Test invalid mailbox name // Test invalid mailbox name
@@ -45,9 +45,9 @@ func TestRestMailboxList(t *testing.T) {
} }
// Test empty mailbox // Test empty mailbox
emptybox := &MockMailbox{} emptybox := &datastore.MockMailbox{}
ds.On("MailboxFor", "empty").Return(emptybox, nil) ds.On("MailboxFor", "empty").Return(emptybox, nil)
emptybox.On("GetMessages").Return([]smtpd.Message{}, nil) emptybox.On("GetMessages").Return([]datastore.Message{}, nil)
w, err = testRestGet(baseURL + "/mailbox/empty") w, err = testRestGet(baseURL + "/mailbox/empty")
expectCode = 200 expectCode = 200
@@ -59,7 +59,7 @@ func TestRestMailboxList(t *testing.T) {
} }
// Test MailboxFor error // Test MailboxFor error
ds.On("MailboxFor", "error").Return(&MockMailbox{}, fmt.Errorf("Internal error")) ds.On("MailboxFor", "error").Return(&datastore.MockMailbox{}, fmt.Errorf("Internal error"))
w, err = testRestGet(baseURL + "/mailbox/error") w, err = testRestGet(baseURL + "/mailbox/error")
expectCode = 500 expectCode = 500
if err != nil { if err != nil {
@@ -77,9 +77,9 @@ func TestRestMailboxList(t *testing.T) {
} }
// Test MailboxFor error // Test MailboxFor error
error2box := &MockMailbox{} error2box := &datastore.MockMailbox{}
ds.On("MailboxFor", "error2").Return(error2box, nil) ds.On("MailboxFor", "error2").Return(error2box, nil)
error2box.On("GetMessages").Return([]smtpd.Message{}, fmt.Errorf("Internal error 2")) error2box.On("GetMessages").Return([]datastore.Message{}, fmt.Errorf("Internal error 2"))
w, err = testRestGet(baseURL + "/mailbox/error2") w, err = testRestGet(baseURL + "/mailbox/error2")
expectCode = 500 expectCode = 500
@@ -107,11 +107,11 @@ func TestRestMailboxList(t *testing.T) {
Subject: "subject 2", Subject: "subject 2",
Date: time.Date(2012, 7, 1, 10, 11, 12, 253, time.FixedZone("PDT", -700)), Date: time.Date(2012, 7, 1, 10, 11, 12, 253, time.FixedZone("PDT", -700)),
} }
goodbox := &MockMailbox{} goodbox := &datastore.MockMailbox{}
ds.On("MailboxFor", "good").Return(goodbox, nil) ds.On("MailboxFor", "good").Return(goodbox, nil)
msg1 := data1.MockMessage() msg1 := data1.MockMessage()
msg2 := data2.MockMessage() msg2 := data2.MockMessage()
goodbox.On("GetMessages").Return([]smtpd.Message{msg1, msg2}, nil) goodbox.On("GetMessages").Return([]datastore.Message{msg1, msg2}, nil)
// Check return code // Check return code
w, err = testRestGet(baseURL + "/mailbox/good") w, err = testRestGet(baseURL + "/mailbox/good")
@@ -155,7 +155,7 @@ func TestRestMailboxList(t *testing.T) {
func TestRestMessage(t *testing.T) { func TestRestMessage(t *testing.T) {
// Setup // Setup
ds := &MockDataStore{} ds := &datastore.MockDataStore{}
logbuf := setupWebServer(ds) logbuf := setupWebServer(ds)
// Test invalid mailbox name // Test invalid mailbox name
@@ -169,9 +169,9 @@ func TestRestMessage(t *testing.T) {
} }
// Test requesting a message that does not exist // Test requesting a message that does not exist
emptybox := &MockMailbox{} emptybox := &datastore.MockMailbox{}
ds.On("MailboxFor", "empty").Return(emptybox, nil) ds.On("MailboxFor", "empty").Return(emptybox, nil)
emptybox.On("GetMessage", "0001").Return(&MockMessage{}, smtpd.ErrNotExist) emptybox.On("GetMessage", "0001").Return(&datastore.MockMessage{}, datastore.ErrNotExist)
w, err = testRestGet(baseURL + "/mailbox/empty/0001") w, err = testRestGet(baseURL + "/mailbox/empty/0001")
expectCode = 404 expectCode = 404
@@ -183,7 +183,7 @@ func TestRestMessage(t *testing.T) {
} }
// Test MailboxFor error // Test MailboxFor error
ds.On("MailboxFor", "error").Return(&MockMailbox{}, fmt.Errorf("Internal error")) ds.On("MailboxFor", "error").Return(&datastore.MockMailbox{}, fmt.Errorf("Internal error"))
w, err = testRestGet(baseURL + "/mailbox/error/0001") w, err = testRestGet(baseURL + "/mailbox/error/0001")
expectCode = 500 expectCode = 500
if err != nil { if err != nil {
@@ -201,9 +201,9 @@ func TestRestMessage(t *testing.T) {
} }
// Test GetMessage error // Test GetMessage error
error2box := &MockMailbox{} error2box := &datastore.MockMailbox{}
ds.On("MailboxFor", "error2").Return(error2box, nil) ds.On("MailboxFor", "error2").Return(error2box, nil)
error2box.On("GetMessage", "0001").Return(&MockMessage{}, fmt.Errorf("Internal error 2")) error2box.On("GetMessage", "0001").Return(&datastore.MockMessage{}, fmt.Errorf("Internal error 2"))
w, err = testRestGet(baseURL + "/mailbox/error2/0001") w, err = testRestGet(baseURL + "/mailbox/error2/0001")
expectCode = 500 expectCode = 500
@@ -228,7 +228,7 @@ func TestRestMessage(t *testing.T) {
Text: "This is some text", Text: "This is some text",
HTML: "This is some HTML", HTML: "This is some HTML",
} }
goodbox := &MockMailbox{} goodbox := &datastore.MockMailbox{}
ds.On("MailboxFor", "good").Return(goodbox, nil) ds.On("MailboxFor", "good").Return(goodbox, nil)
msg1 := data1.MockMessage() msg1 := data1.MockMessage()
goodbox.On("GetMessage", "0001").Return(msg1, nil) goodbox.On("GetMessage", "0001").Return(msg1, nil)

View File

@@ -9,7 +9,7 @@ import (
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
"github.com/jhillyerd/inbucket/rest/model" "github.com/jhillyerd/inbucket/rest/model"
"github.com/jhillyerd/inbucket/smtpd" "github.com/jhillyerd/inbucket/stringutil"
) )
const ( const (
@@ -169,7 +169,7 @@ func MonitorAllMessagesV1(
func MonitorMailboxMessagesV1( func MonitorMailboxMessagesV1(
w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }

View File

@@ -11,9 +11,9 @@ import (
"github.com/jhillyerd/enmime" "github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/httpd" "github.com/jhillyerd/inbucket/httpd"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
"github.com/jhillyerd/inbucket/smtpd"
) )
type InputMessageData struct { type InputMessageData struct {
@@ -25,8 +25,8 @@ type InputMessageData struct {
HTML, Text string HTML, Text string
} }
func (d *InputMessageData) MockMessage() *MockMessage { func (d *InputMessageData) MockMessage() *datastore.MockMessage {
msg := &MockMessage{} msg := &datastore.MockMessage{}
msg.On("ID").Return(d.ID) msg.On("ID").Return(d.ID)
msg.On("From").Return(d.From) msg.On("From").Return(d.From)
msg.On("To").Return(d.To) msg.On("To").Return(d.To)
@@ -188,7 +188,7 @@ func testRestGet(url string) (*httptest.ResponseRecorder, error) {
return w, nil return w, nil
} }
func setupWebServer(ds smtpd.DataStore) *bytes.Buffer { func setupWebServer(ds datastore.DataStore) *bytes.Buffer {
// Capture log output // Capture log output
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
log.SetOutput(buf) log.SetOutput(buf)

View File

@@ -12,8 +12,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
"github.com/jhillyerd/inbucket/stringutil"
) )
// State tracks the current mode of our SMTP state machine // State tracks the current mode of our SMTP state machine
@@ -71,7 +73,7 @@ var commands = map[string]bool{
// recipientDetails for message delivery // recipientDetails for message delivery
type recipientDetails struct { type recipientDetails struct {
address, localPart, domainPart string address, localPart, domainPart string
mailbox Mailbox mailbox datastore.Mailbox
} }
// Session holds the state of an SMTP session // Session holds the state of an SMTP session
@@ -265,7 +267,7 @@ func (ss *Session) readyHandler(cmd string, arg string) {
return return
} }
from := m[1] from := m[1]
if _, _, err := ParseEmailAddress(from); err != nil { if _, _, err := stringutil.ParseEmailAddress(from); err != nil {
ss.send("501 Bad sender address syntax") ss.send("501 Bad sender address syntax")
ss.logWarn("Bad address as MAIL arg: %q, %s", from, err) ss.logWarn("Bad address as MAIL arg: %q, %s", from, err)
return return
@@ -314,7 +316,7 @@ func (ss *Session) mailHandler(cmd string, arg string) {
} }
// This trim is probably too forgiving // This trim is probably too forgiving
recip := strings.Trim(arg[3:], "<> ") recip := strings.Trim(arg[3:], "<> ")
if _, _, err := ParseEmailAddress(recip); err != nil { if _, _, err := stringutil.ParseEmailAddress(recip); err != nil {
ss.send("501 Bad recipient address syntax") ss.send("501 Bad recipient address syntax")
ss.logWarn("Bad address as RCPT arg: %q, %s", recip, err) ss.logWarn("Bad address as RCPT arg: %q, %s", recip, err)
return return
@@ -354,7 +356,7 @@ func (ss *Session) dataHandler() {
if ss.server.storeMessages { if ss.server.storeMessages {
for e := ss.recipients.Front(); e != nil; e = e.Next() { for e := ss.recipients.Front(); e != nil; e = e.Next() {
recip := e.Value.(string) recip := e.Value.(string)
local, domain, err := ParseEmailAddress(recip) local, domain, err := stringutil.ParseEmailAddress(recip)
if err != nil { if err != nil {
ss.logError("Failed to parse address for %q", recip) ss.logError("Failed to parse address for %q", recip)
ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", recip)) ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", recip))
@@ -510,20 +512,16 @@ func (ss *Session) send(msg string) {
// readByteLine reads a line of input into the provided buffer. Does // readByteLine reads a line of input into the provided buffer. Does
// not reset the Buffer - please do so prior to calling. // not reset the Buffer - please do so prior to calling.
func (ss *Session) readByteLine(buf *bytes.Buffer) error { func (ss *Session) readByteLine(buf io.Writer) error {
if err := ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil { if err := ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil {
return err return err
} }
for { line, err := ss.reader.ReadBytes('\n')
line, err := ss.reader.ReadBytes('\n') if err != nil {
if err != nil { return err
return err
}
if _, err = buf.Write(line); err != nil {
return err
}
return nil
} }
_, err = buf.Write(line)
return err
} }
// Reads a line of input // Reads a line of input
@@ -572,7 +570,7 @@ func (ss *Session) parseCmd(line string) (cmd string, arg string, ok bool) {
// The leading space is mandatory. // The leading space is mandatory.
func (ss *Session) parseArgs(arg string) (args map[string]string, ok bool) { func (ss *Session) parseArgs(arg string) (args map[string]string, ok bool) {
args = make(map[string]string) args = make(map[string]string)
re := regexp.MustCompile(" (\\w+)=(\\w+)") re := regexp.MustCompile(` (\w+)=(\w+)`)
pm := re.FindAllStringSubmatch(arg, -1) pm := re.FindAllStringSubmatch(arg, -1)
if pm == nil { if pm == nil {
ss.logWarn("Failed to parse arg string: %q") ss.logWarn("Failed to parse arg string: %q")

View File

@@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
) )
@@ -25,17 +26,13 @@ type scriptStep struct {
// Test commands in GREET state // Test commands in GREET state
func TestGreetState(t *testing.T) { func TestGreetState(t *testing.T) {
// Setup mock objects // Setup mock objects
mds := &MockDataStore{} mds := &datastore.MockDataStore{}
mb1 := &MockMailbox{}
mds.On("MailboxFor").Return(mb1, nil)
server, logbuf, teardown := setupSMTPServer(mds) server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
var script []scriptStep
// Test out some mangled HELOs // Test out some mangled HELOs
script = []scriptStep{ script := []scriptStep{
{"HELO", 501}, {"HELO", 501},
{"EHLO", 501}, {"EHLO", 501},
{"HELLO", 500}, {"HELLO", 500},
@@ -86,17 +83,13 @@ func TestGreetState(t *testing.T) {
// Test commands in READY state // Test commands in READY state
func TestReadyState(t *testing.T) { func TestReadyState(t *testing.T) {
// Setup mock objects // Setup mock objects
mds := &MockDataStore{} mds := &datastore.MockDataStore{}
mb1 := &MockMailbox{}
mds.On("MailboxFor").Return(mb1, nil)
server, logbuf, teardown := setupSMTPServer(mds) server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
var script []scriptStep
// Test out some mangled READY commands // Test out some mangled READY commands
script = []scriptStep{ script := []scriptStep{
{"HELO localhost", 250}, {"HELO localhost", 250},
{"FOOB", 500}, {"FOOB", 500},
{"HELO", 503}, {"HELO", 503},
@@ -151,10 +144,10 @@ func TestReadyState(t *testing.T) {
// Test commands in MAIL state // Test commands in MAIL state
func TestMailState(t *testing.T) { func TestMailState(t *testing.T) {
// Setup mock objects // Setup mock objects
mds := &MockDataStore{} mds := &datastore.MockDataStore{}
mb1 := &MockMailbox{} mb1 := &datastore.MockMailbox{}
msg1 := &MockMessage{} msg1 := &datastore.MockMessage{}
mds.On("MailboxFor").Return(mb1, nil) mds.On("MailboxFor", "u1").Return(mb1, nil)
mb1.On("NewMessage").Return(msg1, nil) mb1.On("NewMessage").Return(msg1, nil)
mb1.On("Name").Return("u1") mb1.On("Name").Return("u1")
msg1.On("ID").Return("") msg1.On("ID").Return("")
@@ -168,10 +161,8 @@ func TestMailState(t *testing.T) {
server, logbuf, teardown := setupSMTPServer(mds) server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
var script []scriptStep
// Test out some mangled READY commands // Test out some mangled READY commands
script = []scriptStep{ script := []scriptStep{
{"HELO localhost", 250}, {"HELO localhost", 250},
{"MAIL FROM:<john@gmail.com>", 250}, {"MAIL FROM:<john@gmail.com>", 250},
{"FOOB", 500}, {"FOOB", 500},
@@ -268,10 +259,10 @@ func TestMailState(t *testing.T) {
// Test commands in DATA state // Test commands in DATA state
func TestDataState(t *testing.T) { func TestDataState(t *testing.T) {
// Setup mock objects // Setup mock objects
mds := &MockDataStore{} mds := &datastore.MockDataStore{}
mb1 := &MockMailbox{} mb1 := &datastore.MockMailbox{}
msg1 := &MockMessage{} msg1 := &datastore.MockMessage{}
mds.On("MailboxFor").Return(mb1, nil) mds.On("MailboxFor", "u1").Return(mb1, nil)
mb1.On("NewMessage").Return(msg1, nil) mb1.On("NewMessage").Return(msg1, nil)
mb1.On("Name").Return("u1") mb1.On("Name").Return("u1")
msg1.On("ID").Return("") msg1.On("ID").Return("")
@@ -376,7 +367,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func setupSMTPServer(ds DataStore) (s *Server, buf *bytes.Buffer, teardown func()) { func setupSMTPServer(ds datastore.DataStore) (s *Server, buf *bytes.Buffer, teardown func()) {
// Test Server Config // Test Server Config
cfg := config.SMTPConfig{ cfg := config.SMTPConfig{
IP4address: net.IPv4(127, 0, 0, 1), IP4address: net.IPv4(127, 0, 0, 1),

View File

@@ -11,10 +11,31 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/msghub" "github.com/jhillyerd/inbucket/msghub"
) )
func init() {
m := expvar.NewMap("smtp")
m.Set("ConnectsTotal", expConnectsTotal)
m.Set("ConnectsHist", expConnectsHist)
m.Set("ConnectsCurrent", expConnectsCurrent)
m.Set("ReceivedTotal", expReceivedTotal)
m.Set("ReceivedHist", expReceivedHist)
m.Set("ErrorsTotal", expErrorsTotal)
m.Set("ErrorsHist", expErrorsHist)
m.Set("WarnsTotal", expWarnsTotal)
m.Set("WarnsHist", expWarnsHist)
log.AddTickerFunc(func() {
expReceivedHist.Set(log.PushMetric(deliveredHist, expReceivedTotal))
expConnectsHist.Set(log.PushMetric(connectsHist, expConnectsTotal))
expErrorsHist.Set(log.PushMetric(errorsHist, expErrorsTotal))
expWarnsHist.Set(log.PushMetric(warnsHist, expWarnsTotal))
})
}
// Server holds the configuration and state of our SMTP server // Server holds the configuration and state of our SMTP server
type Server struct { type Server struct {
// Configuration // Configuration
@@ -27,10 +48,10 @@ type Server struct {
storeMessages bool storeMessages bool
// Dependencies // Dependencies
dataStore DataStore // Mailbox/message store dataStore datastore.DataStore // Mailbox/message store
globalShutdown chan bool // Shuts down Inbucket globalShutdown chan bool // Shuts down Inbucket
msgHub *msghub.Hub // Pub/sub for message info msgHub *msghub.Hub // Pub/sub for message info
retentionScanner *RetentionScanner // Deletes expired messages retentionScanner *datastore.RetentionScanner // Deletes expired messages
// State // State
listener net.Listener // Incoming network connections listener net.Listener // Incoming network connections
@@ -62,7 +83,7 @@ var (
func NewServer( func NewServer(
cfg config.SMTPConfig, cfg config.SMTPConfig,
globalShutdown chan bool, globalShutdown chan bool,
ds DataStore, ds datastore.DataStore,
msgHub *msghub.Hub) *Server { msgHub *msghub.Hub) *Server {
return &Server{ return &Server{
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port), host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
@@ -75,7 +96,7 @@ func NewServer(
globalShutdown: globalShutdown, globalShutdown: globalShutdown,
dataStore: ds, dataStore: ds,
msgHub: msgHub, msgHub: msgHub,
retentionScanner: NewRetentionScanner(ds, globalShutdown), retentionScanner: datastore.NewRetentionScanner(ds, globalShutdown),
waitgroup: new(sync.WaitGroup), waitgroup: new(sync.WaitGroup),
} }
} }
@@ -110,10 +131,8 @@ func (s *Server) Start(ctx context.Context) {
go s.serve(ctx) go s.serve(ctx)
// Wait for shutdown // Wait for shutdown
select { <-ctx.Done()
case <-ctx.Done(): log.Tracef("SMTP shutdown requested, connections will be drained")
log.Tracef("SMTP shutdown requested, connections will be drained")
}
// Closing the listener will cause the serve() go routine to exit // Closing the listener will cause the serve() go routine to exit
if err := s.listener.Close(); err != nil { if err := s.listener.Close(); err != nil {
@@ -165,7 +184,7 @@ func (s *Server) serve(ctx context.Context) {
func (s *Server) emergencyShutdown() { func (s *Server) emergencyShutdown() {
// Shutdown Inbucket // Shutdown Inbucket
select { select {
case _ = <-s.globalShutdown: case <-s.globalShutdown:
default: default:
close(s.globalShutdown) close(s.globalShutdown)
} }
@@ -178,44 +197,3 @@ func (s *Server) Drain() {
log.Tracef("SMTP connections have drained") log.Tracef("SMTP connections have drained")
s.retentionScanner.Join() s.retentionScanner.Join()
} }
// When the provided Ticker ticks, we update our metrics history
func metricsTicker(t *time.Ticker) {
ok := true
for ok {
_, ok = <-t.C
expReceivedHist.Set(pushMetric(deliveredHist, expReceivedTotal))
expConnectsHist.Set(pushMetric(connectsHist, expConnectsTotal))
expErrorsHist.Set(pushMetric(errorsHist, expErrorsTotal))
expWarnsHist.Set(pushMetric(warnsHist, expWarnsTotal))
expRetentionDeletesHist.Set(pushMetric(retentionDeletesHist, expRetentionDeletesTotal))
expRetainedHist.Set(pushMetric(retainedHist, expRetainedCurrent))
}
}
// pushMetric adds the metric to the end of the list and returns a comma separated string of the
// previous 61 entries. We return 61 instead of 60 (an hour) because the chart on the client
// tracks deltas between these values - there is nothing to compare the first value against.
func pushMetric(history *list.List, ev expvar.Var) string {
history.PushBack(ev.String())
if history.Len() > 61 {
history.Remove(history.Front())
}
return JoinStringList(history)
}
func init() {
m := expvar.NewMap("smtp")
m.Set("ConnectsTotal", expConnectsTotal)
m.Set("ConnectsHist", expConnectsHist)
m.Set("ConnectsCurrent", expConnectsCurrent)
m.Set("ReceivedTotal", expReceivedTotal)
m.Set("ReceivedHist", expReceivedHist)
m.Set("ErrorsTotal", expErrorsTotal)
m.Set("ErrorsHist", expErrorsHist)
m.Set("WarnsTotal", expWarnsTotal)
m.Set("WarnsHist", expWarnsHist)
t := time.NewTicker(time.Minute)
go metricsTicker(t)
}

View File

@@ -1,197 +0,0 @@
package smtpd
import (
"fmt"
"io"
"net/mail"
"testing"
"time"
"github.com/jhillyerd/enmime"
"github.com/stretchr/testify/mock"
)
func TestDoRetentionScan(t *testing.T) {
// Create mock objects
mds := &MockDataStore{}
mb1 := &MockMailbox{}
mb2 := &MockMailbox{}
mb3 := &MockMailbox{}
// Mockup some different aged messages (num is in hours)
new1 := mockMessage(0)
new2 := mockMessage(1)
new3 := mockMessage(2)
old1 := mockMessage(4)
old2 := mockMessage(12)
old3 := mockMessage(24)
// First it should ask for all mailboxes
mds.On("AllMailboxes").Return([]Mailbox{mb1, mb2, mb3}, nil)
// Then for all messages on each box
mb1.On("GetMessages").Return([]Message{new1, old1, old2}, nil)
mb2.On("GetMessages").Return([]Message{old3, new2}, nil)
mb3.On("GetMessages").Return([]Message{new3}, nil)
// Test 4 hour retention
rs := &RetentionScanner{
ds: mds,
retentionPeriod: 4*time.Hour - time.Minute,
retentionSleep: 0,
}
if err := rs.doScan(); err != nil {
t.Error(err)
}
// Check our assertions
mds.AssertExpectations(t)
mb1.AssertExpectations(t)
mb2.AssertExpectations(t)
mb3.AssertExpectations(t)
// Delete should not have been called on new messages
new1.AssertNotCalled(t, "Delete")
new2.AssertNotCalled(t, "Delete")
new3.AssertNotCalled(t, "Delete")
// Delete should have been called once on old messages
old1.AssertNumberOfCalls(t, "Delete", 1)
old2.AssertNumberOfCalls(t, "Delete", 1)
old3.AssertNumberOfCalls(t, "Delete", 1)
}
// Make a MockMessage of a specific age
func mockMessage(ageHours int) *MockMessage {
msg := &MockMessage{}
msg.On("ID").Return(fmt.Sprintf("MSG[age=%vh]", ageHours))
msg.On("Date").Return(time.Now().Add(time.Duration(ageHours*-1) * time.Hour))
msg.On("Delete").Return(nil)
return msg
}
// Mock DataStore object
type MockDataStore struct {
mock.Mock
}
func (m *MockDataStore) MailboxFor(name string) (Mailbox, error) {
args := m.Called()
return args.Get(0).(Mailbox), args.Error(1)
}
func (m *MockDataStore) AllMailboxes() ([]Mailbox, error) {
args := m.Called()
return args.Get(0).([]Mailbox), args.Error(1)
}
// Mock Mailbox object
type MockMailbox struct {
mock.Mock
}
func (m *MockMailbox) GetMessages() ([]Message, error) {
args := m.Called()
return args.Get(0).([]Message), args.Error(1)
}
func (m *MockMailbox) GetMessage(id string) (Message, error) {
args := m.Called(id)
return args.Get(0).(Message), args.Error(1)
}
func (m *MockMailbox) Purge() error {
args := m.Called()
return args.Error(0)
}
func (m *MockMailbox) NewMessage() (Message, error) {
args := m.Called()
return args.Get(0).(Message), args.Error(1)
}
func (m *MockMailbox) Name() string {
args := m.Called()
return args.String(0)
}
func (m *MockMailbox) String() string {
args := m.Called()
return args.String(0)
}
// Mock Message object
type MockMessage struct {
mock.Mock
}
func (m *MockMessage) ID() string {
args := m.Called()
return args.String(0)
}
func (m *MockMessage) From() string {
args := m.Called()
return args.String(0)
}
func (m *MockMessage) To() []string {
args := m.Called()
return args.Get(0).([]string)
}
func (m *MockMessage) Date() time.Time {
args := m.Called()
return args.Get(0).(time.Time)
}
func (m *MockMessage) Subject() string {
args := m.Called()
return args.String(0)
}
func (m *MockMessage) ReadHeader() (msg *mail.Message, err error) {
args := m.Called()
return args.Get(0).(*mail.Message), args.Error(1)
}
func (m *MockMessage) ReadBody() (body *enmime.Envelope, err error) {
args := m.Called()
return args.Get(0).(*enmime.Envelope), args.Error(1)
}
func (m *MockMessage) ReadRaw() (raw *string, err error) {
args := m.Called()
return args.Get(0).(*string), args.Error(1)
}
func (m *MockMessage) RawReader() (reader io.ReadCloser, err error) {
args := m.Called()
return args.Get(0).(io.ReadCloser), args.Error(1)
}
func (m *MockMessage) Size() int64 {
args := m.Called()
return int64(args.Int(0))
}
func (m *MockMessage) Append(data []byte) error {
// []byte arg seems to mess up testify/mock
return nil
}
func (m *MockMessage) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockMessage) Delete() error {
args := m.Called()
return args.Error(0)
}
func (m *MockMessage) String() string {
args := m.Called()
return args.String(0)
}

View File

@@ -1,8 +1,7 @@
package smtpd package stringutil
import ( import (
"bytes" "bytes"
"container/list"
"crypto/sha1" "crypto/sha1"
"fmt" "fmt"
"io" "io"
@@ -42,7 +41,7 @@ func ParseMailboxName(localPart string) (result string, err error) {
return result, nil return result, nil
} }
// HashMailboxName accepts a mailbox name and hashes it. Inbucket uses this as // HashMailboxName accepts a mailbox name and hashes it. filestore uses this as
// the directory to house the mailbox // the directory to house the mailbox
func HashMailboxName(mailbox string) string { func HashMailboxName(mailbox string) string {
h := sha1.New() h := sha1.New()
@@ -53,18 +52,6 @@ func HashMailboxName(mailbox string) string {
return fmt.Sprintf("%x", h.Sum(nil)) return fmt.Sprintf("%x", h.Sum(nil))
} }
// JoinStringList joins a List containing strings by commas
func JoinStringList(listOfStrings *list.List) string {
if listOfStrings.Len() == 0 {
return ""
}
s := make([]string, 0, listOfStrings.Len())
for e := listOfStrings.Front(); e != nil; e = e.Next() {
s = append(s, e.Value.(string))
}
return strings.Join(s, ",")
}
// ValidateDomainPart returns true if the domain part complies to RFC3696, RFC1035 // ValidateDomainPart returns true if the domain part complies to RFC3696, RFC1035
func ValidateDomainPart(domain string) bool { func ValidateDomainPart(domain string) bool {
if len(domain) == 0 { if len(domain) == 0 {
@@ -143,15 +130,24 @@ LOOP:
switch { switch {
case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'): case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'):
// Letters are OK // Letters are OK
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
case '0' <= c && c <= '9': case '0' <= c && c <= '9':
// Numbers are OK // Numbers are OK
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
case bytes.IndexByte([]byte("!#$%&'*+-/=?^_`{|}~"), c) >= 0: case bytes.IndexByte([]byte("!#$%&'*+-/=?^_`{|}~"), c) >= 0:
// These specials can be used unquoted // These specials can be used unquoted
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
case c == '.': case c == '.':
// A single period is OK // A single period is OK
@@ -159,13 +155,19 @@ LOOP:
// Sequence of periods is not permitted // Sequence of periods is not permitted
return "", "", fmt.Errorf("Sequence of periods is not permitted") return "", "", fmt.Errorf("Sequence of periods is not permitted")
} }
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
case c == '\\': case c == '\\':
inCharQuote = true inCharQuote = true
case c == '"': case c == '"':
if inCharQuote { if inCharQuote {
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
} else if inStringQuote { } else if inStringQuote {
inStringQuote = false inStringQuote = false
@@ -178,7 +180,10 @@ LOOP:
} }
case c == '@': case c == '@':
if inCharQuote || inStringQuote { if inCharQuote || inStringQuote {
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
} else { } else {
// End of local-part // End of local-part
@@ -195,7 +200,10 @@ LOOP:
return "", "", fmt.Errorf("Characters outside of US-ASCII range not permitted") return "", "", fmt.Errorf("Characters outside of US-ASCII range not permitted")
default: default:
if inCharQuote || inStringQuote { if inCharQuote || inStringQuote {
_ = buf.WriteByte(c) err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false inCharQuote = false
} else { } else {
return "", "", fmt.Errorf("Character %q must be quoted", c) return "", "", fmt.Errorf("Character %q must be quoted", c)

View File

@@ -1,4 +1,4 @@
package smtpd package stringutil
import ( import (
"strings" "strings"

View File

@@ -7,9 +7,10 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/jhillyerd/inbucket/datastore"
"github.com/jhillyerd/inbucket/httpd" "github.com/jhillyerd/inbucket/httpd"
"github.com/jhillyerd/inbucket/log" "github.com/jhillyerd/inbucket/log"
"github.com/jhillyerd/inbucket/smtpd" "github.com/jhillyerd/inbucket/stringutil"
) )
// MailboxIndex renders the index page for a particular mailbox // MailboxIndex renders the index page for a particular mailbox
@@ -23,7 +24,7 @@ func MailboxIndex(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
http.Redirect(w, req, httpd.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, httpd.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
name, err = smtpd.ParseMailboxName(name) name, err = stringutil.ParseMailboxName(name)
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -50,7 +51,7 @@ func MailboxIndex(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -66,7 +67,7 @@ func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
// MailboxList renders a list of messages in a mailbox. Renders a partial // MailboxList renders a list of messages in a mailbox. Renders a partial
func MailboxList(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxList(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -93,7 +94,7 @@ func MailboxList(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -103,7 +104,7 @@ func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
msg, err := mb.GetMessage(id) msg, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -133,7 +134,7 @@ func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -143,7 +144,7 @@ func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -170,7 +171,7 @@ func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (
func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
@@ -180,7 +181,7 @@ func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -205,7 +206,7 @@ func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *httpd.Context)
func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -226,7 +227,7 @@ func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -257,7 +258,7 @@ func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.
// MailboxViewAttach sends the attachment to the client for online viewing // MailboxViewAttach sends the attachment to the client for online viewing
func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) { func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -279,7 +280,7 @@ func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *httpd.Cont
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err) return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
} }
message, err := mb.GetMessage(id) message, err := mb.GetMessage(id)
if err == smtpd.ErrNotExist { if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/jhillyerd/inbucket/config" "github.com/jhillyerd/inbucket/config"
"github.com/jhillyerd/inbucket/httpd" "github.com/jhillyerd/inbucket/httpd"
"github.com/jhillyerd/inbucket/smtpd" "github.com/jhillyerd/inbucket/stringutil"
) )
// RootIndex serves the Inbucket landing page // RootIndex serves the Inbucket landing page
@@ -58,7 +58,7 @@ func RootMonitorMailbox(w http.ResponseWriter, req *http.Request, ctx *httpd.Con
http.Redirect(w, req, httpd.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, httpd.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
name, err := smtpd.ParseMailboxName(ctx.Vars["name"]) name, err := stringutil.ParseMailboxName(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)