1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-18 18:17:03 +00:00

Merge giant storage/service layer refactor #69 #81

This commit is contained in:
James Hillyerd
2018-03-18 15:24:21 -07:00
40 changed files with 2185 additions and 2185 deletions

View File

@@ -6,6 +6,7 @@ env:
before_script: before_script:
- go get github.com/golang/lint/golint - go get github.com/golang/lint/golint
- make deps
go: go:
- "1.10" - "1.10"

View File

@@ -4,6 +4,12 @@ Change Log
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/). This project adheres to [Semantic Versioning](http://semver.org/).
## [Unreleased]
### Changed
- Massive refactor of back-end code. Inbucket should now be both easier and
more enjoyable to work on.
## [v1.3.1] - 2018-03-10 ## [v1.3.1] - 2018-03-10
### Fixed ### Fixed

View File

@@ -3,7 +3,7 @@ SHELL = /bin/sh
SRC := $(shell find . -type f -name '*.go' -not -path "./vendor/*") SRC := $(shell find . -type f -name '*.go' -not -path "./vendor/*")
PKGS := $(shell go list ./... | grep -v /vendor/) PKGS := $(shell go list ./... | grep -v /vendor/)
.PHONY: all build clean fmt lint simplify test .PHONY: all build clean fmt lint reflex simplify test
commands = client inbucket commands = client inbucket
@@ -19,9 +19,9 @@ clean:
deps: deps:
go get -t ./... go get -t ./...
build: deps $(commands) build: $(commands)
test: deps test:
go test -race ./... go test -race ./...
fmt: fmt:
@@ -34,3 +34,6 @@ lint:
@test -z "$(shell gofmt -l . | tee /dev/stderr)" || echo "[WARN] Fix formatting issues with 'make fmt'" @test -z "$(shell gofmt -l . | tee /dev/stderr)" || echo "[WARN] Fix formatting issues with 'make fmt'"
@golint -set_exit_status $(PKGS) @golint -set_exit_status $(PKGS)
@go vet $(PKGS) @go vet $(PKGS)
reflex:
reflex -r '\.go$$' -- sh -c 'echo; date; echo; go test ./... && echo ALL PASS'

View File

@@ -14,11 +14,14 @@ import (
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/rest" "github.com/jhillyerd/inbucket/pkg/rest"
"github.com/jhillyerd/inbucket/pkg/server/pop3" "github.com/jhillyerd/inbucket/pkg/server/pop3"
"github.com/jhillyerd/inbucket/pkg/server/smtp" "github.com/jhillyerd/inbucket/pkg/server/smtp"
"github.com/jhillyerd/inbucket/pkg/server/web" "github.com/jhillyerd/inbucket/pkg/server/web"
"github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/storage/file" "github.com/jhillyerd/inbucket/pkg/storage/file"
"github.com/jhillyerd/inbucket/pkg/webui" "github.com/jhillyerd/inbucket/pkg/webui"
) )
@@ -112,27 +115,27 @@ func main() {
} }
} }
// Create message hub // Configure internal services.
msgHub := msghub.New(rootCtx, config.GetWebConfig().MonitorHistory) msgHub := msghub.New(rootCtx, config.GetWebConfig().MonitorHistory)
dscfg := config.GetDataStoreConfig()
// Grab our datastore store := file.New(dscfg)
ds := filestore.DefaultFileDataStore() apolicy := &policy.Addressing{Config: config.GetSMTPConfig()}
mmanager := &message.StoreManager{Store: store, Hub: msgHub}
// Start HTTP server // Start Retention scanner.
web.Initialize(config.GetWebConfig(), shutdownChan, ds, msgHub) retentionScanner := storage.NewRetentionScanner(dscfg, store, shutdownChan)
retentionScanner.Start()
// Start HTTP server.
web.Initialize(config.GetWebConfig(), shutdownChan, mmanager, msgHub)
webui.SetupRoutes(web.Router) webui.SetupRoutes(web.Router)
rest.SetupRoutes(web.Router) rest.SetupRoutes(web.Router)
go web.Start(rootCtx) go web.Start(rootCtx)
// Start POP3 server.
// Start POP3 server pop3Server = pop3.New(config.GetPOP3Config(), shutdownChan, store)
pop3Server = pop3.New(config.GetPOP3Config(), shutdownChan, ds)
go pop3Server.Start(rootCtx) go pop3Server.Start(rootCtx)
// Start SMTP server.
// Startup SMTP server smtpServer = smtp.NewServer(config.GetSMTPConfig(), shutdownChan, mmanager, apolicy)
smtpServer = smtp.NewServer(config.GetSMTPConfig(), shutdownChan, ds, msgHub)
go smtpServer.Start(rootCtx) go smtpServer.Start(rootCtx)
// Loop forever waiting for signals or shutdown channel.
// Loop forever waiting for signals or shutdown channel
signalLoop: signalLoop:
for { for {
select { select {
@@ -160,6 +163,7 @@ signalLoop:
go timedExit() go timedExit()
smtpServer.Drain() smtpServer.Drain()
pop3Server.Drain() pop3Server.Drain()
retentionScanner.Join()
removePIDFile() removePIDFile()
} }

View File

@@ -5,6 +5,7 @@ import (
golog "log" golog "log"
"os" "os"
"strings" "strings"
"sync"
) )
// Level is used to indicate the severity of a log entry // Level is used to indicate the severity of a log entry
@@ -30,12 +31,16 @@ var (
// logf is the file we send log output to, will be nil for stderr or stdout // logf is the file we send log output to, will be nil for stderr or stdout
logf *os.File logf *os.File
mu sync.RWMutex
) )
// Initialize logging. If logfile is equal to "stderr" or "stdout", then // Initialize logging. If logfile is equal to "stderr" or "stdout", then
// we will log to that output stream. Otherwise the specificed file will // we will log to that output stream. Otherwise the specificed file will
// opened for writing, and all log data will be placed in it. // opened for writing, and all log data will be placed in it.
func Initialize(logfile string) error { func Initialize(logfile string) error {
mu.Lock()
defer mu.Unlock()
if logfile != "stderr" { if logfile != "stderr" {
// stderr is the go logging default // stderr is the go logging default
if logfile == "stdout" { if logfile == "stdout" {
@@ -55,6 +60,8 @@ func Initialize(logfile string) error {
// SetLogLevel sets MaxLevel based on the provided string // SetLogLevel sets MaxLevel based on the provided string
func SetLogLevel(level string) (ok bool) { func SetLogLevel(level string) (ok bool) {
mu.Lock()
defer mu.Unlock()
switch strings.ToUpper(level) { switch strings.ToUpper(level) {
case "ERROR": case "ERROR":
MaxLevel = ERROR MaxLevel = ERROR
@@ -73,12 +80,16 @@ func SetLogLevel(level string) (ok bool) {
// Errorf logs a message to the 'standard' Logger (always), accepts format strings // Errorf logs a message to the 'standard' Logger (always), accepts format strings
func Errorf(msg string, args ...interface{}) { func Errorf(msg string, args ...interface{}) {
mu.RLock()
defer mu.RUnlock()
msg = "[ERROR] " + msg msg = "[ERROR] " + msg
golog.Printf(msg, args...) golog.Printf(msg, args...)
} }
// Warnf logs a message to the 'standard' Logger if MaxLevel is >= WARN, accepts format strings // Warnf logs a message to the 'standard' Logger if MaxLevel is >= WARN, accepts format strings
func Warnf(msg string, args ...interface{}) { func Warnf(msg string, args ...interface{}) {
mu.RLock()
defer mu.RUnlock()
if MaxLevel >= WARN { if MaxLevel >= WARN {
msg = "[WARN ] " + msg msg = "[WARN ] " + msg
golog.Printf(msg, args...) golog.Printf(msg, args...)
@@ -87,6 +98,8 @@ func Warnf(msg string, args ...interface{}) {
// Infof logs a message to the 'standard' Logger if MaxLevel is >= INFO, accepts format strings // Infof logs a message to the 'standard' Logger if MaxLevel is >= INFO, accepts format strings
func Infof(msg string, args ...interface{}) { func Infof(msg string, args ...interface{}) {
mu.RLock()
defer mu.RUnlock()
if MaxLevel >= INFO { if MaxLevel >= INFO {
msg = "[INFO ] " + msg msg = "[INFO ] " + msg
golog.Printf(msg, args...) golog.Printf(msg, args...)
@@ -95,6 +108,8 @@ func Infof(msg string, args ...interface{}) {
// Tracef logs a message to the 'standard' Logger if MaxLevel is >= TRACE, accepts format strings // Tracef logs a message to the 'standard' Logger if MaxLevel is >= TRACE, accepts format strings
func Tracef(msg string, args ...interface{}) { func Tracef(msg string, args ...interface{}) {
mu.RLock()
defer mu.RUnlock()
if MaxLevel >= TRACE { if MaxLevel >= TRACE {
msg = "[TRACE] " + msg msg = "[TRACE] " + msg
golog.Printf(msg, args...) golog.Printf(msg, args...)
@@ -105,6 +120,8 @@ func Tracef(msg string, args ...interface{}) {
// log rotation system the opportunity to move the existing log file out of the // log rotation system the opportunity to move the existing log file out of the
// way and have Inbucket create a new one. // way and have Inbucket create a new one.
func Rotate() { func Rotate() {
mu.Lock()
defer mu.Unlock()
// Rotate logs if configured // Rotate logs if configured
if logf != nil { if logf != nil {
closeLogFile() closeLogFile()
@@ -117,6 +134,8 @@ func Rotate() {
// Close the log file if we have one open // Close the log file if we have one open
func Close() { func Close() {
mu.Lock()
defer mu.Unlock()
if logf != nil { if logf != nil {
closeLogFile() closeLogFile()
} }

162
pkg/message/manager.go Normal file
View File

@@ -0,0 +1,162 @@
package message
import (
"bytes"
"io"
"net/mail"
"strings"
"time"
"github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/stringutil"
)
// Manager is the interface controllers use to interact with messages.
type Manager interface {
Deliver(
to *policy.Recipient,
from string,
recipients []*policy.Recipient,
prefix string,
content []byte,
) (id string, err error)
GetMetadata(mailbox string) ([]*Metadata, error)
GetMessage(mailbox, id string) (*Message, error)
PurgeMessages(mailbox string) error
RemoveMessage(mailbox, id string) error
SourceReader(mailbox, id string) (io.ReadCloser, error)
MailboxForAddress(address string) (string, error)
}
// StoreManager is a message Manager backed by the storage.Store.
type StoreManager struct {
Store storage.Store
Hub *msghub.Hub
}
// Deliver submits a new message to the store.
func (s *StoreManager) Deliver(
to *policy.Recipient,
from string,
recipients []*policy.Recipient,
prefix string,
source []byte,
) (string, error) {
// TODO enmime is too heavy for this step, only need header.
// Go's header parsing isn't good enough, so this is blocked on enmime issue #64.
env, err := enmime.ReadEnvelope(bytes.NewReader(source))
if err != nil {
return "", err
}
fromaddr, err := env.AddressList("From")
if err != nil || len(fromaddr) == 0 {
fromaddr = []*mail.Address{{Address: from}}
}
toaddr, err := env.AddressList("To")
if err != nil {
toaddr = make([]*mail.Address, len(recipients))
for i, torecip := range recipients {
toaddr[i] = &torecip.Address
}
}
delivery := &Delivery{
Meta: Metadata{
Mailbox: to.Mailbox,
From: fromaddr[0],
To: toaddr,
Date: time.Now(),
Subject: env.GetHeader("Subject"),
},
Reader: io.MultiReader(strings.NewReader(prefix), bytes.NewReader(source)),
}
id, err := s.Store.AddMessage(delivery)
if err != nil {
return "", err
}
if s.Hub != nil {
// Broadcast message information.
broadcast := msghub.Message{
Mailbox: to.Mailbox,
ID: id,
From: delivery.From().String(),
To: stringutil.StringAddressList(delivery.To()),
Subject: delivery.Subject(),
Date: delivery.Date(),
Size: delivery.Size(),
}
s.Hub.Dispatch(broadcast)
}
return id, nil
}
// GetMetadata returns a slice of metadata for the specified mailbox.
func (s *StoreManager) GetMetadata(mailbox string) ([]*Metadata, error) {
messages, err := s.Store.GetMessages(mailbox)
if err != nil {
return nil, err
}
metas := make([]*Metadata, len(messages))
for i, sm := range messages {
metas[i] = makeMetadata(sm)
}
return metas, nil
}
// GetMessage returns the specified message.
func (s *StoreManager) GetMessage(mailbox, id string) (*Message, error) {
sm, err := s.Store.GetMessage(mailbox, id)
if err != nil {
return nil, err
}
r, err := sm.Source()
if err != nil {
return nil, err
}
env, err := enmime.ReadEnvelope(r)
if err != nil {
return nil, err
}
_ = r.Close()
header := makeMetadata(sm)
return &Message{Metadata: *header, Envelope: env}, nil
}
// PurgeMessages removes all messages from the specified mailbox.
func (s *StoreManager) PurgeMessages(mailbox string) error {
return s.Store.PurgeMessages(mailbox)
}
// RemoveMessage deletes the specified message.
func (s *StoreManager) RemoveMessage(mailbox, id string) error {
return s.Store.RemoveMessage(mailbox, id)
}
// SourceReader allows the stored message source to be read.
func (s *StoreManager) SourceReader(mailbox, id string) (io.ReadCloser, error) {
sm, err := s.Store.GetMessage(mailbox, id)
if err != nil {
return nil, err
}
return sm.Source()
}
// MailboxForAddress parses an email address to return the canonical mailbox name.
func (s *StoreManager) MailboxForAddress(mailbox string) (string, error) {
return policy.ParseMailboxName(mailbox)
}
// makeMetadata populates Metadata from a storage.Message.
func makeMetadata(m storage.Message) *Metadata {
return &Metadata{
Mailbox: m.Mailbox(),
ID: m.ID(),
From: m.From(),
To: m.To(),
Date: m.Date(),
Subject: m.Subject(),
Size: m.Size(),
}
}

77
pkg/message/message.go Normal file
View File

@@ -0,0 +1,77 @@
// Package message contains message handling logic.
package message
import (
"io"
"io/ioutil"
"net/mail"
"time"
"github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/pkg/storage"
)
// Metadata holds information about a message, but not the content.
type Metadata struct {
Mailbox string
ID string
From *mail.Address
To []*mail.Address
Date time.Time
Subject string
Size int64
}
// Message holds both the metadata and content of a message.
type Message struct {
Metadata
Envelope *enmime.Envelope
}
// Delivery is used to add a message to storage.
type Delivery struct {
Meta Metadata
Reader io.Reader
}
var _ storage.Message = &Delivery{}
// Mailbox getter.
func (d *Delivery) Mailbox() string {
return d.Meta.Mailbox
}
// ID getter.
func (d *Delivery) ID() string {
return d.Meta.ID
}
// From getter.
func (d *Delivery) From() *mail.Address {
return d.Meta.From
}
// To getter.
func (d *Delivery) To() []*mail.Address {
return d.Meta.To
}
// Date getter.
func (d *Delivery) Date() time.Time {
return d.Meta.Date
}
// Subject getter.
func (d *Delivery) Subject() string {
return d.Meta.Subject
}
// Size getter.
func (d *Delivery) Size() int64 {
return d.Meta.Size
}
// Source contains the raw content of the message.
func (d *Delivery) Source() (io.ReadCloser, error) {
return ioutil.NopCloser(d.Reader), nil
}

244
pkg/policy/address.go Normal file
View File

@@ -0,0 +1,244 @@
package policy
import (
"bytes"
"fmt"
"net/mail"
"strings"
"github.com/jhillyerd/inbucket/pkg/config"
)
// Addressing handles email address policy.
type Addressing struct {
Config config.SMTPConfig
}
// NewRecipient parses an address into a Recipient.
func (a *Addressing) NewRecipient(address string) (*Recipient, error) {
local, domain, err := ParseEmailAddress(address)
if err != nil {
return nil, err
}
mailbox, err := ParseMailboxName(local)
if err != nil {
return nil, err
}
ar, err := mail.ParseAddress(address)
if err != nil {
return nil, err
}
return &Recipient{
Address: *ar,
apolicy: a,
LocalPart: local,
Domain: domain,
Mailbox: mailbox,
}, nil
}
// ShouldStoreDomain indicates if Inbucket stores email destined for the specified domain.
func (a *Addressing) ShouldStoreDomain(domain string) bool {
if a.Config.StoreMessages {
return strings.ToLower(domain) != strings.ToLower(a.Config.DomainNoStore)
}
return false
}
// ParseMailboxName takes a localPart string (ex: "user+ext" without "@domain")
// and returns just the mailbox name (ex: "user"). Returns an error if
// localPart contains invalid characters; it won't accept any that must be
// quoted according to RFC3696.
func ParseMailboxName(localPart string) (result string, err error) {
if localPart == "" {
return "", fmt.Errorf("Mailbox name cannot be empty")
}
result = strings.ToLower(localPart)
invalid := make([]byte, 0, 10)
for i := 0; i < len(result); i++ {
c := result[i]
switch {
case 'a' <= c && c <= 'z':
case '0' <= c && c <= '9':
case bytes.IndexByte([]byte("!#$%&'*+-=/?^_`.{|}~"), c) >= 0:
default:
invalid = append(invalid, c)
}
}
if len(invalid) > 0 {
return "", fmt.Errorf("Mailbox name contained invalid character(s): %q", invalid)
}
if idx := strings.Index(result, "+"); idx > -1 {
result = result[0:idx]
}
return result, nil
}
// ParseEmailAddress unescapes an email address, and splits the local part from the domain part.
// An error is returned if the local or domain parts fail validation following the guidelines
// in RFC3696.
func ParseEmailAddress(address string) (local string, domain string, err error) {
if address == "" {
return "", "", fmt.Errorf("Empty address")
}
if len(address) > 320 {
return "", "", fmt.Errorf("Address exceeds 320 characters")
}
if address[0] == '@' {
return "", "", fmt.Errorf("Address cannot start with @ symbol")
}
if address[0] == '.' {
return "", "", fmt.Errorf("Address cannot start with a period")
}
// Loop over address parsing out local part.
buf := new(bytes.Buffer)
prev := byte('.')
inCharQuote := false
inStringQuote := false
LOOP:
for i := 0; i < len(address); i++ {
c := address[i]
switch {
case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'):
// Letters are OK.
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case '0' <= c && c <= '9':
// Numbers are OK.
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case bytes.IndexByte([]byte("!#$%&'*+-/=?^_`{|}~"), c) >= 0:
// These specials can be used unquoted.
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case c == '.':
// A single period is OK.
if prev == '.' {
// Sequence of periods is not permitted.
return "", "", fmt.Errorf("Sequence of periods is not permitted")
}
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case c == '\\':
inCharQuote = true
case c == '"':
if inCharQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else if inStringQuote {
inStringQuote = false
} else {
if i == 0 {
inStringQuote = true
} else {
return "", "", fmt.Errorf("Quoted string can only begin at start of address")
}
}
case c == '@':
if inCharQuote || inStringQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else {
// End of local-part.
if i > 128 {
return "", "", fmt.Errorf("Local part must not exceed 128 characters")
}
if prev == '.' {
return "", "", fmt.Errorf("Local part cannot end with a period")
}
domain = address[i+1:]
break LOOP
}
case c > 127:
return "", "", fmt.Errorf("Characters outside of US-ASCII range not permitted")
default:
if inCharQuote || inStringQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else {
return "", "", fmt.Errorf("Character %q must be quoted", c)
}
}
prev = c
}
if inCharQuote {
return "", "", fmt.Errorf("Cannot end address with unterminated quoted-pair")
}
if inStringQuote {
return "", "", fmt.Errorf("Cannot end address with unterminated string quote")
}
if !ValidateDomainPart(domain) {
return "", "", fmt.Errorf("Domain part validation failed")
}
return buf.String(), domain, nil
}
// ValidateDomainPart returns true if the domain part complies to RFC3696, RFC1035. Used by
// ParseEmailAddress().
func ValidateDomainPart(domain string) bool {
if len(domain) == 0 {
return false
}
if len(domain) > 255 {
return false
}
if domain[len(domain)-1] != '.' {
domain += "."
}
prev := '.'
labelLen := 0
hasAlphaNum := false
for _, c := range domain {
switch {
case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9') || c == '_':
// Must contain some of these to be a valid label.
hasAlphaNum = true
labelLen++
case c == '-':
if prev == '.' {
// Cannot lead with hyphen.
return false
}
case c == '.':
if prev == '.' || prev == '-' {
// Cannot end with hyphen or double-dot.
return false
}
if labelLen > 63 {
return false
}
if !hasAlphaNum {
return false
}
labelLen = 0
hasAlphaNum = false
default:
// Unknown character.
return false
}
prev = c
}
return true
}

191
pkg/policy/address_test.go Normal file
View File

@@ -0,0 +1,191 @@
package policy_test
import (
"strings"
"testing"
"github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/policy"
)
func TestShouldStoreDomain(t *testing.T) {
// Test with storage enabled.
ap := &policy.Addressing{
Config: config.SMTPConfig{
DomainNoStore: "Foo.Com",
StoreMessages: true,
},
}
testCases := []struct {
domain string
want bool
}{
{domain: "bar.com", want: true},
{domain: "foo.com", want: false},
{domain: "FOO.com", want: false},
{domain: "bar.foo.com", want: true},
}
for _, tc := range testCases {
t.Run(tc.domain, func(t *testing.T) {
got := ap.ShouldStoreDomain(tc.domain)
if got != tc.want {
t.Errorf("Got %v for %q, want: %v", got, tc.domain, tc.want)
}
})
}
// Test with storage disabled.
ap = &policy.Addressing{
Config: config.SMTPConfig{
StoreMessages: false,
},
}
testCases = []struct {
domain string
want bool
}{
{domain: "bar.com", want: false},
{domain: "foo.com", want: false},
{domain: "FOO.com", want: false},
{domain: "bar.foo.com", want: false},
}
for _, tc := range testCases {
t.Run(tc.domain, func(t *testing.T) {
got := ap.ShouldStoreDomain(tc.domain)
if got != tc.want {
t.Errorf("Got %v for %q, want: %v", got, tc.domain, tc.want)
}
})
}
}
func TestParseMailboxName(t *testing.T) {
var validTable = []struct {
input string
expect string
}{
{"mailbox", "mailbox"},
{"user123", "user123"},
{"MailBOX", "mailbox"},
{"First.Last", "first.last"},
{"user+label", "user"},
{"chars!#$%", "chars!#$%"},
{"chars&'*-", "chars&'*-"},
{"chars=/?^", "chars=/?^"},
{"chars_`.{", "chars_`.{"},
{"chars|}~", "chars|}~"},
}
for _, tt := range validTable {
if result, err := policy.ParseMailboxName(tt.input); err != nil {
t.Errorf("Error while parsing %q: %v", tt.input, err)
} else {
if result != tt.expect {
t.Errorf("Parsing %q, expected %q, got %q", tt.input, tt.expect, result)
}
}
}
var invalidTable = []struct {
input, msg string
}{
{"", "Empty mailbox name is not permitted"},
{"user@host", "@ symbol not permitted"},
{"first last", "Space not permitted"},
{"first\"last", "Double quote not permitted"},
{"first\nlast", "Control chars not permitted"},
}
for _, tt := range invalidTable {
if _, err := policy.ParseMailboxName(tt.input); err == nil {
t.Errorf("Didn't get an error while parsing %q: %v", tt.input, tt.msg)
}
}
}
func TestValidateDomain(t *testing.T) {
var testTable = []struct {
input string
expect bool
msg string
}{
{"", false, "Empty domain is not valid"},
{"hostname", true, "Just a hostname is valid"},
{"github.com", true, "Two labels should be just fine"},
{"my-domain.com", true, "Hyphen is allowed mid-label"},
{"_domainkey.foo.com", true, "Underscores are allowed"},
{"bar.com.", true, "Must be able to end with a dot"},
{"ABC.6DBS.com", true, "Mixed case is OK"},
{"mail.123.com", true, "Number only label valid"},
{"123.com", true, "Number only label valid"},
{"google..com", false, "Double dot not valid"},
{".foo.com", false, "Cannot start with a dot"},
{"google\r.com", false, "Special chars not allowed"},
{"foo.-bar.com", false, "Label cannot start with hyphen"},
{"foo-.bar.com", false, "Label cannot end with hyphen"},
{strings.Repeat("a", 256), false, "Max domain length is 255"},
{strings.Repeat("a", 63) + ".com", true, "Should allow 63 char domain label"},
{strings.Repeat("a", 64) + ".com", false, "Max domain label length is 63"},
}
for _, tt := range testTable {
if policy.ValidateDomainPart(tt.input) != tt.expect {
t.Errorf("Expected %v for %q: %s", tt.expect, tt.input, tt.msg)
}
}
}
func TestValidateLocal(t *testing.T) {
var testTable = []struct {
input string
expect bool
msg string
}{
{"", false, "Empty local is not valid"},
{"a", true, "Single letter should be fine"},
{strings.Repeat("a", 128), true, "Valid up to 128 characters"},
{strings.Repeat("a", 129), false, "Only valid up to 128 characters"},
{"FirstLast", true, "Mixed case permitted"},
{"user123", true, "Numbers permitted"},
{"a!#$%&'*+-/=?^_`{|}~", true, "Any of !#$%&'*+-/=?^_`{|}~ are permitted"},
{"first.last", true, "Embedded period is permitted"},
{"first..last", false, "Sequence of periods is not allowed"},
{".user", false, "Cannot lead with a period"},
{"user.", false, "Cannot end with a period"},
{"james@mail", false, "Unquoted @ not permitted"},
{"first last", false, "Unquoted space not permitted"},
{"tricky\\. ", false, "Unquoted space not permitted"},
{"no,commas", false, "Unquoted comma not allowed"},
{"t[es]t", false, "Unquoted square brackets not allowed"},
{"james\\", false, "Cannot end with backslash quote"},
{"james\\@mail", true, "Quoted @ permitted"},
{"quoted\\ space", true, "Quoted space permitted"},
{"no\\,commas", true, "Quoted comma is OK"},
{"t\\[es\\]t", true, "Quoted brackets are OK"},
{"user\\name", true, "Should be able to quote a-z"},
{"USER\\NAME", true, "Should be able to quote A-Z"},
{"user\\1", true, "Should be able to quote a digit"},
{"one\\$\\|", true, "Should be able to quote plain specials"},
{"return\\\r", true, "Should be able to quote ASCII control chars"},
{"high\\\x80", false, "Should not accept > 7-bit quoted chars"},
{"quote\\\"", true, "Quoted double quote is permitted"},
{"\"james\"", true, "Quoted a-z is permitted"},
{"\"first last\"", true, "Quoted space is permitted"},
{"\"quoted@sign\"", true, "Quoted @ is allowed"},
{"\"qp\\\"quote\"", true, "Quoted quote within quoted string is OK"},
{"\"unterminated", false, "Quoted string must be terminated"},
{"\"unterminated\\\"", false, "Quoted string must be terminated"},
{"embed\"quote\"string", false, "Embedded quoted string is illegal"},
{"user+mailbox", true, "RFC3696 test case should be valid"},
{"customer/department=shipping", true, "RFC3696 test case should be valid"},
{"$A12345", true, "RFC3696 test case should be valid"},
{"!def!xyz%abc", true, "RFC3696 test case should be valid"},
{"_somename", true, "RFC3696 test case should be valid"},
}
for _, tt := range testTable {
_, _, err := policy.ParseEmailAddress(tt.input + "@domain.com")
if (err != nil) == tt.expect {
if err != nil {
t.Logf("Got error: %s", err)
}
t.Errorf("Expected %v for %q: %s", tt.expect, tt.input, tt.msg)
}
}
}

25
pkg/policy/recipient.go Normal file
View File

@@ -0,0 +1,25 @@
package policy
import "net/mail"
// Recipient represents a potential email recipient, allows policies for it to be queried.
type Recipient struct {
mail.Address
apolicy *Addressing
// LocalPart is the part of the address before @, including +extension.
LocalPart string
// Domain is the part of the address after @.
Domain string
// Mailbox is the canonical mailbox name for this recipient.
Mailbox string
}
// ShouldAccept returns true if Inbucket should accept mail for this recipient.
func (r *Recipient) ShouldAccept() bool {
return true
}
// ShouldStore returns true if Inbucket should store mail for this recipient.
func (r *Recipient) ShouldStore() bool {
return r.apolicy.ShouldStoreDomain(r.Domain)
}

View File

@@ -20,16 +20,11 @@ import (
// MailboxListV1 renders a list of messages in a mailbox // MailboxListV1 renders a list of messages in a mailbox
func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) messages, err := ctx.Manager.GetMetadata(name)
if err != nil {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
messages, err := mb.GetMessages()
if err != nil { if err != nil {
// This doesn't indicate empty, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("Failed to get messages for %v: %v", name, err) return fmt.Errorf("Failed to get messages for %v: %v", name, err)
@@ -40,12 +35,12 @@ func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
for i, msg := range messages { for i, msg := range messages {
jmessages[i] = &model.JSONMessageHeaderV1{ jmessages[i] = &model.JSONMessageHeaderV1{
Mailbox: name, Mailbox: name,
ID: msg.ID(), ID: msg.ID,
From: msg.From(), From: msg.From.String(),
To: msg.To(), To: stringutil.StringAddressList(msg.To),
Subject: msg.Subject(), Subject: msg.Subject,
Date: msg.Date(), Date: msg.Date,
Size: msg.Size(), Size: msg.Size,
} }
} }
return web.RenderJSON(w, jmessages) return web.RenderJSON(w, jmessages)
@@ -55,17 +50,12 @@ func MailboxListV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) msg, err := ctx.Manager.GetMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
msg, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -73,14 +63,7 @@ func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
// This doesn't indicate empty, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("GetMessage(%q) failed: %v", id, err)
} }
header, err := msg.ReadHeader() mime := msg.Envelope
if err != nil {
return fmt.Errorf("ReadHeader(%q) failed: %v", id, err)
}
mime, err := msg.ReadBody()
if err != nil {
return fmt.Errorf("ReadBody(%q) failed: %v", id, err)
}
attachments := make([]*model.JSONMessageAttachmentV1, len(mime.Attachments)) attachments := make([]*model.JSONMessageAttachmentV1, len(mime.Attachments))
for i, att := range mime.Attachments { for i, att := range mime.Attachments {
@@ -99,13 +82,13 @@ func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
return web.RenderJSON(w, return web.RenderJSON(w,
&model.JSONMessageV1{ &model.JSONMessageV1{
Mailbox: name, Mailbox: name,
ID: msg.ID(), ID: msg.ID,
From: msg.From(), From: msg.From.String(),
To: msg.To(), To: stringutil.StringAddressList(msg.To),
Subject: msg.Subject(), Subject: msg.Subject,
Date: msg.Date(), Date: msg.Date,
Size: msg.Size(), Size: msg.Size,
Header: header.Header, Header: mime.Root.Header,
Body: &model.JSONMessageBodyV1{ Body: &model.JSONMessageBodyV1{
Text: mime.Text, Text: mime.Text,
HTML: mime.HTML, HTML: mime.HTML,
@@ -117,17 +100,12 @@ func MailboxShowV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
// MailboxPurgeV1 deletes all messages from a mailbox // MailboxPurgeV1 deletes all messages from a mailbox
func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name)
if err != nil {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
// Delete all messages // Delete all messages
err = mb.Purge() err = ctx.Manager.PurgeMessages(name)
if err != nil { if err != nil {
return fmt.Errorf("Mailbox(%q) purge failed: %v", name, err) return fmt.Errorf("Mailbox(%q) purge failed: %v", name, err)
} }
@@ -140,61 +118,42 @@ func MailboxPurgeV1(w http.ResponseWriter, req *http.Request, ctx *web.Context)
func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxSourceV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name)
if err != nil { r, err := ctx.Manager.SourceReader(name, id)
// This doesn't indicate not found, likely an IO error if err == storage.ErrNotExist {
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate missing, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("SourceReader(%q) failed: %v", id, err)
} }
raw, err := message.ReadRaw() // Output message source
if err != nil {
return fmt.Errorf("ReadRaw(%q) failed: %v", id, err)
}
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
if _, err := io.WriteString(w, *raw); err != nil { _, err = io.Copy(w, r)
return err return err
}
return nil
} }
// MailboxDeleteV1 removes a particular message from a mailbox // MailboxDeleteV1 removes a particular message from a mailbox
func MailboxDeleteV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxDeleteV1(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) err = ctx.Manager.RemoveMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate missing, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("RemoveMessage(%q) failed: %v", id, err)
}
err = message.Delete()
if err != nil {
return fmt.Errorf("Delete(%q) failed: %v", id, err)
} }
return web.RenderJSON(w, "OK") return web.RenderJSON(w, "OK")

View File

@@ -2,14 +2,16 @@ package rest
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/mail" "net/mail"
"net/textproto"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/jhillyerd/inbucket/pkg/storage" "github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/test"
) )
const ( const (
@@ -31,8 +33,8 @@ const (
func TestRestMailboxList(t *testing.T) { func TestRestMailboxList(t *testing.T) {
// Setup // Setup
ds := &datastore.MockDataStore{} mm := test.NewManager()
logbuf := setupWebServer(ds) logbuf := setupWebServer(mm)
// Test invalid mailbox name // Test invalid mailbox name
w, err := testRestGet(baseURL + "/mailbox/foo@bar") w, err := testRestGet(baseURL + "/mailbox/foo@bar")
@@ -45,10 +47,6 @@ func TestRestMailboxList(t *testing.T) {
} }
// Test empty mailbox // Test empty mailbox
emptybox := &datastore.MockMailbox{}
ds.On("MailboxFor", "empty").Return(emptybox, nil)
emptybox.On("GetMessages").Return([]datastore.Message{}, nil)
w, err = testRestGet(baseURL + "/mailbox/empty") w, err = testRestGet(baseURL + "/mailbox/empty")
expectCode = 200 expectCode = 200
if err != nil { if err != nil {
@@ -58,30 +56,8 @@ func TestRestMailboxList(t *testing.T) {
t.Errorf("Expected code %v, got %v", expectCode, w.Code) t.Errorf("Expected code %v, got %v", expectCode, w.Code)
} }
// Test MailboxFor error // Test Mailbox error
ds.On("MailboxFor", "error").Return(&datastore.MockMailbox{}, fmt.Errorf("Internal error")) w, err = testRestGet(baseURL + "/mailbox/messageserr")
w, err = testRestGet(baseURL + "/mailbox/error")
expectCode = 500
if err != nil {
t.Fatal(err)
}
if w.Code != expectCode {
t.Errorf("Expected code %v, got %v", expectCode, w.Code)
}
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
// Test MailboxFor error
error2box := &datastore.MockMailbox{}
ds.On("MailboxFor", "error2").Return(error2box, nil)
error2box.On("GetMessages").Return([]datastore.Message{}, fmt.Errorf("Internal error 2"))
w, err = testRestGet(baseURL + "/mailbox/error2")
expectCode = 500 expectCode = 500
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -91,27 +67,24 @@ func TestRestMailboxList(t *testing.T) {
} }
// Test JSON message headers // Test JSON message headers
data1 := &InputMessageData{ meta1 := message.Metadata{
Mailbox: "good", Mailbox: "good",
ID: "0001", ID: "0001",
From: "from1", From: &mail.Address{Name: "", Address: "from1@host"},
To: []string{"to1"}, To: []*mail.Address{{Name: "", Address: "to1@host"}},
Subject: "subject 1", Subject: "subject 1",
Date: time.Date(2012, 2, 1, 10, 11, 12, 253, time.FixedZone("PST", -800)), Date: time.Date(2012, 2, 1, 10, 11, 12, 253, time.FixedZone("PST", -800)),
} }
data2 := &InputMessageData{ meta2 := message.Metadata{
Mailbox: "good", Mailbox: "good",
ID: "0002", ID: "0002",
From: "from2", From: &mail.Address{Name: "", Address: "from2@host"},
To: []string{"to1"}, To: []*mail.Address{{Name: "", Address: "to1@host"}},
Subject: "subject 2", Subject: "subject 2",
Date: time.Date(2012, 7, 1, 10, 11, 12, 253, time.FixedZone("PDT", -700)), Date: time.Date(2012, 7, 1, 10, 11, 12, 253, time.FixedZone("PDT", -700)),
} }
goodbox := &datastore.MockMailbox{} mm.AddMessage("good", &message.Message{Metadata: meta1})
ds.On("MailboxFor", "good").Return(goodbox, nil) mm.AddMessage("good", &message.Message{Metadata: meta2})
msg1 := data1.MockMessage()
msg2 := data2.MockMessage()
goodbox.On("GetMessages").Return([]datastore.Message{msg1, msg2}, nil)
// Check return code // Check return code
w, err = testRestGet(baseURL + "/mailbox/good") w, err = testRestGet(baseURL + "/mailbox/good")
@@ -130,21 +103,24 @@ func TestRestMailboxList(t *testing.T) {
t.Errorf("Failed to decode JSON: %v", err) t.Errorf("Failed to decode JSON: %v", err)
} }
if len(result) != 2 { if len(result) != 2 {
t.Errorf("Expected 2 results, got %v", len(result)) t.Fatalf("Expected 2 results, got %v", len(result))
}
if errors := data1.CompareToJSONHeaderMap(result[0]); len(errors) > 0 {
t.Logf("%v", result[0])
for _, e := range errors {
t.Error(e)
}
}
if errors := data2.CompareToJSONHeaderMap(result[1]); len(errors) > 0 {
t.Logf("%v", result[1])
for _, e := range errors {
t.Error(e)
}
} }
decodedStringEquals(t, result, "[0]/mailbox", "good")
decodedStringEquals(t, result, "[0]/id", "0001")
decodedStringEquals(t, result, "[0]/from", "<from1@host>")
decodedStringEquals(t, result, "[0]/to/[0]", "<to1@host>")
decodedStringEquals(t, result, "[0]/subject", "subject 1")
decodedStringEquals(t, result, "[0]/date", "2012-02-01T10:11:12.000000253-00:13")
decodedNumberEquals(t, result, "[0]/size", 0)
decodedStringEquals(t, result, "[1]/mailbox", "good")
decodedStringEquals(t, result, "[1]/id", "0002")
decodedStringEquals(t, result, "[1]/from", "<from2@host>")
decodedStringEquals(t, result, "[1]/to/[0]", "<to1@host>")
decodedStringEquals(t, result, "[1]/subject", "subject 2")
decodedStringEquals(t, result, "[1]/date", "2012-07-01T10:11:12.000000253-00:11")
decodedNumberEquals(t, result, "[1]/size", 0)
if t.Failed() { if t.Failed() {
// Wait for handler to finish logging // Wait for handler to finish logging
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
@@ -155,8 +131,8 @@ func TestRestMailboxList(t *testing.T) {
func TestRestMessage(t *testing.T) { func TestRestMessage(t *testing.T) {
// Setup // Setup
ds := &datastore.MockDataStore{} mm := test.NewManager()
logbuf := setupWebServer(ds) logbuf := setupWebServer(mm)
// Test invalid mailbox name // Test invalid mailbox name
w, err := testRestGet(baseURL + "/mailbox/foo@bar/0001") w, err := testRestGet(baseURL + "/mailbox/foo@bar/0001")
@@ -169,10 +145,6 @@ func TestRestMessage(t *testing.T) {
} }
// Test requesting a message that does not exist // Test requesting a message that does not exist
emptybox := &datastore.MockMailbox{}
ds.On("MailboxFor", "empty").Return(emptybox, nil)
emptybox.On("GetMessage", "0001").Return(&datastore.MockMessage{}, datastore.ErrNotExist)
w, err = testRestGet(baseURL + "/mailbox/empty/0001") w, err = testRestGet(baseURL + "/mailbox/empty/0001")
expectCode = 404 expectCode = 404
if err != nil { if err != nil {
@@ -182,9 +154,8 @@ func TestRestMessage(t *testing.T) {
t.Errorf("Expected code %v, got %v", expectCode, w.Code) t.Errorf("Expected code %v, got %v", expectCode, w.Code)
} }
// Test MailboxFor error // Test GetMessage error
ds.On("MailboxFor", "error").Return(&datastore.MockMailbox{}, fmt.Errorf("Internal error")) w, err = testRestGet(baseURL + "/mailbox/messageerr/0001")
w, err = testRestGet(baseURL + "/mailbox/error/0001")
expectCode = 500 expectCode = 500
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -200,38 +171,28 @@ func TestRestMessage(t *testing.T) {
_, _ = io.Copy(os.Stderr, logbuf) _, _ = io.Copy(os.Stderr, logbuf)
} }
// Test GetMessage error
error2box := &datastore.MockMailbox{}
ds.On("MailboxFor", "error2").Return(error2box, nil)
error2box.On("GetMessage", "0001").Return(&datastore.MockMessage{}, fmt.Errorf("Internal error 2"))
w, err = testRestGet(baseURL + "/mailbox/error2/0001")
expectCode = 500
if err != nil {
t.Fatal(err)
}
if w.Code != expectCode {
t.Errorf("Expected code %v, got %v", expectCode, w.Code)
}
// Test JSON message headers // Test JSON message headers
data1 := &InputMessageData{ msg1 := &message.Message{
Mailbox: "good", Metadata: message.Metadata{
ID: "0001", Mailbox: "good",
From: "from1", ID: "0001",
Subject: "subject 1", From: &mail.Address{Name: "", Address: "from1@host"},
Date: time.Date(2012, 2, 1, 10, 11, 12, 253, time.FixedZone("PST", -800)), To: []*mail.Address{{Name: "", Address: "to1@host"}},
Header: mail.Header{ Subject: "subject 1",
"To": []string{"fred@fish.com", "keyword@nsa.gov"}, Date: time.Date(2012, 2, 1, 10, 11, 12, 253, time.FixedZone("PST", -800)),
"From": []string{"noreply@inbucket.org"}, },
Envelope: &enmime.Envelope{
Text: "This is some text",
HTML: "This is some HTML",
Root: &enmime.Part{
Header: textproto.MIMEHeader{
"To": []string{"fred@fish.com", "keyword@nsa.gov"},
"From": []string{"noreply@inbucket.org"},
},
},
}, },
Text: "This is some text",
HTML: "This is some HTML",
} }
goodbox := &datastore.MockMailbox{} mm.AddMessage("good", msg1)
ds.On("MailboxFor", "good").Return(goodbox, nil)
msg1 := data1.MockMessage()
goodbox.On("GetMessage", "0001").Return(msg1, nil)
// Check return code // Check return code
w, err = testRestGet(baseURL + "/mailbox/good/0001") w, err = testRestGet(baseURL + "/mailbox/good/0001")
@@ -250,12 +211,18 @@ func TestRestMessage(t *testing.T) {
t.Errorf("Failed to decode JSON: %v", err) t.Errorf("Failed to decode JSON: %v", err)
} }
if errors := data1.CompareToJSONMessageMap(result); len(errors) > 0 { decodedStringEquals(t, result, "mailbox", "good")
t.Logf("%v", result) decodedStringEquals(t, result, "id", "0001")
for _, e := range errors { decodedStringEquals(t, result, "from", "<from1@host>")
t.Error(e) decodedStringEquals(t, result, "to/[0]", "<to1@host>")
} decodedStringEquals(t, result, "subject", "subject 1")
} decodedStringEquals(t, result, "date", "2012-02-01T10:11:12.000000253-00:13")
decodedNumberEquals(t, result, "size", 0)
decodedStringEquals(t, result, "body/text", "This is some text")
decodedStringEquals(t, result, "body/html", "This is some HTML")
decodedStringEquals(t, result, "header/To/[0]", "fred@fish.com")
decodedStringEquals(t, result, "header/To/[1]", "keyword@nsa.gov")
decodedStringEquals(t, result, "header/From/[0]", "noreply@inbucket.org")
if t.Failed() { if t.Failed() {
// Wait for handler to finish logging // Wait for handler to finish logging

View File

@@ -1,7 +1,6 @@
package model package model
import ( import (
"net/mail"
"time" "time"
) )
@@ -26,7 +25,7 @@ type JSONMessageV1 struct {
Date time.Time `json:"date"` Date time.Time `json:"date"`
Size int64 `json:"size"` Size int64 `json:"size"`
Body *JSONMessageBodyV1 `json:"body"` Body *JSONMessageBodyV1 `json:"body"`
Header mail.Header `json:"header"` Header map[string][]string `json:"header"`
Attachments []*JSONMessageAttachmentV1 `json:"attachments"` Attachments []*JSONMessageAttachmentV1 `json:"attachments"`
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/rest/model" "github.com/jhillyerd/inbucket/pkg/rest/model"
"github.com/jhillyerd/inbucket/pkg/server/web" "github.com/jhillyerd/inbucket/pkg/server/web"
"github.com/jhillyerd/inbucket/pkg/stringutil"
) )
const ( const (
@@ -173,7 +172,7 @@ func MonitorAllMessagesV1(
// notifies the client of messages received by a particular mailbox. // notifies the client of messages received by a particular mailbox.
func MonitorMailboxMessagesV1( func MonitorMailboxMessagesV1(
w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2,180 +2,19 @@ package rest
import ( import (
"bytes" "bytes"
"fmt"
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/mail" "strconv"
"time" "strings"
"testing"
"github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/server/web" "github.com/jhillyerd/inbucket/pkg/server/web"
"github.com/jhillyerd/inbucket/pkg/storage"
) )
type InputMessageData struct {
Mailbox, ID, From, Subject string
To []string
Date time.Time
Size int
Header mail.Header
HTML, Text string
}
func (d *InputMessageData) MockMessage() *datastore.MockMessage {
msg := &datastore.MockMessage{}
msg.On("ID").Return(d.ID)
msg.On("From").Return(d.From)
msg.On("To").Return(d.To)
msg.On("Subject").Return(d.Subject)
msg.On("Date").Return(d.Date)
msg.On("Size").Return(d.Size)
gomsg := &mail.Message{
Header: d.Header,
}
msg.On("ReadHeader").Return(gomsg, nil)
body := &enmime.Envelope{
Text: d.Text,
HTML: d.HTML,
}
msg.On("ReadBody").Return(body, nil)
return msg
}
// isJSONStringEqual is a utility function to return a nicely formatted message when
// comparing a string to a value received from a JSON map.
func isJSONStringEqual(key, expected string, received interface{}) (message string, ok bool) {
if value, ok := received.(string); ok {
if expected == value {
return "", true
}
return fmt.Sprintf("Expected value of key %v to be %q, got %q", key, expected, value), false
}
return fmt.Sprintf("Expected value of key %v to be a string, got %T", key, received), false
}
// isJSONNumberEqual is a utility function to return a nicely formatted message when
// comparing an float64 to a value received from a JSON map.
func isJSONNumberEqual(key string, expected float64, received interface{}) (message string, ok bool) {
if value, ok := received.(float64); ok {
if expected == value {
return "", true
}
return fmt.Sprintf("Expected %v to be %v, got %v", key, expected, value), false
}
return fmt.Sprintf("Expected %v to be a string, got %T", key, received), false
}
// CompareToJSONHeaderMap compares InputMessageData to a header map decoded from JSON,
// returning a list of things that did not match.
func (d *InputMessageData) CompareToJSONHeaderMap(json interface{}) (errors []string) {
if m, ok := json.(map[string]interface{}); ok {
if msg, ok := isJSONStringEqual(mailboxKey, d.Mailbox, m[mailboxKey]); !ok {
errors = append(errors, msg)
}
if msg, ok := isJSONStringEqual(idKey, d.ID, m[idKey]); !ok {
errors = append(errors, msg)
}
if msg, ok := isJSONStringEqual(fromKey, d.From, m[fromKey]); !ok {
errors = append(errors, msg)
}
for i, inputTo := range d.To {
if msg, ok := isJSONStringEqual(toKey, inputTo, m[toKey].([]interface{})[i]); !ok {
errors = append(errors, msg)
}
}
if msg, ok := isJSONStringEqual(subjectKey, d.Subject, m[subjectKey]); !ok {
errors = append(errors, msg)
}
exDate := d.Date.Format("2006-01-02T15:04:05.999999999-07:00")
if msg, ok := isJSONStringEqual(dateKey, exDate, m[dateKey]); !ok {
errors = append(errors, msg)
}
if msg, ok := isJSONNumberEqual(sizeKey, float64(d.Size), m[sizeKey]); !ok {
errors = append(errors, msg)
}
return errors
}
panic(fmt.Sprintf("Expected map[string]interface{} in json, got %T", json))
}
// CompareToJSONMessageMap compares InputMessageData to a message map decoded from JSON,
// returning a list of things that did not match.
func (d *InputMessageData) CompareToJSONMessageMap(json interface{}) (errors []string) {
// We need to check the same values as header first
errors = d.CompareToJSONHeaderMap(json)
if m, ok := json.(map[string]interface{}); ok {
// Get nested body map
if m[bodyKey] != nil {
if body, ok := m[bodyKey].(map[string]interface{}); ok {
if msg, ok := isJSONStringEqual(textKey, d.Text, body[textKey]); !ok {
errors = append(errors, msg)
}
if msg, ok := isJSONStringEqual(htmlKey, d.HTML, body[htmlKey]); !ok {
errors = append(errors, msg)
}
} else {
panic(fmt.Sprintf("Expected map[string]interface{} in json key %q, got %T",
bodyKey, m[bodyKey]))
}
} else {
errors = append(errors, fmt.Sprintf("Expected body in JSON %q but it was nil", bodyKey))
}
exDate := d.Date.Format("2006-01-02T15:04:05.999999999-07:00")
if msg, ok := isJSONStringEqual(dateKey, exDate, m[dateKey]); !ok {
errors = append(errors, msg)
}
if msg, ok := isJSONNumberEqual(sizeKey, float64(d.Size), m[sizeKey]); !ok {
errors = append(errors, msg)
}
// Get nested header map
if m[headerKey] != nil {
if header, ok := m[headerKey].(map[string]interface{}); ok {
// Loop over input (expected) header names
for name, keyInputHeaders := range d.Header {
// Make sure expected header name exists in received JSON
if keyOutputVals, ok := header[name]; ok {
if keyOutputHeaders, ok := keyOutputVals.([]interface{}); ok {
// Loop over input (expected) header values
for _, inputHeader := range keyInputHeaders {
hasValue := false
// Look for expected value in received headers
for _, outputHeader := range keyOutputHeaders {
if inputHeader == outputHeader {
hasValue = true
break
}
}
if !hasValue {
errors = append(errors, fmt.Sprintf(
"JSON %v[%q] missing value %q", headerKey, name, inputHeader))
}
}
} else {
// keyOutputValues was not a slice of interface{}
panic(fmt.Sprintf("Expected []interface{} in %v[%q], got %T", headerKey,
name, keyOutputVals))
}
} else {
errors = append(errors, fmt.Sprintf("JSON %v missing key %q", headerKey, name))
}
}
}
} else {
errors = append(errors, fmt.Sprintf("Expected header in JSON %q but it was nil", headerKey))
}
} else {
panic(fmt.Sprintf("Expected map[string]interface{} in json, got %T", json))
}
return errors
}
func testRestGet(url string) (*httptest.ResponseRecorder, error) { func testRestGet(url string) (*httptest.ResponseRecorder, error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
@@ -188,7 +27,7 @@ func testRestGet(url string) (*httptest.ResponseRecorder, error) {
return w, nil return w, nil
} }
func setupWebServer(ds datastore.DataStore) *bytes.Buffer { func setupWebServer(mm message.Manager) *bytes.Buffer {
// Capture log output // Capture log output
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
log.SetOutput(buf) log.SetOutput(buf)
@@ -200,8 +39,94 @@ func setupWebServer(ds datastore.DataStore) *bytes.Buffer {
PublicDir: "../themes/bootstrap/public", PublicDir: "../themes/bootstrap/public",
} }
shutdownChan := make(chan bool) shutdownChan := make(chan bool)
web.Initialize(cfg, shutdownChan, ds, &msghub.Hub{}) web.Initialize(cfg, shutdownChan, mm, &msghub.Hub{})
SetupRoutes(web.Router) SetupRoutes(web.Router)
return buf return buf
} }
func decodedNumberEquals(t *testing.T, json interface{}, path string, want float64) {
t.Helper()
els := strings.Split(path, "/")
val, msg := getDecodedPath(json, els...)
if msg != "" {
t.Errorf("JSON result%s", msg)
return
}
if got, ok := val.(float64); ok {
if got == want {
return
}
}
t.Errorf("JSON result/%s == %v (%T), want: %v", path, val, val, want)
}
func decodedStringEquals(t *testing.T, json interface{}, path string, want string) {
t.Helper()
els := strings.Split(path, "/")
val, msg := getDecodedPath(json, els...)
if msg != "" {
t.Errorf("JSON result%s", msg)
return
}
if got, ok := val.(string); ok {
if got == want {
return
}
}
t.Errorf("JSON result/%s == %v (%T), want: %v", path, val, val, want)
}
// getDecodedPath recursively navigates the specified path, returing the requested element. If
// something goes wrong, the returned string will contain an explanation.
//
// Named path elements require the parent element to be a map[string]interface{}, numbers in square
// brackets require the parent element to be a []interface{}.
//
// getDecodedPath(o, "users", "[1]", "name")
//
// is equivalent to the JavaScript:
//
// o.users[1].name
//
func getDecodedPath(o interface{}, path ...string) (interface{}, string) {
if len(path) == 0 {
return o, ""
}
if o == nil {
return nil, " is nil"
}
key := path[0]
present := false
var val interface{}
if key[0] == '[' {
// Expecting slice.
index, err := strconv.Atoi(strings.Trim(key, "[]"))
if err != nil {
return nil, "/" + key + " is not a slice index"
}
oslice, ok := o.([]interface{})
if !ok {
return nil, " is not a slice"
}
if index >= len(oslice) {
return nil, "/" + key + " is out of bounds"
}
val, present = oslice[index], true
} else {
// Expecting map.
omap, ok := o.(map[string]interface{})
if !ok {
return nil, " is not a map"
}
val, present = omap[key]
}
if !present {
return nil, "/" + key + " is missing"
}
result, msg := getDecodedPath(val, path[1:]...)
if msg != "" {
return nil, "/" + key + msg
}
return result, ""
}

View File

@@ -57,18 +57,17 @@ var commands = map[string]bool{
// Session defines an active POP3 session // Session defines an active POP3 session
type Session struct { type Session struct {
server *Server // Reference to the server we belong to server *Server // Reference to the server we belong to
id int // Session ID number id int // Session ID number
conn net.Conn // Our network connection conn net.Conn // Our network connection
remoteHost string // IP address of client remoteHost string // IP address of client
sendError error // Used to bail out of read loop on send error sendError error // Used to bail out of read loop on send error
state State // Current session state state State // Current session state
reader *bufio.Reader // Buffered reader for our net conn reader *bufio.Reader // Buffered reader for our net conn
user string // Mailbox name user string // Mailbox name
mailbox datastore.Mailbox // Mailbox instance messages []storage.Message // Slice of messages in mailbox
messages []datastore.Message // Slice of messages in mailbox retain []bool // Messages to retain upon UPDATE (true=retain)
retain []bool // Messages to retain upon UPDATE (true=retain) msgCount int // Number of undeleted messages
msgCount int // Number of undeleted messages
} }
// NewSession creates a new POP3 session // NewSession creates a new POP3 session
@@ -195,14 +194,6 @@ func (ses *Session) authorizationHandler(cmd string, args []string) {
if ses.user == "" { if ses.user == "" {
ses.ooSeq(cmd) ses.ooSeq(cmd)
} else { } else {
var err error
ses.mailbox, err = ses.server.dataStore.MailboxFor(ses.user)
if err != nil {
ses.logError("Failed to open mailbox for %v", ses.user)
ses.send(fmt.Sprintf("-ERR Failed to open mailbox for %v", ses.user))
ses.enterState(QUIT)
return
}
ses.loadMailbox() ses.loadMailbox()
ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user)) ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user))
ses.enterState(TRANSACTION) ses.enterState(TRANSACTION)
@@ -214,14 +205,6 @@ func (ses *Session) authorizationHandler(cmd string, args []string) {
return return
} }
ses.user = args[0] ses.user = args[0]
var err error
ses.mailbox, err = ses.server.dataStore.MailboxFor(ses.user)
if err != nil {
ses.logError("Failed to open mailbox for %v", ses.user)
ses.send(fmt.Sprintf("-ERR Failed to open mailbox for %v", ses.user))
ses.enterState(QUIT)
return
}
ses.loadMailbox() ses.loadMailbox()
ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user)) ses.send(fmt.Sprintf("+OK Found %v messages for %v", ses.msgCount, ses.user))
ses.enterState(TRANSACTION) ses.enterState(TRANSACTION)
@@ -432,8 +415,8 @@ func (ses *Session) transactionHandler(cmd string, args []string) {
} }
// Send the contents of the message to the client // Send the contents of the message to the client
func (ses *Session) sendMessage(msg datastore.Message) { func (ses *Session) sendMessage(msg storage.Message) {
reader, err := msg.RawReader() reader, err := msg.Source()
if err != nil { if err != nil {
ses.logError("Failed to read message for RETR command") ses.logError("Failed to read message for RETR command")
ses.send("-ERR Failed to RETR that message, internal error") ses.send("-ERR Failed to RETR that message, internal error")
@@ -465,8 +448,8 @@ func (ses *Session) sendMessage(msg datastore.Message) {
} }
// Send the headers plus the top N lines to the client // Send the headers plus the top N lines to the client
func (ses *Session) sendMessageTop(msg datastore.Message, lineCount int) { func (ses *Session) sendMessageTop(msg storage.Message, lineCount int) {
reader, err := msg.RawReader() reader, err := msg.Source()
if err != nil { if err != nil {
ses.logError("Failed to read message for RETR command") ses.logError("Failed to read message for RETR command")
ses.send("-ERR Failed to RETR that message, internal error") ses.send("-ERR Failed to RETR that message, internal error")
@@ -513,12 +496,11 @@ func (ses *Session) sendMessageTop(msg datastore.Message, lineCount int) {
// Load the users mailbox // Load the users mailbox
func (ses *Session) loadMailbox() { func (ses *Session) loadMailbox() {
var err error m, err := ses.server.dataStore.GetMessages(ses.user)
ses.messages, err = ses.mailbox.GetMessages()
if err != nil { if err != nil {
ses.logError("Failed to load messages for %v", ses.user) ses.logError("Failed to load messages for %v: %v", ses.user, err)
} }
ses.messages = m
ses.retainAll() ses.retainAll()
} }
@@ -540,7 +522,7 @@ func (ses *Session) processDeletes() {
for i, msg := range ses.messages { for i, msg := range ses.messages {
if !ses.retain[i] { if !ses.retain[i] {
ses.logTrace("Deleting %v", msg) ses.logTrace("Deleting %v", msg)
if err := msg.Delete(); err != nil { if err := ses.server.dataStore.RemoveMessage(ses.user, msg.ID()); err != nil {
ses.logWarn("Error deleting %v: %v", msg, err) ses.logWarn("Error deleting %v: %v", msg, err)
} }
} }

View File

@@ -17,14 +17,14 @@ type Server struct {
host string host string
domain string domain string
maxIdleSeconds int maxIdleSeconds int
dataStore datastore.DataStore dataStore storage.Store
listener net.Listener listener net.Listener
globalShutdown chan bool globalShutdown chan bool
waitgroup *sync.WaitGroup waitgroup *sync.WaitGroup
} }
// New creates a new Server struct // New creates a new Server struct
func New(cfg config.POP3Config, shutdownChan chan bool, ds datastore.DataStore) *Server { func New(cfg config.POP3Config, shutdownChan chan bool, ds storage.Store) *Server {
return &Server{ return &Server{
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port), host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
domain: cfg.Domain, domain: cfg.Domain,

View File

@@ -3,7 +3,6 @@ package smtp
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"container/list"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -13,9 +12,7 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/stringutil"
) )
// State tracks the current mode of our SMTP state machine // State tracks the current mode of our SMTP state machine
@@ -70,12 +67,6 @@ var commands = map[string]bool{
"TURN": true, "TURN": true,
} }
// recipientDetails for message delivery
type recipientDetails struct {
address, localPart, domainPart string
mailbox datastore.Mailbox
}
// Session holds the state of an SMTP session // Session holds the state of an SMTP session
type Session struct { type Session struct {
server *Server server *Server
@@ -87,14 +78,22 @@ type Session struct {
state State state State
reader *bufio.Reader reader *bufio.Reader
from string from string
recipients *list.List recipients []*policy.Recipient
} }
// NewSession creates a new Session for the given connection // NewSession creates a new Session for the given connection
func NewSession(server *Server, id int, conn net.Conn) *Session { func NewSession(server *Server, id int, conn net.Conn) *Session {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
return &Session{server: server, id: id, conn: conn, state: GREET, reader: reader, remoteHost: host} return &Session{
server: server,
id: id,
conn: conn,
state: GREET,
reader: reader,
remoteHost: host,
recipients: make([]*policy.Recipient, 0),
}
} }
func (ss *Session) String() string { func (ss *Session) String() string {
@@ -267,7 +266,7 @@ func (ss *Session) readyHandler(cmd string, arg string) {
return return
} }
from := m[1] from := m[1]
if _, _, err := stringutil.ParseEmailAddress(from); err != nil { if _, _, err := policy.ParseEmailAddress(from); err != nil {
ss.send("501 Bad sender address syntax") ss.send("501 Bad sender address syntax")
ss.logWarn("Bad address as MAIL arg: %q, %s", from, err) ss.logWarn("Bad address as MAIL arg: %q, %s", from, err)
return return
@@ -296,7 +295,6 @@ func (ss *Session) readyHandler(cmd string, arg string) {
} }
} }
ss.from = from ss.from = from
ss.recipients = list.New()
ss.logInfo("Mail from: %v", from) ss.logInfo("Mail from: %v", from)
ss.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from)) ss.send(fmt.Sprintf("250 Roger, accepting mail from <%v>", from))
ss.enterState(MAIL) ss.enterState(MAIL)
@@ -315,20 +313,21 @@ func (ss *Session) mailHandler(cmd string, arg string) {
return return
} }
// This trim is probably too forgiving // This trim is probably too forgiving
recip := strings.Trim(arg[3:], "<> ") addr := strings.Trim(arg[3:], "<> ")
if _, _, err := stringutil.ParseEmailAddress(recip); err != nil { recip, err := ss.server.apolicy.NewRecipient(addr)
if err != nil {
ss.send("501 Bad recipient address syntax") ss.send("501 Bad recipient address syntax")
ss.logWarn("Bad address as RCPT arg: %q, %s", recip, err) ss.logWarn("Bad address as RCPT arg: %q, %s", addr, err)
return return
} }
if ss.recipients.Len() >= ss.server.maxRecips { if len(ss.recipients) >= ss.server.maxRecips {
ss.logWarn("Maximum limit of %v recipients reached", ss.server.maxRecips) ss.logWarn("Maximum limit of %v recipients reached", ss.server.maxRecips)
ss.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", ss.server.maxRecips)) ss.send(fmt.Sprintf("552 Maximum limit of %v recipients reached", ss.server.maxRecips))
return return
} }
ss.recipients.PushBack(recip) ss.recipients = append(ss.recipients, recip)
ss.logInfo("Recipient: %v", recip) ss.logInfo("Recipient: %v", addr)
ss.send(fmt.Sprintf("250 I'll make sure <%v> gets this", recip)) ss.send(fmt.Sprintf("250 I'll make sure <%v> gets this", addr))
return return
case "DATA": case "DATA":
if arg != "" { if arg != "" {
@@ -336,7 +335,7 @@ func (ss *Session) mailHandler(cmd string, arg string) {
ss.logWarn("Got unexpected args on DATA: %q", arg) ss.logWarn("Got unexpected args on DATA: %q", arg)
return return
} }
if ss.recipients.Len() > 0 { if len(ss.recipients) > 0 {
// We have recipients, go to accept data // We have recipients, go to accept data
ss.enterState(DATA) ss.enterState(DATA)
return return
@@ -350,41 +349,10 @@ func (ss *Session) mailHandler(cmd string, arg string) {
// DATA // DATA
func (ss *Session) dataHandler() { func (ss *Session) dataHandler() {
recipients := make([]recipientDetails, 0, ss.recipients.Len())
// Get a Mailbox and a new Message for each recipient
msgSize := 0
if ss.server.storeMessages {
for e := ss.recipients.Front(); e != nil; e = e.Next() {
recip := e.Value.(string)
local, domain, err := stringutil.ParseEmailAddress(recip)
if err != nil {
ss.logError("Failed to parse address for %q", recip)
ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", recip))
ss.reset()
return
}
if strings.ToLower(domain) != ss.server.domainNoStore {
// Not our "no store" domain, so store the message
mb, err := ss.server.dataStore.MailboxFor(local)
if err != nil {
ss.logError("Failed to open mailbox for %q: %s", local, err)
ss.send(fmt.Sprintf("451 Failed to open mailbox for %v", local))
ss.reset()
return
}
recipients = append(recipients, recipientDetails{recip, local, domain, mb})
} else {
log.Tracef("Not storing message for %q", recip)
}
}
}
ss.send("354 Start mail input; end with <CRLF>.<CRLF>") ss.send("354 Start mail input; end with <CRLF>.<CRLF>")
var lineBuf bytes.Buffer msgBuf := &bytes.Buffer{}
msgBuf := make([][]byte, 0, 1024)
for { for {
lineBuf.Reset() lineBuf, err := ss.readByteLine()
err := ss.readByteLine(&lineBuf)
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); ok { if netErr, ok := err.(net.Error); ok {
if netErr.Timeout() { if netErr.Timeout() {
@@ -395,103 +363,44 @@ func (ss *Session) dataHandler() {
ss.enterState(QUIT) ss.enterState(QUIT)
return return
} }
line := lineBuf.Bytes() if bytes.Equal(lineBuf, []byte(".\r\n")) || bytes.Equal(lineBuf, []byte(".\n")) {
// ss.logTrace("DATA: %q", line) // Mail data complete.
if string(line) == ".\r\n" || string(line) == ".\n" { tstamp := time.Now().Format(timeStampFormat)
// Mail data complete for _, recip := range ss.recipients {
if ss.server.storeMessages { if recip.ShouldStore() {
// Create a message for each valid recipient // Generate Received header.
for _, r := range recipients { prefix := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
// TODO temporary hack to fix #77 until datastore revamp ss.remoteDomain, ss.remoteHost, ss.server.domain, recip.Address.Address,
mu, err := ss.server.dataStore.LockFor(r.localPart) tstamp)
// Deliver message.
_, err := ss.server.manager.Deliver(
recip, ss.from, ss.recipients, prefix, msgBuf.Bytes())
if err != nil { if err != nil {
ss.logError("Failed to get lock for %q: %s", r.localPart, err) ss.logError("delivery for %v: %v", recip.LocalPart, err)
// Delivery failure ss.send(fmt.Sprintf("451 Failed to store message for %v", recip.LocalPart))
ss.send(fmt.Sprintf("451 Failed to store message for %v", r.localPart))
ss.reset()
return
}
mu.Lock()
ok := ss.deliverMessage(r, msgBuf)
mu.Unlock()
if ok {
expReceivedTotal.Add(1)
} else {
// Delivery failure
ss.send(fmt.Sprintf("451 Failed to store message for %v", r.localPart))
ss.reset() ss.reset()
return return
} }
} }
} else {
expReceivedTotal.Add(1) expReceivedTotal.Add(1)
} }
ss.send("250 Mail accepted for delivery") ss.send("250 Mail accepted for delivery")
ss.logInfo("Message size %v bytes", msgSize) ss.logInfo("Message size %v bytes", msgBuf.Len())
ss.reset() ss.reset()
return return
} }
// SMTP RFC says remove leading periods from input // RFC: remove leading periods from DATA.
if len(line) > 0 && line[0] == '.' { if len(lineBuf) > 0 && lineBuf[0] == '.' {
line = line[1:] lineBuf = lineBuf[1:]
} }
// Second append copies line/lineBuf so we can reuse it msgBuf.Write(lineBuf)
msgBuf = append(msgBuf, append([]byte{}, line...)) if msgBuf.Len() > ss.server.maxMessageBytes {
msgSize += len(line)
if msgSize > ss.server.maxMessageBytes {
// Max message size exceeded
ss.send("552 Maximum message size exceeded") ss.send("552 Maximum message size exceeded")
ss.logWarn("Max message size exceeded while in DATA") ss.logWarn("Max message size exceeded while in DATA")
ss.reset() ss.reset()
// Should really cleanup the crap on filesystem (after issue #23)
return return
} }
} // end for
}
// deliverMessage creates and populates a new Message for the specified recipient
func (ss *Session) deliverMessage(r recipientDetails, msgBuf [][]byte) (ok bool) {
msg, err := r.mailbox.NewMessage()
if err != nil {
ss.logError("Failed to create message for %q: %s", r.localPart, err)
return false
} }
// Generate Received header
stamp := time.Now().Format(timeStampFormat)
recd := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n",
ss.remoteDomain, ss.remoteHost, ss.server.domain, r.address, stamp)
if err := msg.Append([]byte(recd)); err != nil {
ss.logError("Failed to write received header for %q: %s", r.localPart, err)
return false
}
// Append lines from msgBuf
for _, line := range msgBuf {
if err := msg.Append(line); err != nil {
ss.logError("Failed to append to mailbox %v: %v", r.mailbox, err)
// Should really cleanup the crap on filesystem
return false
}
}
if err := msg.Close(); err != nil {
ss.logError("Error while closing message for %v: %v", r.mailbox, err)
return false
}
// Broadcast message information
broadcast := msghub.Message{
Mailbox: r.mailbox.Name(),
ID: msg.ID(),
From: msg.From(),
To: msg.To(),
Subject: msg.Subject(),
Date: msg.Date(),
Size: msg.Size(),
}
ss.server.msgHub.Dispatch(broadcast)
return true
} }
func (ss *Session) enterState(state State) { func (ss *Session) enterState(state State) {
@@ -522,18 +431,12 @@ func (ss *Session) send(msg string) {
ss.logTrace(">> %v >>", msg) ss.logTrace(">> %v >>", msg)
} }
// readByteLine reads a line of input into the provided buffer. Does // readByteLine reads a line of input, returns byte slice.
// not reset the Buffer - please do so prior to calling. func (ss *Session) readByteLine() ([]byte, error) {
func (ss *Session) readByteLine(buf io.Writer) error {
if err := ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil { if err := ss.conn.SetReadDeadline(ss.nextDeadline()); err != nil {
return err return nil, err
} }
line, err := ss.reader.ReadBytes('\n') return ss.reader.ReadBytes('\n')
if err != nil {
return err
}
_, err = buf.Write(line)
return err
} }
// Reads a line of input // Reads a line of input

View File

@@ -2,7 +2,6 @@ package smtp
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
@@ -14,8 +13,10 @@ import (
"time" "time"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/storage" "github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/test"
) )
type scriptStep struct { type scriptStep struct {
@@ -25,10 +26,8 @@ type scriptStep struct {
// Test commands in GREET state // Test commands in GREET state
func TestGreetState(t *testing.T) { func TestGreetState(t *testing.T) {
// Setup mock objects ds := test.NewStore()
mds := &datastore.MockDataStore{} server, logbuf, teardown := setupSMTPServer(ds)
server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
// Test out some mangled HELOs // Test out some mangled HELOs
@@ -82,10 +81,8 @@ func TestGreetState(t *testing.T) {
// Test commands in READY state // Test commands in READY state
func TestReadyState(t *testing.T) { func TestReadyState(t *testing.T) {
// Setup mock objects ds := test.NewStore()
mds := &datastore.MockDataStore{} server, logbuf, teardown := setupSMTPServer(ds)
server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
// Test out some mangled READY commands // Test out some mangled READY commands
@@ -143,21 +140,7 @@ func TestReadyState(t *testing.T) {
// Test commands in MAIL state // Test commands in MAIL state
func TestMailState(t *testing.T) { func TestMailState(t *testing.T) {
// Setup mock objects mds := test.NewStore()
mds := &datastore.MockDataStore{}
mb1 := &datastore.MockMailbox{}
msg1 := &datastore.MockMessage{}
mds.On("MailboxFor", "u1").Return(mb1, nil)
mb1.On("NewMessage").Return(msg1, nil)
mb1.On("Name").Return("u1")
msg1.On("ID").Return("")
msg1.On("From").Return("")
msg1.On("To").Return(make([]string, 0))
msg1.On("Date").Return(time.Time{})
msg1.On("Subject").Return("")
msg1.On("Size").Return(0)
msg1.On("Close").Return(nil)
server, logbuf, teardown := setupSMTPServer(mds) server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
@@ -189,10 +172,7 @@ func TestMailState(t *testing.T) {
{"RCPT TO: u4@gmail.com", 250}, {"RCPT TO: u4@gmail.com", 250},
{"RSET", 250}, {"RSET", 250},
{"MAIL FROM:<john@gmail.com>", 250}, {"MAIL FROM:<john@gmail.com>", 250},
{"RCPT TO:<user\\@internal@external.com", 250}, {`RCPT TO:<"first/last"@host.com`, 250},
{"RCPT TO:<\"first last\"@host.com", 250},
{"RCPT TO:<user\\>name@host.com>", 250},
{"RCPT TO:<\"user>name\"@host.com>", 250},
} }
if err := playSession(t, server, script); err != nil { if err := playSession(t, server, script); err != nil {
t.Error(err) t.Error(err)
@@ -258,21 +238,7 @@ func TestMailState(t *testing.T) {
// Test commands in DATA state // Test commands in DATA state
func TestDataState(t *testing.T) { func TestDataState(t *testing.T) {
// Setup mock objects mds := test.NewStore()
mds := &datastore.MockDataStore{}
mb1 := &datastore.MockMailbox{}
msg1 := &datastore.MockMessage{}
mds.On("MailboxFor", "u1").Return(mb1, nil)
mb1.On("NewMessage").Return(msg1, nil)
mb1.On("Name").Return("u1")
msg1.On("ID").Return("")
msg1.On("From").Return("")
msg1.On("To").Return(make([]string, 0))
msg1.On("Date").Return(time.Time{})
msg1.On("Subject").Return("")
msg1.On("Size").Return(0)
msg1.On("Close").Return(nil)
server, logbuf, teardown := setupSMTPServer(mds) server, logbuf, teardown := setupSMTPServer(mds)
defer teardown() defer teardown()
@@ -280,7 +246,6 @@ func TestDataState(t *testing.T) {
pipe := setupSMTPSession(server) pipe := setupSMTPSession(server)
c := textproto.NewConn(pipe) c := textproto.NewConn(pipe)
// Get us into DATA state
if code, _, err := c.ReadCodeLine(220); err != nil { if code, _, err := c.ReadCodeLine(220); err != nil {
t.Errorf("Expected a 220 greeting, got %v", code) t.Errorf("Expected a 220 greeting, got %v", code)
} }
@@ -307,6 +272,33 @@ Hi!
t.Errorf("Expected a 250 greeting, got %v", code) t.Errorf("Expected a 250 greeting, got %v", code)
} }
// Test with no useful headers.
pipe = setupSMTPSession(server)
c = textproto.NewConn(pipe)
if code, _, err := c.ReadCodeLine(220); err != nil {
t.Errorf("Expected a 220 greeting, got %v", code)
}
script = []scriptStep{
{"HELO localhost", 250},
{"MAIL FROM:<john@gmail.com>", 250},
{"RCPT TO:<u1@gmail.com>", 250},
{"DATA", 354},
}
if err := playScriptAgainst(t, c, script); err != nil {
t.Error(err)
}
// Send a message
body = `X-Useless-Header: true
Hi! Can you still deliver this?
`
dw = c.DotWriter()
_, _ = io.WriteString(dw, body)
_ = dw.Close()
if code, _, err := c.ReadCodeLine(250); err != nil {
t.Errorf("Expected a 250 greeting, got %v", code)
}
if t.Failed() { if t.Failed() {
// Wait for handler to finish logging // Wait for handler to finish logging
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
@@ -367,7 +359,7 @@ func (m *mockConn) SetDeadline(t time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func setupSMTPServer(ds datastore.DataStore) (s *Server, buf *bytes.Buffer, teardown func()) { func setupSMTPServer(ds storage.Store) (s *Server, buf *bytes.Buffer, teardown func()) {
// Test Server Config // Test Server Config
cfg := config.SMTPConfig{ cfg := config.SMTPConfig{
IP4address: net.IPv4(127, 0, 0, 1), IP4address: net.IPv4(127, 0, 0, 1),
@@ -386,12 +378,12 @@ func setupSMTPServer(ds datastore.DataStore) (s *Server, buf *bytes.Buffer, tear
// Create a server, don't start it // Create a server, don't start it
shutdownChan := make(chan bool) shutdownChan := make(chan bool)
ctx, cancel := context.WithCancel(context.Background())
teardown = func() { teardown = func() {
close(shutdownChan) close(shutdownChan)
cancel()
} }
s = NewServer(cfg, shutdownChan, ds, msghub.New(ctx, 100)) apolicy := &policy.Addressing{Config: cfg}
manager := &message.StoreManager{Store: ds}
s = NewServer(cfg, shutdownChan, manager, apolicy)
return s, buf, teardown return s, buf, teardown
} }

View File

@@ -12,8 +12,8 @@ import (
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/storage" "github.com/jhillyerd/inbucket/pkg/policy"
) )
func init() { func init() {
@@ -48,10 +48,9 @@ type Server struct {
storeMessages bool storeMessages bool
// Dependencies // Dependencies
dataStore datastore.DataStore // Mailbox/message store apolicy *policy.Addressing // Address policy.
globalShutdown chan bool // Shuts down Inbucket globalShutdown chan bool // Shuts down Inbucket.
msgHub *msghub.Hub // Pub/sub for message info manager message.Manager // Used to deliver messages.
retentionScanner *datastore.RetentionScanner // Deletes expired messages
// State // State
listener net.Listener // Incoming network connections listener net.Listener // Incoming network connections
@@ -83,21 +82,21 @@ var (
func NewServer( func NewServer(
cfg config.SMTPConfig, cfg config.SMTPConfig,
globalShutdown chan bool, globalShutdown chan bool,
ds datastore.DataStore, manager message.Manager,
msgHub *msghub.Hub) *Server { apolicy *policy.Addressing,
) *Server {
return &Server{ return &Server{
host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port), host: fmt.Sprintf("%v:%v", cfg.IP4address, cfg.IP4port),
domain: cfg.Domain, domain: cfg.Domain,
domainNoStore: strings.ToLower(cfg.DomainNoStore), domainNoStore: strings.ToLower(cfg.DomainNoStore),
maxRecips: cfg.MaxRecipients, maxRecips: cfg.MaxRecipients,
maxIdleSeconds: cfg.MaxIdleSeconds, maxIdleSeconds: cfg.MaxIdleSeconds,
maxMessageBytes: cfg.MaxMessageBytes, maxMessageBytes: cfg.MaxMessageBytes,
storeMessages: cfg.StoreMessages, storeMessages: cfg.StoreMessages,
globalShutdown: globalShutdown, globalShutdown: globalShutdown,
dataStore: ds, manager: manager,
msgHub: msgHub, apolicy: apolicy,
retentionScanner: datastore.NewRetentionScanner(ds, globalShutdown), waitgroup: new(sync.WaitGroup),
waitgroup: new(sync.WaitGroup),
} }
} }
@@ -124,9 +123,6 @@ func (s *Server) Start(ctx context.Context) {
log.Infof("Messages sent to domain '%v' will be discarded", s.domainNoStore) log.Infof("Messages sent to domain '%v' will be discarded", s.domainNoStore)
} }
// Start retention scanner
s.retentionScanner.Start()
// Listener go routine // Listener go routine
go s.serve(ctx) go s.serve(ctx)
@@ -195,5 +191,4 @@ func (s *Server) Drain() {
// Wait for sessions to close // Wait for sessions to close
s.waitgroup.Wait() s.waitgroup.Wait()
log.Tracef("SMTP connections have drained") log.Tracef("SMTP connections have drained")
s.retentionScanner.Join()
} }

View File

@@ -7,16 +7,16 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/storage"
) )
// Context is passed into every request handler function // Context is passed into every request handler function
type Context struct { type Context struct {
Vars map[string]string Vars map[string]string
Session *sessions.Session Session *sessions.Session
DataStore datastore.DataStore
MsgHub *msghub.Hub MsgHub *msghub.Hub
Manager message.Manager
WebConfig config.WebConfig WebConfig config.WebConfig
IsJSON bool IsJSON bool
} }
@@ -59,8 +59,8 @@ func NewContext(req *http.Request) (*Context, error) {
ctx := &Context{ ctx := &Context{
Vars: vars, Vars: vars,
Session: sess, Session: sess,
DataStore: DataStore,
MsgHub: msgHub, MsgHub: msgHub,
Manager: manager,
WebConfig: webConfig, WebConfig: webConfig,
IsJSON: headerMatch(req, "Accept", "application/json"), IsJSON: headerMatch(req, "Accept", "application/json"),
} }

View File

@@ -14,19 +14,17 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/msghub" "github.com/jhillyerd/inbucket/pkg/msghub"
"github.com/jhillyerd/inbucket/pkg/storage"
) )
// Handler is a function type that handles an HTTP request in Inbucket // Handler is a function type that handles an HTTP request in Inbucket
type Handler func(http.ResponseWriter, *http.Request, *Context) error type Handler func(http.ResponseWriter, *http.Request, *Context) error
var ( var (
// DataStore is where all the mailboxes and messages live
DataStore datastore.DataStore
// msgHub holds a reference to the message pub/sub system // msgHub holds a reference to the message pub/sub system
msgHub *msghub.Hub msgHub *msghub.Hub
manager message.Manager
// Router is shared between httpd, webui and rest packages. It sends // Router is shared between httpd, webui and rest packages. It sends
// incoming requests to the correct handler function // incoming requests to the correct handler function
@@ -51,15 +49,15 @@ func init() {
func Initialize( func Initialize(
cfg config.WebConfig, cfg config.WebConfig,
shutdownChan chan bool, shutdownChan chan bool,
ds datastore.DataStore, mm message.Manager,
mh *msghub.Hub) { mh *msghub.Hub) {
webConfig = cfg webConfig = cfg
globalShutdown = shutdownChan globalShutdown = shutdownChan
// NewContext() will use this DataStore for the web handlers // NewContext() will use this DataStore for the web handlers
DataStore = ds
msgHub = mh msgHub = mh
manager = mm
// Content Paths // Content Paths
log.Infof("HTTP templates mapped to %q", cfg.TemplateDir) log.Infof("HTTP templates mapped to %q", cfg.TemplateDir)

View File

@@ -1,56 +0,0 @@
// Package datastore contains implementation independent datastore logic
package datastore
import (
"errors"
"io"
"net/mail"
"sync"
"time"
"github.com/jhillyerd/enmime"
)
var (
// ErrNotExist indicates the requested message does not exist
ErrNotExist = errors.New("Message does not exist")
// ErrNotWritable indicates the message is closed; no longer writable
ErrNotWritable = errors.New("Message not writable")
)
// DataStore is an interface to get Mailboxes stored in Inbucket
type DataStore interface {
MailboxFor(emailAddress string) (Mailbox, error)
AllMailboxes() ([]Mailbox, error)
// LockFor is a temporary hack to fix #77 until Datastore revamp
LockFor(emailAddress string) (*sync.RWMutex, error)
}
// Mailbox is an interface to get and manipulate messages in a DataStore
type Mailbox interface {
GetMessages() ([]Message, error)
GetMessage(id string) (Message, error)
Purge() error
NewMessage() (Message, error)
Name() string
String() string
}
// Message is an interface for a single message in a Mailbox
type Message interface {
ID() string
From() string
To() []string
Date() time.Time
Subject() string
RawReader() (reader io.ReadCloser, err error)
ReadHeader() (msg *mail.Message, err error)
ReadBody() (body *enmime.Envelope, err error)
ReadRaw() (raw *string, err error)
Append(data []byte) error
Close() error
Delete() error
String() string
Size() int64
}

View File

@@ -1,270 +1,95 @@
package filestore package file
import ( import (
"bufio"
"fmt"
"io" "io"
"io/ioutil"
"net/mail" "net/mail"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
"github.com/jhillyerd/enmime"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/storage"
) )
// FileMessage implements Message and contains a little bit of data about a // Message implements Message and contains a little bit of data about a
// particular email message, and methods to retrieve the rest of it from disk. // particular email message, and methods to retrieve the rest of it from disk.
type FileMessage struct { type Message struct {
mailbox *FileMailbox mailbox *mbox
// Stored in GOB // Stored in GOB
Fid string Fid string
Fdate time.Time Fdate time.Time
Ffrom string Ffrom *mail.Address
Fto []string Fto []*mail.Address
Fsubject string Fsubject string
Fsize int64 Fsize int64
// These are for creating new messages only
writable bool
writerFile *os.File
writer *bufio.Writer
} }
// NewMessage creates a new FileMessage object and sets the Date and Id fields. // newMessage creates a new FileMessage object and sets the Date and ID fields.
// It will also delete messages over messageCap if configured. // It will also delete messages over messageCap if configured.
func (mb *FileMailbox) NewMessage() (datastore.Message, error) { func (mb *mbox) newMessage() (*Message, error) {
// Load index // Load index
if !mb.indexLoaded { if !mb.indexLoaded {
if err := mb.readIndex(); err != nil { if err := mb.readIndex(); err != nil {
return nil, err return nil, err
} }
} }
// Delete old messages over messageCap // Delete old messages over messageCap
if mb.store.messageCap > 0 { if mb.store.messageCap > 0 {
for len(mb.messages) >= mb.store.messageCap { for len(mb.messages) >= mb.store.messageCap {
log.Infof("Mailbox %q over configured message cap", mb.name) log.Infof("Mailbox %q over configured message cap", mb.name)
if err := mb.messages[0].Delete(); err != nil { if err := mb.removeMessage(mb.messages[0].ID()); err != nil {
log.Errorf("Error deleting message: %s", err) log.Errorf("Error deleting message: %s", err)
} }
} }
} }
date := time.Now() date := time.Now()
id := generateID(date) id := generateID(date)
return &FileMessage{mailbox: mb, Fid: id, Fdate: date, writable: true}, nil return &Message{mailbox: mb, Fid: id, Fdate: date}, nil
}
// Mailbox returns the name of the mailbox this message resides in.
func (m *Message) Mailbox() string {
return m.mailbox.name
} }
// ID gets the ID of the Message // ID gets the ID of the Message
func (m *FileMessage) ID() string { func (m *Message) ID() string {
return m.Fid return m.Fid
} }
// Date returns the date/time this Message was received by Inbucket // Date returns the date/time this Message was received by Inbucket
func (m *FileMessage) Date() time.Time { func (m *Message) Date() time.Time {
return m.Fdate return m.Fdate
} }
// From returns the value of the Message From header // From returns the value of the Message From header
func (m *FileMessage) From() string { func (m *Message) From() *mail.Address {
return m.Ffrom return m.Ffrom
} }
// To returns the value of the Message To header // To returns the value of the Message To header
func (m *FileMessage) To() []string { func (m *Message) To() []*mail.Address {
return m.Fto return m.Fto
} }
// Subject returns the value of the Message Subject header // Subject returns the value of the Message Subject header
func (m *FileMessage) Subject() string { func (m *Message) Subject() string {
return m.Fsubject return m.Fsubject
} }
// String returns a string in the form: "Subject()" from From()
func (m *FileMessage) String() string {
return fmt.Sprintf("\"%v\" from %v", m.Fsubject, m.Ffrom)
}
// Size returns the size of the Message on disk in bytes // Size returns the size of the Message on disk in bytes
func (m *FileMessage) Size() int64 { func (m *Message) Size() int64 {
return m.Fsize return m.Fsize
} }
func (m *FileMessage) rawPath() string { func (m *Message) rawPath() string {
return filepath.Join(m.mailbox.path, m.Fid+".raw") return filepath.Join(m.mailbox.path, m.Fid+".raw")
} }
// ReadHeader opens the .raw portion of a Message and returns a standard Go mail.Message object // Source opens the .raw portion of a Message as an io.ReadCloser
func (m *FileMessage) ReadHeader() (msg *mail.Message, err error) { func (m *Message) Source() (reader io.ReadCloser, err error) {
file, err := os.Open(m.rawPath())
if err != nil {
return nil, err
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", m.rawPath(), err)
}
}()
reader := bufio.NewReader(file)
return mail.ReadMessage(reader)
}
// ReadBody opens the .raw portion of a Message and returns a MIMEBody object
func (m *FileMessage) ReadBody() (body *enmime.Envelope, err error) {
file, err := os.Open(m.rawPath())
if err != nil {
return nil, err
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", m.rawPath(), err)
}
}()
reader := bufio.NewReader(file)
mime, err := enmime.ReadEnvelope(reader)
if err != nil {
return nil, err
}
return mime, nil
}
// RawReader opens the .raw portion of a Message as an io.ReadCloser
func (m *FileMessage) RawReader() (reader io.ReadCloser, err error) {
file, err := os.Open(m.rawPath()) file, err := os.Open(m.rawPath())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return file, nil return file, nil
} }
// ReadRaw opens the .raw portion of a Message and returns it as a string
func (m *FileMessage) ReadRaw() (raw *string, err error) {
reader, err := m.RawReader()
if err != nil {
return nil, err
}
defer func() {
if err := reader.Close(); err != nil {
log.Errorf("Failed to close %q: %v", m.rawPath(), err)
}
}()
bodyBytes, err := ioutil.ReadAll(bufio.NewReader(reader))
if err != nil {
return nil, err
}
bodyString := string(bodyBytes)
return &bodyString, nil
}
// Append data to a newly opened Message, this will fail on a pre-existing Message and
// after Close() is called.
func (m *FileMessage) Append(data []byte) error {
// Prevent Appending to a pre-existing Message
if !m.writable {
return datastore.ErrNotWritable
}
// Open file for writing if we haven't yet
if m.writer == nil {
// Ensure mailbox directory exists
if err := m.mailbox.createDir(); err != nil {
return err
}
file, err := os.Create(m.rawPath())
if err != nil {
// Set writable false just in case something calls me a million times
m.writable = false
return err
}
m.writerFile = file
m.writer = bufio.NewWriter(file)
}
_, err := m.writer.Write(data)
m.Fsize += int64(len(data))
return err
}
// Close this Message for writing - no more data may be Appended. Close() will also
// trigger the creation of the .gob file.
func (m *FileMessage) Close() error {
// nil out the writer fields so they can't be used
writer := m.writer
writerFile := m.writerFile
m.writer = nil
m.writerFile = nil
if writer != nil {
if err := writer.Flush(); err != nil {
return err
}
}
if writerFile != nil {
if err := writerFile.Close(); err != nil {
return err
}
}
// Fetch headers
body, err := m.ReadBody()
if err != nil {
return err
}
// Only public fields are stored in gob, hence starting with capital F
// Parse From address
if address, err := mail.ParseAddress(body.GetHeader("From")); err == nil {
m.Ffrom = address.String()
} else {
m.Ffrom = body.GetHeader("From")
}
m.Fsubject = body.GetHeader("Subject")
// Turn the To header into a slice
if addresses, err := body.AddressList("To"); err == nil {
for _, a := range addresses {
m.Fto = append(m.Fto, a.String())
}
} else {
m.Fto = []string{body.GetHeader("To")}
}
// Refresh the index before adding our message
err = m.mailbox.readIndex()
if err != nil {
return err
}
// Made it this far without errors, add it to the index
m.mailbox.messages = append(m.mailbox.messages, m)
return m.mailbox.writeIndex()
}
// Delete this Message from disk by removing it from the index and deleting the
// raw files.
func (m *FileMessage) Delete() error {
messages := m.mailbox.messages
for i, mm := range messages {
if m == mm {
// Slice around message we are deleting
m.mailbox.messages = append(messages[:i], messages[i+1:]...)
break
}
}
if err := m.mailbox.writeIndex(); err != nil {
return err
}
if len(m.mailbox.messages) == 0 {
// This was the last message, thus writeIndex() has removed the entire
// directory; we don't need to delete the raw file.
return nil
}
// There are still messages in the index
log.Tracef("Deleting %v", m.rawPath())
return os.Remove(m.rawPath())
}

View File

@@ -1,18 +1,17 @@
package filestore package file
import ( import (
"bufio" "bufio"
"encoding/gob"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"time" "time"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/storage" "github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/stringutil" "github.com/jhillyerd/inbucket/pkg/stringutil"
) )
@@ -21,15 +20,6 @@ import (
const indexFileName = "index.gob" const indexFileName = "index.gob"
var ( var (
// indexMx is locked while reading/writing an index file
//
// NOTE: This is a bottleneck because it's a single lock even if we have a
// million index files
indexMx = new(sync.RWMutex)
// dirMx is locked while creating/removing directories
dirMx = new(sync.Mutex)
// countChannel is filled with a sequential numbers (0000..9999), which are // countChannel is filled with a sequential numbers (0000..9999), which are
// used by generateID() to generate unique message IDs. It's global // used by generateID() to generate unique message IDs. It's global
// because we only want one regardless of the number of DataStore objects // because we only want one regardless of the number of DataStore objects
@@ -48,17 +38,17 @@ func countGenerator(c chan int) {
} }
} }
// FileDataStore implements DataStore aand is the root of the mail storage // Store implements DataStore aand is the root of the mail storage
// hiearchy. It provides access to Mailbox objects // hiearchy. It provides access to Mailbox objects
type FileDataStore struct { type Store struct {
hashLock datastore.HashLock hashLock storage.HashLock
path string path string
mailPath string mailPath string
messageCap int messageCap int
} }
// NewFileDataStore creates a new DataStore object using the specified path // New creates a new DataStore object using the specified path
func NewFileDataStore(cfg config.DataStoreConfig) datastore.DataStore { func New(cfg config.DataStoreConfig) storage.Store {
path := cfg.Path path := cfg.Path
if path == "" { if path == "" {
log.Errorf("No value configured for datastore path") log.Errorf("No value configured for datastore path")
@@ -71,287 +61,193 @@ func NewFileDataStore(cfg config.DataStoreConfig) datastore.DataStore {
log.Errorf("Error creating dir %q: %v", mailPath, err) log.Errorf("Error creating dir %q: %v", mailPath, err)
} }
} }
return &FileDataStore{path: path, mailPath: mailPath, messageCap: cfg.MailboxMsgCap} return &Store{path: path, mailPath: mailPath, messageCap: cfg.MailboxMsgCap}
} }
// DefaultFileDataStore creates a new DataStore object. It uses the inbucket.Config object to // AddMessage adds a message to the specified mailbox.
// construct it's path. func (fs *Store) AddMessage(m storage.Message) (id string, err error) {
func DefaultFileDataStore() datastore.DataStore { mb, err := fs.mbox(m.Mailbox())
cfg := config.GetDataStoreConfig() if err != nil {
return NewFileDataStore(cfg) return "", err
}
mb.Lock()
defer mb.Unlock()
r, err := m.Source()
if err != nil {
return "", err
}
// Create a new message.
fm, err := mb.newMessage()
if err != nil {
return "", err
}
// Ensure mailbox directory exists.
if err := mb.createDir(); err != nil {
return "", err
}
// Write the message content
file, err := os.Create(fm.rawPath())
if err != nil {
return "", err
}
w := bufio.NewWriter(file)
size, err := io.Copy(w, r)
if err != nil {
// Try to remove the file
_ = file.Close()
_ = os.Remove(fm.rawPath())
return "", err
}
_ = r.Close()
if err := w.Flush(); err != nil {
// Try to remove the file
_ = file.Close()
_ = os.Remove(fm.rawPath())
return "", err
}
if err := file.Close(); err != nil {
// Try to remove the file
_ = os.Remove(fm.rawPath())
return "", err
}
// Update the index.
fm.Fdate = m.Date()
fm.Ffrom = m.From()
fm.Fto = m.To()
fm.Fsize = size
fm.Fsubject = m.Subject()
mb.messages = append(mb.messages, fm)
if err := mb.writeIndex(); err != nil {
// Try to remove the file
_ = os.Remove(fm.rawPath())
return "", err
}
return fm.Fid, nil
} }
// MailboxFor retrieves the Mailbox object for a specified email address, if the mailbox // GetMessage returns the messages in the named mailbox, or an error.
// does not exist, it will attempt to create it. func (fs *Store) GetMessage(mailbox, id string) (storage.Message, error) {
func (ds *FileDataStore) MailboxFor(emailAddress string) (datastore.Mailbox, error) { mb, err := fs.mbox(mailbox)
name, err := stringutil.ParseMailboxName(emailAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dir := stringutil.HashMailboxName(name) mb.RLock()
s1 := dir[0:3] defer mb.RUnlock()
s2 := dir[0:6] return mb.getMessage(id)
path := filepath.Join(ds.mailPath, s1, s2, dir)
indexPath := filepath.Join(path, indexFileName)
return &FileMailbox{store: ds, name: name, dirName: dir, path: path,
indexPath: indexPath}, nil
} }
// AllMailboxes returns a slice with all Mailboxes // GetMessages returns the messages in the named mailbox, or an error.
func (ds *FileDataStore) AllMailboxes() ([]datastore.Mailbox, error) { func (fs *Store) GetMessages(mailbox string) ([]storage.Message, error) {
mailboxes := make([]datastore.Mailbox, 0, 100) mb, err := fs.mbox(mailbox)
infos1, err := ioutil.ReadDir(ds.mailPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mb.RLock()
defer mb.RUnlock()
return mb.getMessages()
}
// RemoveMessage deletes a message by ID from the specified mailbox.
func (fs *Store) RemoveMessage(mailbox, id string) error {
mb, err := fs.mbox(mailbox)
if err != nil {
return err
}
mb.Lock()
defer mb.Unlock()
return mb.removeMessage(id)
}
// PurgeMessages deletes all messages in the named mailbox, or returns an error.
func (fs *Store) PurgeMessages(mailbox string) error {
mb, err := fs.mbox(mailbox)
if err != nil {
return err
}
mb.Lock()
defer mb.Unlock()
return mb.purge()
}
// VisitMailboxes accepts a function that will be called with the messages in each mailbox while it
// continues to return true.
func (fs *Store) VisitMailboxes(f func([]storage.Message) (cont bool)) error {
infos1, err := ioutil.ReadDir(fs.mailPath)
if err != nil {
return err
}
// Loop over level 1 directories // Loop over level 1 directories
for _, inf1 := range infos1 { for _, inf1 := range infos1 {
if inf1.IsDir() { if inf1.IsDir() {
l1 := inf1.Name() l1 := inf1.Name()
infos2, err := ioutil.ReadDir(filepath.Join(ds.mailPath, l1)) infos2, err := ioutil.ReadDir(filepath.Join(fs.mailPath, l1))
if err != nil { if err != nil {
return nil, err return err
} }
// Loop over level 2 directories // Loop over level 2 directories
for _, inf2 := range infos2 { for _, inf2 := range infos2 {
if inf2.IsDir() { if inf2.IsDir() {
l2 := inf2.Name() l2 := inf2.Name()
infos3, err := ioutil.ReadDir(filepath.Join(ds.mailPath, l1, l2)) infos3, err := ioutil.ReadDir(filepath.Join(fs.mailPath, l1, l2))
if err != nil { if err != nil {
return nil, err return err
} }
// Loop over mailboxes // Loop over mailboxes
for _, inf3 := range infos3 { for _, inf3 := range infos3 {
if inf3.IsDir() { if inf3.IsDir() {
mbdir := inf3.Name() mb := fs.mboxFromHash(inf3.Name())
mbpath := filepath.Join(ds.mailPath, l1, l2, mbdir) mb.RLock()
idx := filepath.Join(mbpath, indexFileName) msgs, err := mb.getMessages()
mb := &FileMailbox{store: ds, dirName: mbdir, path: mbpath, mb.RUnlock()
indexPath: idx} if err != nil {
mailboxes = append(mailboxes, mb) return err
}
if !f(msgs) {
return nil
}
} }
} }
} }
} }
} }
} }
return nil
return mailboxes, nil
} }
// LockFor returns the RWMutex for this mailbox, or an error. // mbox returns the named mailbox.
func (ds *FileDataStore) LockFor(emailAddress string) (*sync.RWMutex, error) { func (fs *Store) mbox(mailbox string) (*mbox, error) {
name, err := stringutil.ParseMailboxName(emailAddress) name, err := policy.ParseMailboxName(mailbox)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hash := stringutil.HashMailboxName(name) hash := stringutil.HashMailboxName(name)
return ds.hashLock.Get(hash), nil s1 := hash[0:3]
s2 := hash[0:6]
path := filepath.Join(fs.mailPath, s1, s2, hash)
indexPath := filepath.Join(path, indexFileName)
return &mbox{
RWMutex: fs.hashLock.Get(hash),
store: fs,
name: name,
dirName: hash,
path: path,
indexPath: indexPath,
}, nil
} }
// FileMailbox implements Mailbox, manages the mail for a specific user and // mboxFromPath constructs a mailbox based on name hash.
// correlates to a particular directory on disk. func (fs *Store) mboxFromHash(hash string) *mbox {
type FileMailbox struct { s1 := hash[0:3]
store *FileDataStore s2 := hash[0:6]
name string path := filepath.Join(fs.mailPath, s1, s2, hash)
dirName string indexPath := filepath.Join(path, indexFileName)
path string return &mbox{
indexLoaded bool RWMutex: fs.hashLock.Get(hash),
indexPath string store: fs,
messages []*FileMessage dirName: hash,
} path: path,
indexPath: indexPath,
// Name of the mailbox
func (mb *FileMailbox) Name() string {
return mb.name
}
// String renders the name and directory path of the mailbox
func (mb *FileMailbox) String() string {
return mb.name + "[" + mb.dirName + "]"
}
// GetMessages scans the mailbox directory for .gob files and decodes them into
// a slice of Message objects.
func (mb *FileMailbox) GetMessages() ([]datastore.Message, error) {
if !mb.indexLoaded {
if err := mb.readIndex(); err != nil {
return nil, err
}
} }
messages := make([]datastore.Message, len(mb.messages))
for i, m := range mb.messages {
messages[i] = m
}
return messages, nil
}
// GetMessage decodes a single message by Id and returns a Message object
func (mb *FileMailbox) GetMessage(id string) (datastore.Message, error) {
if !mb.indexLoaded {
if err := mb.readIndex(); err != nil {
return nil, err
}
}
if id == "latest" && len(mb.messages) != 0 {
return mb.messages[len(mb.messages)-1], nil
}
for _, m := range mb.messages {
if m.Fid == id {
return m, nil
}
}
return nil, datastore.ErrNotExist
}
// Purge deletes all messages in this mailbox
func (mb *FileMailbox) Purge() error {
mb.messages = mb.messages[:0]
return mb.writeIndex()
}
// readIndex loads the mailbox index data from disk
func (mb *FileMailbox) readIndex() error {
// Clear message slice, open index
mb.messages = mb.messages[:0]
// Lock for reading
indexMx.RLock()
defer indexMx.RUnlock()
// Check if index exists
if _, err := os.Stat(mb.indexPath); err != nil {
// Does not exist, but that's not an error in our world
log.Tracef("Index %v does not exist (yet)", mb.indexPath)
mb.indexLoaded = true
return nil
}
file, err := os.Open(mb.indexPath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", mb.indexPath, err)
}
}()
// Decode gob data
dec := gob.NewDecoder(bufio.NewReader(file))
for {
msg := new(FileMessage)
if err = dec.Decode(msg); err != nil {
if err == io.EOF {
// It's OK to get an EOF here
break
}
return fmt.Errorf("Corrupt mailbox %q: %v", mb.indexPath, err)
}
msg.mailbox = mb
mb.messages = append(mb.messages, msg)
}
mb.indexLoaded = true
return nil
}
// writeIndex overwrites the index on disk with the current mailbox data
func (mb *FileMailbox) writeIndex() error {
// Lock for writing
indexMx.Lock()
defer indexMx.Unlock()
if len(mb.messages) > 0 {
// Ensure mailbox directory exists
if err := mb.createDir(); err != nil {
return err
}
// Open index for writing
file, err := os.Create(mb.indexPath)
if err != nil {
return err
}
writer := bufio.NewWriter(file)
// Write each message and then flush
enc := gob.NewEncoder(writer)
for _, m := range mb.messages {
err = enc.Encode(m)
if err != nil {
_ = file.Close()
return err
}
}
if err := writer.Flush(); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", mb.indexPath, err)
return err
}
} else {
// No messages, delete index+maildir
log.Tracef("Removing mailbox %v", mb.path)
return mb.removeDir()
}
return nil
}
// createDir checks for the presence of the path for this mailbox, creates it if needed
func (mb *FileMailbox) createDir() error {
dirMx.Lock()
defer dirMx.Unlock()
if _, err := os.Stat(mb.path); err != nil {
if err := os.MkdirAll(mb.path, 0770); err != nil {
log.Errorf("Failed to create directory %v, %v", mb.path, err)
return err
}
}
return nil
}
// removeDir removes the mailbox, plus empty higher level directories
func (mb *FileMailbox) removeDir() error {
dirMx.Lock()
defer dirMx.Unlock()
// remove mailbox dir, including index file
if err := os.RemoveAll(mb.path); err != nil {
return err
}
// remove parents if empty
dir := filepath.Dir(mb.path)
if removeDirIfEmpty(dir) {
removeDirIfEmpty(filepath.Dir(dir))
}
return nil
}
// removeDirIfEmpty will remove the specified directory if it contains no files or directories.
// Caller should hold dirMx. Returns true if dir was removed.
func removeDirIfEmpty(path string) (removed bool) {
f, err := os.Open(path)
if err != nil {
return false
}
files, err := f.Readdirnames(0)
_ = f.Close()
if err != nil {
return false
}
if len(files) > 0 {
// Dir not empty
return false
}
log.Tracef("Removing dir %v", path)
err = os.Remove(path)
if err != nil {
log.Errorf("Failed to remove %q: %v", path, err)
return false
}
return true
} }
// generatePrefix converts a Time object into the ISO style format we use // generatePrefix converts a Time object into the ISO style format we use

View File

@@ -1,4 +1,4 @@
package filestore package file
import ( import (
"bytes" "bytes"
@@ -6,15 +6,31 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"net/mail"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
// TestSuite runs storage package test suite on file store.
func TestSuite(t *testing.T) {
test.StoreSuite(t, func() (storage.Store, func(), error) {
ds, _ := setupDataStore(config.DataStoreConfig{})
destroy := func() {
teardownDataStore(ds)
}
return ds, destroy, nil
})
}
// Test directory structure created by filestore // Test directory structure created by filestore
func TestFSDirStructure(t *testing.T) { func TestFSDirStructure(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{}) ds, logbuf := setupDataStore(config.DataStoreConfig{})
@@ -62,11 +78,7 @@ func TestFSDirStructure(t *testing.T) {
assert.True(t, isFile(expect), "Expected %q to be a file", expect) assert.True(t, isFile(expect), "Expected %q to be a file", expect)
// Delete message // Delete message
mb, err := ds.MailboxFor(mbName) err := ds.RemoveMessage(mbName, id1)
assert.Nil(t, err)
msg, err := mb.GetMessage(id1)
assert.Nil(t, err)
err = msg.Delete()
assert.Nil(t, err) assert.Nil(t, err)
// Message should be removed // Message should be removed
@@ -76,9 +88,7 @@ func TestFSDirStructure(t *testing.T) {
assert.True(t, isFile(expect), "Expected %q to be a file", expect) assert.True(t, isFile(expect), "Expected %q to be a file", expect)
// Delete message // Delete message
msg, err = mb.GetMessage(id2) err = ds.RemoveMessage(mbName, id2)
assert.Nil(t, err)
err = msg.Delete()
assert.Nil(t, err) assert.Nil(t, err)
// Message should be removed // Message should be removed
@@ -99,243 +109,6 @@ func TestFSDirStructure(t *testing.T) {
} }
} }
// Test FileDataStore.AllMailboxes()
func TestFSAllMailboxes(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{})
defer teardownDataStore(ds)
for _, name := range []string{"abby", "bill", "christa", "donald", "evelyn"} {
// Create day old message
date := time.Now().Add(-24 * time.Hour)
deliverMessage(ds, name, "Old Message", date)
// Create current message
date = time.Now()
deliverMessage(ds, name, "New Message", date)
}
mboxes, err := ds.AllMailboxes()
assert.Nil(t, err)
assert.Equal(t, len(mboxes), 5)
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
}
// Test delivering several messages to the same mailbox, meanwhile querying its
// contents with a new mailbox object each time
func TestFSDeliverMany(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{})
defer teardownDataStore(ds)
mbName := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for i, subj := range subjects {
// Check number of messages
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
assert.Equal(t, i, len(msgs), "Expected %v message(s), but got %v", i, len(msgs))
// Add a message
deliverMessage(ds, mbName, subj, time.Now())
}
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
assert.Equal(t, len(subjects), len(msgs), "Expected %v message(s), but got %v",
len(subjects), len(msgs))
// Confirm delivery order
for i, expect := range subjects {
subj := msgs[i].Subject()
assert.Equal(t, expect, subj, "Expected subject %q, got %q", expect, subj)
}
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
}
// Test deleting messages
func TestFSDelete(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{})
defer teardownDataStore(ds)
mbName := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for _, subj := range subjects {
// Add a message
deliverMessage(ds, mbName, subj, time.Now())
}
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
assert.Equal(t, len(subjects), len(msgs), "Expected %v message(s), but got %v",
len(subjects), len(msgs))
// Delete a couple messages
_ = msgs[1].Delete()
_ = msgs[3].Delete()
// Confirm deletion
mb, err = ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err = mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
subjects = []string{"alpha", "charlie", "echo"}
assert.Equal(t, len(subjects), len(msgs), "Expected %v message(s), but got %v",
len(subjects), len(msgs))
for i, expect := range subjects {
subj := msgs[i].Subject()
assert.Equal(t, expect, subj, "Expected subject %q, got %q", expect, subj)
}
// Try appending one more
deliverMessage(ds, mbName, "foxtrot", time.Now())
mb, err = ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err = mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
subjects = []string{"alpha", "charlie", "echo", "foxtrot"}
assert.Equal(t, len(subjects), len(msgs), "Expected %v message(s), but got %v",
len(subjects), len(msgs))
for i, expect := range subjects {
subj := msgs[i].Subject()
assert.Equal(t, expect, subj, "Expected subject %q, got %q", expect, subj)
}
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
}
// Test purging a mailbox
func TestFSPurge(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{})
defer teardownDataStore(ds)
mbName := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for _, subj := range subjects {
// Add a message
deliverMessage(ds, mbName, subj, time.Now())
}
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
assert.Equal(t, len(subjects), len(msgs), "Expected %v message(s), but got %v",
len(subjects), len(msgs))
// Purge mailbox
err = mb.Purge()
assert.Nil(t, err)
// Confirm deletion
mb, err = ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err = mb.GetMessages()
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
}
assert.Equal(t, len(msgs), 0, "Expected mailbox to have zero messages, got %v", len(msgs))
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
}
// Test message size calculation
func TestFSSize(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{})
defer teardownDataStore(ds)
mbName := "fred"
subjects := []string{"a", "br", "much longer than the others"}
sentIds := make([]string, len(subjects))
sentSizes := make([]int64, len(subjects))
for i, subj := range subjects {
// Add a message
id, size := deliverMessage(ds, mbName, subj, time.Now())
sentIds[i] = id
sentSizes[i] = size
}
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
for i, id := range sentIds {
msg, err := mb.GetMessage(id)
assert.Nil(t, err)
expect := sentSizes[i]
size := msg.Size()
assert.Equal(t, expect, size, "Expected size of %v, got %v", expect, size)
}
if t.Failed() {
// Wait for handler to finish logging
time.Sleep(2 * time.Second)
// Dump buffered log data if there was a failure
_, _ = io.Copy(os.Stderr, logbuf)
}
}
// Test missing files // Test missing files
func TestFSMissing(t *testing.T) { func TestFSMissing(t *testing.T) {
ds, logbuf := setupDataStore(config.DataStoreConfig{}) ds, logbuf := setupDataStore(config.DataStoreConfig{})
@@ -351,23 +124,16 @@ func TestFSMissing(t *testing.T) {
sentIds[i] = id sentIds[i] = id
} }
mb, err := ds.MailboxFor(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
// Delete a message file without removing it from index // Delete a message file without removing it from index
msg, err := mb.GetMessage(sentIds[1]) msg, err := ds.GetMessage(mbName, sentIds[1])
assert.Nil(t, err) assert.Nil(t, err)
fmsg := msg.(*FileMessage) fmsg := msg.(*Message)
_ = os.Remove(fmsg.rawPath()) _ = os.Remove(fmsg.rawPath())
msg, err = mb.GetMessage(sentIds[1]) msg, err = ds.GetMessage(mbName, sentIds[1])
assert.Nil(t, err) assert.Nil(t, err)
// Try to read parts of message // Try to read parts of message
_, err = msg.ReadHeader() _, err = msg.Source()
assert.Error(t, err)
_, err = msg.ReadBody()
assert.Error(t, err) assert.Error(t, err)
if t.Failed() { if t.Failed() {
@@ -392,11 +158,7 @@ func TestFSMessageCap(t *testing.T) {
t.Logf("Delivered %q", subj) t.Logf("Delivered %q", subj)
// Check number of messages // Check number of messages
mb, err := ds.MailboxFor(mbName) msgs, err := ds.GetMessages(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil { if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err) t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
} }
@@ -437,11 +199,7 @@ func TestFSNoMessageCap(t *testing.T) {
t.Logf("Delivered %q", subj) t.Logf("Delivered %q", subj)
// Check number of messages // Check number of messages
mb, err := ds.MailboxFor(mbName) msgs, err := ds.GetMessages(mbName)
if err != nil {
t.Fatalf("Failed to MailboxFor(%q): %v", mbName, err)
}
msgs, err := mb.GetMessages()
if err != nil { if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mbName, err) t.Fatalf("Failed to GetMessages for %q: %v", mbName, err)
} }
@@ -467,9 +225,7 @@ func TestGetLatestMessage(t *testing.T) {
mbName := "james" mbName := "james"
// Test empty mailbox // Test empty mailbox
mb, err := ds.MailboxFor(mbName) msg, err := ds.GetMessage(mbName, "latest")
assert.Nil(t, err)
msg, err := mb.GetMessage("latest")
assert.Nil(t, msg) assert.Nil(t, msg)
assert.Error(t, err) assert.Error(t, err)
@@ -480,23 +236,19 @@ func TestGetLatestMessage(t *testing.T) {
id2, _ := deliverMessage(ds, mbName, "test 2", time.Now()) id2, _ := deliverMessage(ds, mbName, "test 2", time.Now())
// Test get the latest message // Test get the latest message
mb, err = ds.MailboxFor(mbName) msg, err = ds.GetMessage(mbName, "latest")
assert.Nil(t, err)
msg, err = mb.GetMessage("latest")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, msg.ID() == id2, "Expected %q to be equal to %q", msg.ID(), id2) assert.True(t, msg.ID() == id2, "Expected %q to be equal to %q", msg.ID(), id2)
// Deliver test message 3 // Deliver test message 3
id3, _ := deliverMessage(ds, mbName, "test 3", time.Now()) id3, _ := deliverMessage(ds, mbName, "test 3", time.Now())
mb, err = ds.MailboxFor(mbName) msg, err = ds.GetMessage(mbName, "latest")
assert.Nil(t, err)
msg, err = mb.GetMessage("latest")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, msg.ID() == id3, "Expected %q to be equal to %q", msg.ID(), id3) assert.True(t, msg.ID() == id3, "Expected %q to be equal to %q", msg.ID(), id3)
// Test wrong id // Test wrong id
_, err = mb.GetMessage("wrongid") _, err = ds.GetMessage(mbName, "wrongid")
assert.Error(t, err) assert.Error(t, err)
if t.Failed() { if t.Failed() {
@@ -508,7 +260,7 @@ func TestGetLatestMessage(t *testing.T) {
} }
// setupDataStore creates a new FileDataStore in a temporary directory // setupDataStore creates a new FileDataStore in a temporary directory
func setupDataStore(cfg config.DataStoreConfig) (*FileDataStore, *bytes.Buffer) { func setupDataStore(cfg config.DataStoreConfig) (*Store, *bytes.Buffer) {
path, err := ioutil.TempDir("", "inbucket") path, err := ioutil.TempDir("", "inbucket")
if err != nil { if err != nil {
panic(err) panic(err)
@@ -519,45 +271,34 @@ func setupDataStore(cfg config.DataStoreConfig) (*FileDataStore, *bytes.Buffer)
log.SetOutput(buf) log.SetOutput(buf)
cfg.Path = path cfg.Path = path
return NewFileDataStore(cfg).(*FileDataStore), buf return New(cfg).(*Store), buf
} }
// deliverMessage creates and delivers a message to the specific mailbox, returning // deliverMessage creates and delivers a message to the specific mailbox, returning
// the size of the generated message. // the size of the generated message.
func deliverMessage(ds *FileDataStore, mbName string, subject string, func deliverMessage(ds *Store, mbName string, subject string, date time.Time) (string, int64) {
date time.Time) (id string, size int64) { // Build message for delivery
// Build fake SMTP message for delivery meta := message.Metadata{
testMsg := make([]byte, 0, 300) Mailbox: mbName,
testMsg = append(testMsg, []byte("To: somebody@host\r\n")...) To: []*mail.Address{{Name: "", Address: "somebody@host"}},
testMsg = append(testMsg, []byte("From: somebodyelse@host\r\n")...) From: &mail.Address{Name: "", Address: "somebodyelse@host"},
testMsg = append(testMsg, []byte(fmt.Sprintf("Subject: %s\r\n", subject))...) Subject: subject,
testMsg = append(testMsg, []byte("\r\n")...) Date: date,
testMsg = append(testMsg, []byte("Test Body\r\n")...) }
testMsg := fmt.Sprintf("To: %s\r\nFrom: %s\r\nSubject: %s\r\n\r\nTest Body\r\n",
mb, err := ds.MailboxFor(mbName) meta.To[0].Address, meta.From.Address, subject)
delivery := &message.Delivery{
Meta: meta,
Reader: ioutil.NopCloser(strings.NewReader(testMsg)),
}
id, err := ds.AddMessage(delivery)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Create message object
id = generateID(date)
msg, err := mb.NewMessage()
if err != nil {
panic(err)
}
fmsg := msg.(*FileMessage)
fmsg.Fdate = date
fmsg.Fid = id
if err = msg.Append(testMsg); err != nil {
panic(err)
}
if err = msg.Close(); err != nil {
panic(err)
}
return id, int64(len(testMsg)) return id, int64(len(testMsg))
} }
func teardownDataStore(ds *FileDataStore) { func teardownDataStore(ds *Store) {
if err := os.RemoveAll(ds.path); err != nil { if err := os.RemoveAll(ds.path); err != nil {
panic(err) panic(err)
} }

233
pkg/storage/file/mbox.go Normal file
View File

@@ -0,0 +1,233 @@
package file
import (
"bufio"
"encoding/gob"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/storage"
)
// mbox manages the mail for a specific user and correlates to a particular directory on disk.
// mbox methods are not thread safe, mbox.RWMutex must be held prior to calling.
type mbox struct {
*sync.RWMutex
store *Store
name string
dirName string
path string
indexLoaded bool
indexPath string
messages []*Message
}
// getMessages scans the mailbox directory for .gob files and decodes them into
// a slice of Message objects.
func (mb *mbox) getMessages() ([]storage.Message, error) {
if !mb.indexLoaded {
if err := mb.readIndex(); err != nil {
return nil, err
}
}
messages := make([]storage.Message, len(mb.messages))
for i, m := range mb.messages {
messages[i] = m
}
return messages, nil
}
// getMessage decodes a single message by ID and returns a Message object.
func (mb *mbox) getMessage(id string) (storage.Message, error) {
if !mb.indexLoaded {
if err := mb.readIndex(); err != nil {
return nil, err
}
}
if id == "latest" && len(mb.messages) != 0 {
return mb.messages[len(mb.messages)-1], nil
}
for _, m := range mb.messages {
if m.Fid == id {
return m, nil
}
}
return nil, storage.ErrNotExist
}
// removeMessage deletes the message off disk and removes it from the index.
func (mb *mbox) removeMessage(id string) error {
if !mb.indexLoaded {
if err := mb.readIndex(); err != nil {
return err
}
}
var msg *Message
for i, m := range mb.messages {
if id == m.ID() {
msg = m
// Slice around message we are deleting
mb.messages = append(mb.messages[:i], mb.messages[i+1:]...)
break
}
}
if msg == nil {
return storage.ErrNotExist
}
if err := mb.writeIndex(); err != nil {
return err
}
if len(mb.messages) == 0 {
// This was the last message, thus writeIndex() has removed the entire
// directory; we don't need to delete the raw file.
return nil
}
// There are still messages in the index
log.Tracef("Deleting %v", msg.rawPath())
return os.Remove(msg.rawPath())
}
// purge deletes all messages in this mailbox.
func (mb *mbox) purge() error {
mb.messages = mb.messages[:0]
return mb.writeIndex()
}
// readIndex loads the mailbox index data from disk
func (mb *mbox) readIndex() error {
// Clear message slice, open index
mb.messages = mb.messages[:0]
// Check if index exists
if _, err := os.Stat(mb.indexPath); err != nil {
// Does not exist, but that's not an error in our world
log.Tracef("Index %v does not exist (yet)", mb.indexPath)
mb.indexLoaded = true
return nil
}
file, err := os.Open(mb.indexPath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", mb.indexPath, err)
}
}()
// Decode gob data
dec := gob.NewDecoder(bufio.NewReader(file))
name := ""
if err = dec.Decode(&name); err != nil {
return fmt.Errorf("Corrupt mailbox %q: %v", mb.indexPath, err)
}
mb.name = name
for {
// Load messages until EOF
msg := &Message{}
if err = dec.Decode(msg); err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("Corrupt mailbox %q: %v", mb.indexPath, err)
}
msg.mailbox = mb
mb.messages = append(mb.messages, msg)
}
mb.indexLoaded = true
return nil
}
// writeIndex overwrites the index on disk with the current mailbox data
func (mb *mbox) writeIndex() error {
// Lock for writing
if len(mb.messages) > 0 {
// Ensure mailbox directory exists
if err := mb.createDir(); err != nil {
return err
}
// Open index for writing
file, err := os.Create(mb.indexPath)
if err != nil {
return err
}
writer := bufio.NewWriter(file)
// Write each message and then flush
enc := gob.NewEncoder(writer)
if err = enc.Encode(mb.name); err != nil {
_ = file.Close()
return err
}
for _, m := range mb.messages {
if err = enc.Encode(m); err != nil {
_ = file.Close()
return err
}
}
if err := writer.Flush(); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
log.Errorf("Failed to close %q: %v", mb.indexPath, err)
return err
}
} else {
// No messages, delete index+maildir
log.Tracef("Removing mailbox %v", mb.path)
return mb.removeDir()
}
return nil
}
// createDir checks for the presence of the path for this mailbox, creates it if needed
func (mb *mbox) createDir() error {
if _, err := os.Stat(mb.path); err != nil {
if err := os.MkdirAll(mb.path, 0770); err != nil {
log.Errorf("Failed to create directory %v, %v", mb.path, err)
return err
}
}
return nil
}
// removeDir removes the mailbox, plus empty higher level directories
func (mb *mbox) removeDir() error {
// remove mailbox dir, including index file
if err := os.RemoveAll(mb.path); err != nil {
return err
}
// remove parents if empty
dir := filepath.Dir(mb.path)
if removeDirIfEmpty(dir) {
removeDirIfEmpty(filepath.Dir(dir))
}
return nil
}
// removeDirIfEmpty will remove the specified directory if it contains no files or directories.
// Returns true if dir was removed.
func removeDirIfEmpty(path string) (removed bool) {
f, err := os.Open(path)
if err != nil {
return false
}
files, err := f.Readdirnames(0)
_ = f.Close()
if err != nil {
return false
}
if len(files) > 0 {
// Dir not empty
return false
}
log.Tracef("Removing dir %v", path)
err = os.Remove(path)
if err != nil {
log.Errorf("Failed to remove %q: %v", path, err)
return false
}
return true
}

View File

@@ -1,4 +1,4 @@
package datastore package storage
import ( import (
"strconv" "strconv"

View File

@@ -1,4 +1,4 @@
package datastore_test package storage_test
import ( import (
"testing" "testing"
@@ -7,7 +7,7 @@ import (
) )
func TestHashLock(t *testing.T) { func TestHashLock(t *testing.T) {
hl := &datastore.HashLock{} hl := &storage.HashLock{}
// Invalid hashes // Invalid hashes
testCases := []struct { testCases := []struct {

View File

@@ -1,4 +1,4 @@
package datastore package storage
import ( import (
"container/list" "container/list"
@@ -47,15 +47,17 @@ func init() {
type RetentionScanner struct { type RetentionScanner struct {
globalShutdown chan bool // Closes when Inbucket needs to shut down globalShutdown chan bool // Closes when Inbucket needs to shut down
retentionShutdown chan bool // Closed after the scanner has shut down retentionShutdown chan bool // Closed after the scanner has shut down
ds DataStore ds Store
retentionPeriod time.Duration retentionPeriod time.Duration
retentionSleep time.Duration retentionSleep time.Duration
} }
// NewRetentionScanner launches a go-routine that scans for expired // NewRetentionScanner configures a new RententionScanner.
// messages, following the configured interval func NewRetentionScanner(
func NewRetentionScanner(ds DataStore, shutdownChannel chan bool) *RetentionScanner { cfg config.DataStoreConfig,
cfg := config.GetDataStoreConfig() ds Store,
shutdownChannel chan bool,
) *RetentionScanner {
rs := &RetentionScanner{ rs := &RetentionScanner{
globalShutdown: shutdownChannel, globalShutdown: shutdownChannel,
retentionShutdown: make(chan bool), retentionShutdown: make(chan bool),
@@ -97,7 +99,7 @@ retentionLoop:
} }
// Kickoff scan // Kickoff scan
start = time.Now() start = time.Now()
if err := rs.doScan(); err != nil { if err := rs.DoScan(); err != nil {
log.Errorf("Error during retention scan: %v", err) log.Errorf("Error during retention scan: %v", err)
} }
// Check for global shutdown // Check for global shutdown
@@ -111,28 +113,17 @@ retentionLoop:
close(rs.retentionShutdown) close(rs.retentionShutdown)
} }
// doScan does a single pass of all mailboxes looking for messages that can be purged // DoScan does a single pass of all mailboxes looking for messages that can be purged.
func (rs *RetentionScanner) doScan() error { func (rs *RetentionScanner) DoScan() error {
log.Tracef("Starting retention scan") log.Tracef("Starting retention scan")
cutoff := time.Now().Add(-1 * rs.retentionPeriod) cutoff := time.Now().Add(-1 * rs.retentionPeriod)
mboxes, err := rs.ds.AllMailboxes()
if err != nil {
return err
}
retained := 0 retained := 0
// Loop over all mailboxes // Loop over all mailboxes.
for _, mb := range mboxes { err := rs.ds.VisitMailboxes(func(messages []Message) bool {
messages, err := mb.GetMessages()
if err != nil {
return err
}
// Loop over all messages in mailbox
for _, msg := range messages { for _, msg := range messages {
if msg.Date().Before(cutoff) { if msg.Date().Before(cutoff) {
log.Tracef("Purging expired message %v", msg.ID()) log.Tracef("Purging expired message %v/%v", msg.Mailbox(), msg.ID())
err = msg.Delete() if err := rs.ds.RemoveMessage(msg.Mailbox(), msg.ID()); err != nil {
if err != nil {
// Log but don't abort
log.Errorf("Failed to purge message %v: %v", msg.ID(), err) log.Errorf("Failed to purge message %v: %v", msg.ID(), err)
} else { } else {
expRetentionDeletesTotal.Add(1) expRetentionDeletesTotal.Add(1)
@@ -141,14 +132,17 @@ func (rs *RetentionScanner) doScan() error {
retained++ retained++
} }
} }
// Sleep after completing a mailbox
select { select {
case <-rs.globalShutdown: case <-rs.globalShutdown:
log.Tracef("Retention scan aborted due to shutdown") log.Tracef("Retention scan aborted due to shutdown")
return nil return false
case <-time.After(rs.retentionSleep): case <-time.After(rs.retentionSleep):
// Reduce disk thrashing // Reduce disk thrashing
} }
return true
})
if err != nil {
return err
} }
// Update metrics // Update metrics
setRetentionScanCompleted(time.Now()) setRetentionScanCompleted(time.Now())
@@ -156,7 +150,7 @@ func (rs *RetentionScanner) doScan() error {
return nil return nil
} }
// Join does not retun until the retention scanner has shut down // Join does not return until the retention scanner has shut down.
func (rs *RetentionScanner) Join() { func (rs *RetentionScanner) Join() {
if rs.retentionShutdown != nil { if rs.retentionShutdown != nil {
<-rs.retentionShutdown <-rs.retentionShutdown

View File

@@ -1,67 +1,62 @@
package datastore package storage_test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/test"
) )
func TestDoRetentionScan(t *testing.T) { func TestDoRetentionScan(t *testing.T) {
// Create mock objects ds := test.NewStore()
mds := &MockDataStore{}
mb1 := &MockMailbox{}
mb2 := &MockMailbox{}
mb3 := &MockMailbox{}
// Mockup some different aged messages (num is in hours) // Mockup some different aged messages (num is in hours)
new1 := mockMessage(0) new1 := stubMessage("mb1", 0)
new2 := mockMessage(1) new2 := stubMessage("mb2", 1)
new3 := mockMessage(2) new3 := stubMessage("mb3", 2)
old1 := mockMessage(4) old1 := stubMessage("mb1", 4)
old2 := mockMessage(12) old2 := stubMessage("mb1", 12)
old3 := mockMessage(24) old3 := stubMessage("mb2", 24)
ds.AddMessage(new1)
// First it should ask for all mailboxes ds.AddMessage(old1)
mds.On("AllMailboxes").Return([]Mailbox{mb1, mb2, mb3}, nil) ds.AddMessage(old2)
ds.AddMessage(old3)
// Then for all messages on each box ds.AddMessage(new2)
mb1.On("GetMessages").Return([]Message{new1, old1, old2}, nil) ds.AddMessage(new3)
mb2.On("GetMessages").Return([]Message{old3, new2}, nil)
mb3.On("GetMessages").Return([]Message{new3}, nil)
// Test 4 hour retention // Test 4 hour retention
rs := &RetentionScanner{ cfg := config.DataStoreConfig{
ds: mds, RetentionMinutes: 239,
retentionPeriod: 4*time.Hour - time.Minute, RetentionSleep: 0,
retentionSleep: 0,
} }
if err := rs.doScan(); err != nil { shutdownChan := make(chan bool)
rs := storage.NewRetentionScanner(cfg, ds, shutdownChan)
if err := rs.DoScan(); err != nil {
t.Error(err) t.Error(err)
} }
// Check our assertions
mds.AssertExpectations(t)
mb1.AssertExpectations(t)
mb2.AssertExpectations(t)
mb3.AssertExpectations(t)
// Delete should not have been called on new messages // Delete should not have been called on new messages
new1.AssertNotCalled(t, "Delete") for _, m := range []storage.Message{new1, new2, new3} {
new2.AssertNotCalled(t, "Delete") if ds.MessageDeleted(m) {
new3.AssertNotCalled(t, "Delete") t.Errorf("Expected %v to be present, was deleted", m.ID())
}
}
// Delete should have been called once on old messages // Delete should have been called once on old messages
old1.AssertNumberOfCalls(t, "Delete", 1) for _, m := range []storage.Message{old1, old2, old3} {
old2.AssertNumberOfCalls(t, "Delete", 1) if !ds.MessageDeleted(m) {
old3.AssertNumberOfCalls(t, "Delete", 1) t.Errorf("Expected %v to be deleted, was present", m.ID())
}
}
} }
// Make a MockMessage of a specific age // stubMessage creates a message stub of a specific age
func mockMessage(ageHours int) *MockMessage { func stubMessage(mailbox string, ageHours int) storage.Message {
msg := &MockMessage{} return &message.Delivery{
msg.On("ID").Return(fmt.Sprintf("MSG[age=%vh]", ageHours)) Meta: message.Metadata{
msg.On("Date").Return(time.Now().Add(time.Duration(ageHours*-1) * time.Hour)) Mailbox: mailbox,
msg.On("Delete").Return(nil) ID: fmt.Sprintf("MSG[age=%vh]", ageHours),
return msg Date: time.Now().Add(time.Duration(ageHours*-1) * time.Hour),
},
}
} }

40
pkg/storage/storage.go Normal file
View File

@@ -0,0 +1,40 @@
// Package storage contains implementation independent datastore logic
package storage
import (
"errors"
"io"
"net/mail"
"time"
)
var (
// ErrNotExist indicates the requested message does not exist.
ErrNotExist = errors.New("message does not exist")
// ErrNotWritable indicates the message is closed; no longer writable
ErrNotWritable = errors.New("Message not writable")
)
// Store is the interface Inbucket uses to interact with storage implementations.
type Store interface {
// AddMessage stores the message, message ID and Size will be ignored.
AddMessage(message Message) (id string, err error)
GetMessage(mailbox, id string) (Message, error)
GetMessages(mailbox string) ([]Message, error)
PurgeMessages(mailbox string) error
RemoveMessage(mailbox, id string) error
VisitMailboxes(f func([]Message) (cont bool)) error
}
// Message represents a message to be stored, or returned from a storage implementation.
type Message interface {
Mailbox() string
ID() string
From() *mail.Address
To() []*mail.Address
Date() time.Time
Subject() string
Source() (io.ReadCloser, error)
Size() int64
}

View File

@@ -1,163 +0,0 @@
package datastore
import (
"io"
"net/mail"
"sync"
"time"
"github.com/jhillyerd/enmime"
"github.com/stretchr/testify/mock"
)
// MockDataStore is a shared mock for unit testing
type MockDataStore struct {
mock.Mock
}
// MailboxFor mock function
func (m *MockDataStore) MailboxFor(name string) (Mailbox, error) {
args := m.Called(name)
return args.Get(0).(Mailbox), args.Error(1)
}
// AllMailboxes mock function
func (m *MockDataStore) AllMailboxes() ([]Mailbox, error) {
args := m.Called()
return args.Get(0).([]Mailbox), args.Error(1)
}
// LockFor mock function returns a new RWMutex, never errors.
func (m *MockDataStore) LockFor(name string) (*sync.RWMutex, error) {
return &sync.RWMutex{}, nil
}
// MockMailbox is a shared mock for unit testing
type MockMailbox struct {
mock.Mock
}
// GetMessages mock function
func (m *MockMailbox) GetMessages() ([]Message, error) {
args := m.Called()
return args.Get(0).([]Message), args.Error(1)
}
// GetMessage mock function
func (m *MockMailbox) GetMessage(id string) (Message, error) {
args := m.Called(id)
return args.Get(0).(Message), args.Error(1)
}
// Purge mock function
func (m *MockMailbox) Purge() error {
args := m.Called()
return args.Error(0)
}
// NewMessage mock function
func (m *MockMailbox) NewMessage() (Message, error) {
args := m.Called()
return args.Get(0).(Message), args.Error(1)
}
// Name mock function
func (m *MockMailbox) Name() string {
args := m.Called()
return args.String(0)
}
// String mock function
func (m *MockMailbox) String() string {
args := m.Called()
return args.String(0)
}
// MockMessage is a shared mock for unit testing
type MockMessage struct {
mock.Mock
}
// ID mock function
func (m *MockMessage) ID() string {
args := m.Called()
return args.String(0)
}
// From mock function
func (m *MockMessage) From() string {
args := m.Called()
return args.String(0)
}
// To mock function
func (m *MockMessage) To() []string {
args := m.Called()
return args.Get(0).([]string)
}
// Date mock function
func (m *MockMessage) Date() time.Time {
args := m.Called()
return args.Get(0).(time.Time)
}
// Subject mock function
func (m *MockMessage) Subject() string {
args := m.Called()
return args.String(0)
}
// ReadHeader mock function
func (m *MockMessage) ReadHeader() (msg *mail.Message, err error) {
args := m.Called()
return args.Get(0).(*mail.Message), args.Error(1)
}
// ReadBody mock function
func (m *MockMessage) ReadBody() (body *enmime.Envelope, err error) {
args := m.Called()
return args.Get(0).(*enmime.Envelope), args.Error(1)
}
// ReadRaw mock function
func (m *MockMessage) ReadRaw() (raw *string, err error) {
args := m.Called()
return args.Get(0).(*string), args.Error(1)
}
// RawReader mock function
func (m *MockMessage) RawReader() (reader io.ReadCloser, err error) {
args := m.Called()
return args.Get(0).(io.ReadCloser), args.Error(1)
}
// Size mock function
func (m *MockMessage) Size() int64 {
args := m.Called()
return int64(args.Int(0))
}
// Append mock function
func (m *MockMessage) Append(data []byte) error {
// []byte arg seems to mess up testify/mock
return nil
}
// Close mock function
func (m *MockMessage) Close() error {
args := m.Called()
return args.Error(0)
}
// Delete mock function
func (m *MockMessage) Delete() error {
args := m.Called()
return args.Error(0)
}
// String mock function
func (m *MockMessage) String() string {
args := m.Called()
return args.String(0)
}

View File

@@ -1,46 +1,12 @@
package stringutil package stringutil
import ( import (
"bytes"
"crypto/sha1" "crypto/sha1"
"fmt" "fmt"
"io" "io"
"strings" "net/mail"
) )
// ParseMailboxName takes a localPart string (ex: "user+ext" without "@domain")
// and returns just the mailbox name (ex: "user"). Returns an error if
// localPart contains invalid characters; it won't accept any that must be
// quoted according to RFC3696.
func ParseMailboxName(localPart string) (result string, err error) {
if localPart == "" {
return "", fmt.Errorf("Mailbox name cannot be empty")
}
result = strings.ToLower(localPart)
invalid := make([]byte, 0, 10)
for i := 0; i < len(result); i++ {
c := result[i]
switch {
case 'a' <= c && c <= 'z':
case '0' <= c && c <= '9':
case bytes.IndexByte([]byte("!#$%&'*+-=/?^_`.{|}~"), c) >= 0:
default:
invalid = append(invalid, c)
}
}
if len(invalid) > 0 {
return "", fmt.Errorf("Mailbox name contained invalid character(s): %q", invalid)
}
if idx := strings.Index(result, "+"); idx > -1 {
result = result[0:idx]
}
return result, nil
}
// HashMailboxName accepts a mailbox name and hashes it. filestore uses this as // HashMailboxName accepts a mailbox name and hashes it. filestore uses this as
// the directory to house the mailbox // the directory to house the mailbox
func HashMailboxName(mailbox string) string { func HashMailboxName(mailbox string) string {
@@ -52,175 +18,13 @@ func HashMailboxName(mailbox string) string {
return fmt.Sprintf("%x", h.Sum(nil)) return fmt.Sprintf("%x", h.Sum(nil))
} }
// ValidateDomainPart returns true if the domain part complies to RFC3696, RFC1035 // StringAddressList converts a list of addresses to a list of strings
func ValidateDomainPart(domain string) bool { func StringAddressList(addrs []*mail.Address) []string {
if len(domain) == 0 { s := make([]string, len(addrs))
return false for i, a := range addrs {
} if a != nil {
if len(domain) > 255 { s[i] = a.String()
return false
}
if domain[len(domain)-1] != '.' {
domain += "."
}
prev := '.'
labelLen := 0
hasAlphaNum := false
for _, c := range domain {
switch {
case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9') || c == '_':
// Must contain some of these to be a valid label
hasAlphaNum = true
labelLen++
case c == '-':
if prev == '.' {
// Cannot lead with hyphen
return false
}
case c == '.':
if prev == '.' || prev == '-' {
// Cannot end with hyphen or double-dot
return false
}
if labelLen > 63 {
return false
}
if !hasAlphaNum {
return false
}
labelLen = 0
hasAlphaNum = false
default:
// Unknown character
return false
} }
prev = c
} }
return s
return true
}
// ParseEmailAddress unescapes an email address, and splits the local part from the domain part.
// An error is returned if the local or domain parts fail validation following the guidelines
// in RFC3696.
func ParseEmailAddress(address string) (local string, domain string, err error) {
if address == "" {
return "", "", fmt.Errorf("Empty address")
}
if len(address) > 320 {
return "", "", fmt.Errorf("Address exceeds 320 characters")
}
if address[0] == '@' {
return "", "", fmt.Errorf("Address cannot start with @ symbol")
}
if address[0] == '.' {
return "", "", fmt.Errorf("Address cannot start with a period")
}
// Loop over address parsing out local part
buf := new(bytes.Buffer)
prev := byte('.')
inCharQuote := false
inStringQuote := false
LOOP:
for i := 0; i < len(address); i++ {
c := address[i]
switch {
case ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'):
// Letters are OK
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case '0' <= c && c <= '9':
// Numbers are OK
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case bytes.IndexByte([]byte("!#$%&'*+-/=?^_`{|}~"), c) >= 0:
// These specials can be used unquoted
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case c == '.':
// A single period is OK
if prev == '.' {
// Sequence of periods is not permitted
return "", "", fmt.Errorf("Sequence of periods is not permitted")
}
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
case c == '\\':
inCharQuote = true
case c == '"':
if inCharQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else if inStringQuote {
inStringQuote = false
} else {
if i == 0 {
inStringQuote = true
} else {
return "", "", fmt.Errorf("Quoted string can only begin at start of address")
}
}
case c == '@':
if inCharQuote || inStringQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else {
// End of local-part
if i > 128 {
return "", "", fmt.Errorf("Local part must not exceed 128 characters")
}
if prev == '.' {
return "", "", fmt.Errorf("Local part cannot end with a period")
}
domain = address[i+1:]
break LOOP
}
case c > 127:
return "", "", fmt.Errorf("Characters outside of US-ASCII range not permitted")
default:
if inCharQuote || inStringQuote {
err = buf.WriteByte(c)
if err != nil {
return
}
inCharQuote = false
} else {
return "", "", fmt.Errorf("Character %q must be quoted", c)
}
}
prev = c
}
if inCharQuote {
return "", "", fmt.Errorf("Cannot end address with unterminated quoted-pair")
}
if inStringQuote {
return "", "", fmt.Errorf("Cannot end address with unterminated string quote")
}
if !ValidateDomainPart(domain) {
return "", "", fmt.Errorf("Domain part validation failed")
}
return buf.String(), domain, nil
} }

View File

@@ -1,215 +1,33 @@
package stringutil package stringutil_test
import ( import (
"strings" "net/mail"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/jhillyerd/inbucket/pkg/stringutil"
) )
func TestParseMailboxName(t *testing.T) {
var validTable = []struct {
input string
expect string
}{
{"mailbox", "mailbox"},
{"user123", "user123"},
{"MailBOX", "mailbox"},
{"First.Last", "first.last"},
{"user+label", "user"},
{"chars!#$%", "chars!#$%"},
{"chars&'*-", "chars&'*-"},
{"chars=/?^", "chars=/?^"},
{"chars_`.{", "chars_`.{"},
{"chars|}~", "chars|}~"},
}
for _, tt := range validTable {
if result, err := ParseMailboxName(tt.input); err != nil {
t.Errorf("Error while parsing %q: %v", tt.input, err)
} else {
if result != tt.expect {
t.Errorf("Parsing %q, expected %q, got %q", tt.input, tt.expect, result)
}
}
}
var invalidTable = []struct {
input, msg string
}{
{"", "Empty mailbox name is not permitted"},
{"user@host", "@ symbol not permitted"},
{"first last", "Space not permitted"},
{"first\"last", "Double quote not permitted"},
{"first\nlast", "Control chars not permitted"},
}
for _, tt := range invalidTable {
if _, err := ParseMailboxName(tt.input); err == nil {
t.Errorf("Didn't get an error while parsing %q: %v", tt.input, tt.msg)
}
}
}
func TestHashMailboxName(t *testing.T) { func TestHashMailboxName(t *testing.T) {
assert.Equal(t, HashMailboxName("mail"), "1d6e1cf70ec6f9ab28d3ea4b27a49a77654d370e") want := "1d6e1cf70ec6f9ab28d3ea4b27a49a77654d370e"
} got := stringutil.HashMailboxName("mail")
if got != want {
func TestValidateDomain(t *testing.T) { t.Errorf("Got %q, want %q", got, want)
assert.False(t, ValidateDomainPart(strings.Repeat("a", 256)),
"Max domain length is 255")
assert.False(t, ValidateDomainPart(strings.Repeat("a", 64)+".com"),
"Max label length is 63")
assert.True(t, ValidateDomainPart(strings.Repeat("a", 63)+".com"),
"Should allow 63 char label")
var testTable = []struct {
input string
expect bool
msg string
}{
{"", false, "Empty domain is not valid"},
{"hostname", true, "Just a hostname is valid"},
{"github.com", true, "Two labels should be just fine"},
{"my-domain.com", true, "Hyphen is allowed mid-label"},
{"_domainkey.foo.com", true, "Underscores are allowed"},
{"bar.com.", true, "Must be able to end with a dot"},
{"ABC.6DBS.com", true, "Mixed case is OK"},
{"mail.123.com", true, "Number only label valid"},
{"123.com", true, "Number only label valid"},
{"google..com", false, "Double dot not valid"},
{".foo.com", false, "Cannot start with a dot"},
{"google\r.com", false, "Special chars not allowed"},
{"foo.-bar.com", false, "Label cannot start with hyphen"},
{"foo-.bar.com", false, "Label cannot end with hyphen"},
}
for _, tt := range testTable {
if ValidateDomainPart(tt.input) != tt.expect {
t.Errorf("Expected %v for %q: %s", tt.expect, tt.input, tt.msg)
}
} }
} }
func TestValidateLocal(t *testing.T) { func TestStringAddressList(t *testing.T) {
var testTable = []struct { input := []*mail.Address{
input string {Name: "Fred B. Fish", Address: "fred@fish.org"},
expect bool {Name: "User", Address: "user@domain.org"},
msg string
}{
{"", false, "Empty local is not valid"},
{"a", true, "Single letter should be fine"},
{strings.Repeat("a", 128), true, "Valid up to 128 characters"},
{strings.Repeat("a", 129), false, "Only valid up to 128 characters"},
{"FirstLast", true, "Mixed case permitted"},
{"user123", true, "Numbers permitted"},
{"a!#$%&'*+-/=?^_`{|}~", true, "Any of !#$%&'*+-/=?^_`{|}~ are permitted"},
{"first.last", true, "Embedded period is permitted"},
{"first..last", false, "Sequence of periods is not allowed"},
{".user", false, "Cannot lead with a period"},
{"user.", false, "Cannot end with a period"},
{"james@mail", false, "Unquoted @ not permitted"},
{"first last", false, "Unquoted space not permitted"},
{"tricky\\. ", false, "Unquoted space not permitted"},
{"no,commas", false, "Unquoted comma not allowed"},
{"t[es]t", false, "Unquoted square brackets not allowed"},
{"james\\", false, "Cannot end with backslash quote"},
{"james\\@mail", true, "Quoted @ permitted"},
{"quoted\\ space", true, "Quoted space permitted"},
{"no\\,commas", true, "Quoted comma is OK"},
{"t\\[es\\]t", true, "Quoted brackets are OK"},
{"user\\name", true, "Should be able to quote a-z"},
{"USER\\NAME", true, "Should be able to quote A-Z"},
{"user\\1", true, "Should be able to quote a digit"},
{"one\\$\\|", true, "Should be able to quote plain specials"},
{"return\\\r", true, "Should be able to quote ASCII control chars"},
{"high\\\x80", false, "Should not accept > 7-bit quoted chars"},
{"quote\\\"", true, "Quoted double quote is permitted"},
{"\"james\"", true, "Quoted a-z is permitted"},
{"\"first last\"", true, "Quoted space is permitted"},
{"\"quoted@sign\"", true, "Quoted @ is allowed"},
{"\"qp\\\"quote\"", true, "Quoted quote within quoted string is OK"},
{"\"unterminated", false, "Quoted string must be terminated"},
{"\"unterminated\\\"", false, "Quoted string must be terminated"},
{"embed\"quote\"string", false, "Embedded quoted string is illegal"},
{"user+mailbox", true, "RFC3696 test case should be valid"},
{"customer/department=shipping", true, "RFC3696 test case should be valid"},
{"$A12345", true, "RFC3696 test case should be valid"},
{"!def!xyz%abc", true, "RFC3696 test case should be valid"},
{"_somename", true, "RFC3696 test case should be valid"},
} }
want := []string{`"Fred B. Fish" <fred@fish.org>`, `"User" <user@domain.org>`}
for _, tt := range testTable { output := stringutil.StringAddressList(input)
_, _, err := ParseEmailAddress(tt.input + "@domain.com") if len(output) != len(want) {
if (err != nil) == tt.expect { t.Fatalf("Got %v strings, want: %v", len(output), len(want))
if err != nil { }
t.Logf("Got error: %s", err) for i, got := range output {
} if got != want[i] {
t.Errorf("Expected %v for %q: %s", tt.expect, tt.input, tt.msg) t.Errorf("Got %q, want: %q", got, want[i])
}
}
}
func TestParseEmailAddress(t *testing.T) {
// Test some good email addresses
var testTable = []struct {
input, local, domain string
}{
{"root@localhost", "root", "localhost"},
{"FirstLast@domain.local", "FirstLast", "domain.local"},
{"route66@prodigy.net", "route66", "prodigy.net"},
{"lorbit!user@uucp", "lorbit!user", "uucp"},
{"user+spam@gmail.com", "user+spam", "gmail.com"},
{"first.last@domain.local", "first.last", "domain.local"},
{"first\\ last@_key.domain.com", "first last", "_key.domain.com"},
{"first\\\"last@a.b.c", "first\"last", "a.b.c"},
{"user\\@internal@myhost.ca", "user@internal", "myhost.ca"},
{"\"first last@evil\"@top-secret.gov", "first last@evil", "top-secret.gov"},
{"\"line\nfeed\"@linenoise.co.uk", "line\nfeed", "linenoise.co.uk"},
{"user+mailbox@host", "user+mailbox", "host"},
{"customer/department=shipping@host", "customer/department=shipping", "host"},
{"$A12345@host", "$A12345", "host"},
{"!def!xyz%abc@host", "!def!xyz%abc", "host"},
{"_somename@host", "_somename", "host"},
}
for _, tt := range testTable {
local, domain, err := ParseEmailAddress(tt.input)
if err != nil {
t.Errorf("Error when parsing %q: %s", tt.input, err)
} else {
if tt.local != local {
t.Errorf("When parsing %q, expected local %q, got %q instead",
tt.input, tt.local, local)
}
if tt.domain != domain {
t.Errorf("When parsing %q, expected domain %q, got %q instead",
tt.input, tt.domain, domain)
}
}
}
// Check that validations fail correctly
var badTable = []struct {
input, msg string
}{
{"", "Empty address not permitted"},
{"user", "Missing domain part"},
{"@host", "Missing local part"},
{"user\\@host", "Missing domain part"},
{"\"user@host\"", "Missing domain part"},
{"\"user@host", "Unterminated quoted string"},
{"first last@host", "Unquoted space"},
{"user@bad!domain", "Invalid domain"},
{".user@host", "Can't lead with a ."},
{"user.@host", "Can't end local with a dot"},
{"user@bad domain", "No spaces in domain permitted"},
}
for _, tt := range badTable {
if _, _, err := ParseEmailAddress(tt.input); err == nil {
t.Errorf("Did not get expected error when parsing %q: %s", tt.input, tt.msg)
} }
} }
} }

59
pkg/test/manager.go Normal file
View File

@@ -0,0 +1,59 @@
package test
import (
"errors"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/policy"
"github.com/jhillyerd/inbucket/pkg/storage"
)
// ManagerStub is a test stub for message.Manager
type ManagerStub struct {
message.Manager
mailboxes map[string][]*message.Message
}
// NewManager creates a new ManagerStub.
func NewManager() *ManagerStub {
return &ManagerStub{
mailboxes: make(map[string][]*message.Message),
}
}
// AddMessage adds a message to the specified mailbox.
func (m *ManagerStub) AddMessage(mailbox string, msg *message.Message) {
messages := m.mailboxes[mailbox]
m.mailboxes[mailbox] = append(messages, msg)
}
// GetMessage gets a message by ID from the specified mailbox.
func (m *ManagerStub) GetMessage(mailbox, id string) (*message.Message, error) {
if mailbox == "messageerr" {
return nil, errors.New("internal error")
}
for _, msg := range m.mailboxes[mailbox] {
if msg.ID == id {
return msg, nil
}
}
return nil, storage.ErrNotExist
}
// GetMetadata gets all the metadata for the specified mailbox.
func (m *ManagerStub) GetMetadata(mailbox string) ([]*message.Metadata, error) {
if mailbox == "messageserr" {
return nil, errors.New("internal error")
}
messages := m.mailboxes[mailbox]
metas := make([]*message.Metadata, len(messages))
for i, msg := range messages {
metas[i] = &msg.Metadata
}
return metas, nil
}
// MailboxForAddress invokes policy.ParseMailboxName.
func (m *ManagerStub) MailboxForAddress(address string) (string, error) {
return policy.ParseMailboxName(address)
}

88
pkg/test/storage.go Normal file
View File

@@ -0,0 +1,88 @@
package test
import (
"errors"
"github.com/jhillyerd/inbucket/pkg/storage"
)
// StoreStub stubs storage.Store for testing.
type StoreStub struct {
storage.Store
mailboxes map[string][]storage.Message
deleted map[storage.Message]struct{}
}
// NewStore creates a new StoreStub.
func NewStore() *StoreStub {
return &StoreStub{
mailboxes: make(map[string][]storage.Message),
deleted: make(map[storage.Message]struct{}),
}
}
// AddMessage adds a message to the specified mailbox.
func (s *StoreStub) AddMessage(m storage.Message) (id string, err error) {
mb := m.Mailbox()
msgs := s.mailboxes[mb]
s.mailboxes[mb] = append(msgs, m)
return m.ID(), nil
}
// GetMessage gets a message by ID from the specified mailbox.
func (s *StoreStub) GetMessage(mailbox, id string) (storage.Message, error) {
if mailbox == "messageerr" {
return nil, errors.New("internal error")
}
for _, m := range s.mailboxes[mailbox] {
if m.ID() == id {
return m, nil
}
}
return nil, storage.ErrNotExist
}
// GetMessages gets all the messages for the specified mailbox.
func (s *StoreStub) GetMessages(mailbox string) ([]storage.Message, error) {
if mailbox == "messageserr" {
return nil, errors.New("internal error")
}
return s.mailboxes[mailbox], nil
}
// RemoveMessage deletes a message by ID from the specified mailbox.
func (s *StoreStub) RemoveMessage(mailbox, id string) error {
mb, ok := s.mailboxes[mailbox]
if ok {
var msg storage.Message
for i, m := range mb {
if m.ID() == id {
msg = m
s.mailboxes[mailbox] = append(mb[:i], mb[i+1:]...)
break
}
}
if msg != nil {
s.deleted[msg] = struct{}{}
return nil
}
}
return storage.ErrNotExist
}
// VisitMailboxes accepts a function that will be called with the messages in each mailbox while it
// continues to return true.
func (s *StoreStub) VisitMailboxes(f func([]storage.Message) (cont bool)) error {
for _, v := range s.mailboxes {
if !f(v) {
return nil
}
}
return nil
}
// MessageDeleted returns true if the specified message was deleted
func (s *StoreStub) MessageDeleted(m storage.Message) bool {
_, ok := s.deleted[m]
return ok
}

330
pkg/test/storage_suite.go Normal file
View File

@@ -0,0 +1,330 @@
package test
import (
"bytes"
"fmt"
"io/ioutil"
"net/mail"
"strings"
"testing"
"time"
"github.com/jhillyerd/inbucket/pkg/message"
"github.com/jhillyerd/inbucket/pkg/storage"
)
// StoreFactory returns a new store for the test suite.
type StoreFactory func() (store storage.Store, destroy func(), err error)
// StoreSuite runs a set of general tests on the provided Store.
func StoreSuite(t *testing.T, factory StoreFactory) {
testCases := []struct {
name string
test func(*testing.T, storage.Store)
}{
{"metadata", testMetadata},
{"content", testContent},
{"delivery order", testDeliveryOrder},
{"size", testSize},
{"delete", testDelete},
{"purge", testPurge},
{"visit mailboxes", testVisitMailboxes},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
store, destroy, err := factory()
if err != nil {
t.Fatal(err)
}
tc.test(t, store)
destroy()
})
}
}
// testMetadata verifies message metadata is stored and retrieved correctly.
func testMetadata(t *testing.T, store storage.Store) {
mailbox := "testmailbox"
from := &mail.Address{Name: "From Person", Address: "from@person.com"}
to := []*mail.Address{
{Name: "One Person", Address: "one@a.person.com"},
{Name: "Two Person", Address: "two@b.person.com"},
}
date := time.Now()
subject := "fantastic test subject line"
content := "doesn't matter"
delivery := &message.Delivery{
Meta: message.Metadata{
// ID and Size will be determined by the Store.
Mailbox: mailbox,
From: from,
To: to,
Date: date,
Subject: subject,
},
Reader: strings.NewReader(content),
}
id, err := store.AddMessage(delivery)
if err != nil {
t.Fatal(err)
}
if id == "" {
t.Fatal("Expected AddMessage() to return non-empty ID string")
}
// Retrieve and validate the message.
sm, err := store.GetMessage(mailbox, id)
if err != nil {
t.Fatal(err)
}
if sm.Mailbox() != mailbox {
t.Errorf("got mailbox %q, want: %q", sm.Mailbox(), mailbox)
}
if sm.ID() != id {
t.Errorf("got id %q, want: %q", sm.ID(), id)
}
if *sm.From() != *from {
t.Errorf("got from %v, want: %v", sm.From(), from)
}
if len(sm.To()) != len(to) {
t.Errorf("got len(to) = %v, want: %v", len(sm.To()), len(to))
} else {
for i, got := range sm.To() {
if *to[i] != *got {
t.Errorf("got to[%v] %v, want: %v", i, got, to[i])
}
}
}
if !sm.Date().Equal(date) {
t.Errorf("got date %v, want: %v", sm.Date(), date)
}
if sm.Subject() != subject {
t.Errorf("got subject %q, want: %q", sm.Subject(), subject)
}
if sm.Size() != int64(len(content)) {
t.Errorf("got size %v, want: %v", sm.Size(), len(content))
}
}
// testContent generates some binary content and makes sure it is correctly retrieved.
func testContent(t *testing.T, store storage.Store) {
content := make([]byte, 5000)
for i := 0; i < len(content); i++ {
content[i] = byte(i % 256)
}
mailbox := "testmailbox"
from := &mail.Address{Name: "From Person", Address: "from@person.com"}
to := []*mail.Address{
{Name: "One Person", Address: "one@a.person.com"},
}
date := time.Now()
subject := "fantastic test subject line"
delivery := &message.Delivery{
Meta: message.Metadata{
// ID and Size will be determined by the Store.
Mailbox: mailbox,
From: from,
To: to,
Date: date,
Subject: subject,
},
Reader: bytes.NewReader(content),
}
id, err := store.AddMessage(delivery)
if err != nil {
t.Fatal(err)
}
// Get and check.
m, err := store.GetMessage(mailbox, id)
if err != nil {
t.Fatal(err)
}
r, err := m.Source()
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if len(got) != len(content) {
t.Errorf("Got len(content) == %v, want: %v", len(got), len(content))
}
errors := 0
for i, b := range got {
if b != content[i] {
t.Errorf("Got content[%v] == %v, want: %v", i, b, content[i])
errors++
}
if errors > 5 {
t.Fatalf("Too many content errors, aborting test.")
break
}
}
}
// testDeliveryOrder delivers several messages to the same mailbox, meanwhile querying its contents
// with a new GetMessages call each cycle.
func testDeliveryOrder(t *testing.T, store storage.Store) {
mailbox := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for i, subj := range subjects {
// Check mailbox count.
getAndCountMessages(t, store, mailbox, i)
deliverMessage(t, store, mailbox, subj, time.Now())
}
// Confirm delivery order.
msgs := getAndCountMessages(t, store, mailbox, 5)
for i, want := range subjects {
got := msgs[i].Subject()
if got != want {
t.Errorf("Got subject %q, want %q", got, want)
}
}
}
// testSize verifies message contnet size metadata values.
func testSize(t *testing.T, store storage.Store) {
mailbox := "fred"
subjects := []string{"a", "br", "much longer than the others"}
sentIds := make([]string, len(subjects))
sentSizes := make([]int64, len(subjects))
for i, subj := range subjects {
id, size := deliverMessage(t, store, mailbox, subj, time.Now())
sentIds[i] = id
sentSizes[i] = size
}
for i, id := range sentIds {
msg, err := store.GetMessage(mailbox, id)
if err != nil {
t.Fatal(err)
}
want := sentSizes[i]
got := msg.Size()
if got != want {
t.Errorf("Got size %v, want: %v", got, want)
}
}
}
// testDelete creates and deletes some messages.
func testDelete(t *testing.T, store storage.Store) {
mailbox := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for _, subj := range subjects {
deliverMessage(t, store, mailbox, subj, time.Now())
}
msgs := getAndCountMessages(t, store, mailbox, len(subjects))
// Delete a couple messages.
err := store.RemoveMessage(mailbox, msgs[1].ID())
if err != nil {
t.Fatal(err)
}
err = store.RemoveMessage(mailbox, msgs[3].ID())
if err != nil {
t.Fatal(err)
}
// Confirm deletion.
subjects = []string{"alpha", "charlie", "echo"}
msgs = getAndCountMessages(t, store, mailbox, len(subjects))
for i, want := range subjects {
got := msgs[i].Subject()
if got != want {
t.Errorf("Got subject %q, want %q", got, want)
}
}
// Try appending one more.
deliverMessage(t, store, mailbox, "foxtrot", time.Now())
subjects = []string{"alpha", "charlie", "echo", "foxtrot"}
msgs = getAndCountMessages(t, store, mailbox, len(subjects))
for i, want := range subjects {
got := msgs[i].Subject()
if got != want {
t.Errorf("Got subject %q, want %q", got, want)
}
}
}
// testPurge makes sure mailboxes can be purged.
func testPurge(t *testing.T, store storage.Store) {
mailbox := "fred"
subjects := []string{"alpha", "bravo", "charlie", "delta", "echo"}
for _, subj := range subjects {
deliverMessage(t, store, mailbox, subj, time.Now())
}
getAndCountMessages(t, store, mailbox, len(subjects))
// Purge and verify.
err := store.PurgeMessages(mailbox)
if err != nil {
t.Fatal(err)
}
getAndCountMessages(t, store, mailbox, 0)
}
// testVisitMailboxes creates some mailboxes and confirms the VisitMailboxes method visits all of
// them.
func testVisitMailboxes(t *testing.T, ds storage.Store) {
boxes := []string{"abby", "bill", "christa", "donald", "evelyn"}
for _, name := range boxes {
deliverMessage(t, ds, name, "Old Message", time.Now().Add(-24*time.Hour))
deliverMessage(t, ds, name, "New Message", time.Now())
}
seen := 0
err := ds.VisitMailboxes(func(messages []storage.Message) bool {
seen++
count := len(messages)
if count != 2 {
t.Errorf("got: %v messages, want: 2", count)
}
return true
})
if err != nil {
t.Error(err)
}
if seen != 5 {
t.Errorf("saw %v messages in total, want: 5", seen)
}
}
// deliverMessage creates and delivers a message to the specific mailbox, returning the size of the
// generated message.
func deliverMessage(
t *testing.T,
store storage.Store,
mailbox string,
subject string,
date time.Time,
) (string, int64) {
t.Helper()
meta := message.Metadata{
Mailbox: mailbox,
To: []*mail.Address{{Name: "Some Body", Address: "somebody@host"}},
From: &mail.Address{Name: "Some B. Else", Address: "somebodyelse@host"},
Subject: subject,
Date: date,
}
testMsg := fmt.Sprintf("To: %s\r\nFrom: %s\r\nSubject: %s\r\n\r\nTest Body\r\n",
meta.To[0].Address, meta.From.Address, subject)
delivery := &message.Delivery{
Meta: meta,
Reader: ioutil.NopCloser(strings.NewReader(testMsg)),
}
id, err := store.AddMessage(delivery)
if err != nil {
t.Fatal(err)
}
return id, int64(len(testMsg))
}
// getAndCountMessages is a test helper that expects to receive count messages or fails the test, it
// also checks return error.
func getAndCountMessages(t *testing.T, s storage.Store, mailbox string, count int) []storage.Message {
t.Helper()
msgs, err := s.GetMessages(mailbox)
if err != nil {
t.Fatalf("Failed to GetMessages for %q: %v", mailbox, err)
}
if len(msgs) != count {
t.Errorf("Got %v messages, want: %v", len(msgs), count)
}
return msgs
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/jhillyerd/inbucket/pkg/log" "github.com/jhillyerd/inbucket/pkg/log"
"github.com/jhillyerd/inbucket/pkg/server/web" "github.com/jhillyerd/inbucket/pkg/server/web"
"github.com/jhillyerd/inbucket/pkg/storage" "github.com/jhillyerd/inbucket/pkg/storage"
"github.com/jhillyerd/inbucket/pkg/stringutil"
"github.com/jhillyerd/inbucket/pkg/webui/sanitize" "github.com/jhillyerd/inbucket/pkg/webui/sanitize"
) )
@@ -25,7 +24,7 @@ func MailboxIndex(w http.ResponseWriter, req *http.Request, ctx *web.Context) (e
http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
name, err = stringutil.ParseMailboxName(name) name, err = ctx.Manager.MailboxForAddress(name)
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -52,7 +51,7 @@ func MailboxIndex(w http.ResponseWriter, req *http.Request, ctx *web.Context) (e
func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -68,16 +67,11 @@ func MailboxLink(w http.ResponseWriter, req *http.Request, ctx *web.Context) (er
// MailboxList renders a list of messages in a mailbox. Renders a partial // MailboxList renders a list of messages in a mailbox. Renders a partial
func MailboxList(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxList(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) messages, err := ctx.Manager.GetMetadata(name)
if err != nil {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
messages, err := mb.GetMessages()
if err != nil { if err != nil {
// This doesn't indicate empty, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("Failed to get messages for %v: %v", name, err) return fmt.Errorf("Failed to get messages for %v: %v", name, err)
@@ -95,17 +89,12 @@ func MailboxList(w http.ResponseWriter, req *http.Request, ctx *web.Context) (er
func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) msg, err := ctx.Manager.GetMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
msg, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
@@ -113,10 +102,7 @@ func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *web.Context) (er
// This doesn't indicate empty, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("GetMessage(%q) failed: %v", id, err)
} }
mime, err := msg.ReadBody() mime := msg.Envelope
if err != nil {
return fmt.Errorf("ReadBody(%q) failed: %v", id, err)
}
body := template.HTML(web.TextToHTML(mime.Text)) body := template.HTML(web.TextToHTML(mime.Text))
htmlAvailable := mime.HTML != "" htmlAvailable := mime.HTML != ""
var htmlBody template.HTML var htmlBody template.HTML
@@ -144,36 +130,27 @@ func MailboxShow(w http.ResponseWriter, req *http.Request, ctx *web.Context) (er
func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) msg, err := ctx.Manager.GetMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("GetMessage(%q) failed: %v", id, err)
} }
mime, err := message.ReadBody() mime := msg.Envelope
if err != nil {
return fmt.Errorf("ReadBody(%q) failed: %v", id, err)
}
// Render partial template // Render partial template
w.Header().Set("Content-Type", "text/html; charset=UTF-8") w.Header().Set("Content-Type", "text/html; charset=UTF-8")
return web.RenderPartial("mailbox/_html.html", w, map[string]interface{}{ return web.RenderPartial("mailbox/_html.html", w, map[string]interface{}{
"ctx": ctx, "ctx": ctx,
"name": name, "name": name,
"message": message, "message": msg,
// TODO It is not really safe to render, need to sanitize, issue #5 "body": template.HTML(mime.HTML),
"body": template.HTML(mime.HTML),
}) })
} }
@@ -181,34 +158,23 @@ func MailboxHTML(w http.ResponseWriter, req *http.Request, ctx *web.Context) (er
func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
return err return err
} }
mb, err := ctx.DataStore.MailboxFor(name) r, err := ctx.Manager.SourceReader(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate missing, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("SourceReader(%q) failed: %v", id, err)
}
raw, err := message.ReadRaw()
if err != nil {
return fmt.Errorf("ReadRaw(%q) failed: %v", id, err)
} }
// Output message source // Output message source
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
if _, err := io.WriteString(w, *raw); err != nil { _, err = io.Copy(w, r)
return err return err
}
return nil
} }
// MailboxDownloadAttach sends the attachment to the client; disposition: // MailboxDownloadAttach sends the attachment to the client; disposition:
@@ -216,7 +182,7 @@ func MailboxSource(w http.ResponseWriter, req *http.Request, ctx *web.Context) (
func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
id := ctx.Vars["id"] id := ctx.Vars["id"]
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -231,24 +197,16 @@ func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *web.Co
http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
mb, err := ctx.DataStore.MailboxFor(name) msg, err := ctx.Manager.GetMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("GetMessage(%q) failed: %v", id, err)
} }
body, err := message.ReadBody() body := msg.Envelope
if err != nil {
return err
}
if int(num) >= len(body.Attachments) { if int(num) >= len(body.Attachments) {
ctx.Session.AddFlash("Attachment number too high", "errors") ctx.Session.AddFlash("Attachment number too high", "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -259,16 +217,14 @@ func MailboxDownloadAttach(w http.ResponseWriter, req *http.Request, ctx *web.Co
// Output attachment // Output attachment
w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", "attachment") w.Header().Set("Content-Disposition", "attachment")
if _, err := io.Copy(w, part); err != nil { _, err = io.Copy(w, part)
return err return err
}
return nil
} }
// MailboxViewAttach sends the attachment to the client for online viewing // MailboxViewAttach sends the attachment to the client for online viewing
func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) { func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *web.Context) (err error) {
// Don't have to validate these aren't empty, Gorilla returns 404 // Don't have to validate these aren't empty, Gorilla returns 404
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -284,24 +240,16 @@ func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *web.Contex
http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
mb, err := ctx.DataStore.MailboxFor(name) msg, err := ctx.Manager.GetMessage(name, id)
if err != nil { if err == storage.ErrNotExist {
// This doesn't indicate not found, likely an IO error
return fmt.Errorf("Failed to get mailbox for %q: %v", name, err)
}
message, err := mb.GetMessage(id)
if err == datastore.ErrNotExist {
http.NotFound(w, req) http.NotFound(w, req)
return nil return nil
} }
if err != nil { if err != nil {
// This doesn't indicate missing, likely an IO error // This doesn't indicate empty, likely an IO error
return fmt.Errorf("GetMessage(%q) failed: %v", id, err) return fmt.Errorf("GetMessage(%q) failed: %v", id, err)
} }
body, err := message.ReadBody() body := msg.Envelope
if err != nil {
return err
}
if int(num) >= len(body.Attachments) { if int(num) >= len(body.Attachments) {
ctx.Session.AddFlash("Attachment number too high", "errors") ctx.Session.AddFlash("Attachment number too high", "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)
@@ -311,8 +259,6 @@ func MailboxViewAttach(w http.ResponseWriter, req *http.Request, ctx *web.Contex
part := body.Attachments[num] part := body.Attachments[num]
// Output attachment // Output attachment
w.Header().Set("Content-Type", part.ContentType) w.Header().Set("Content-Type", part.ContentType)
if _, err := io.Copy(w, part); err != nil { _, err = io.Copy(w, part)
return err return err
}
return nil
} }

View File

@@ -8,7 +8,6 @@ import (
"github.com/jhillyerd/inbucket/pkg/config" "github.com/jhillyerd/inbucket/pkg/config"
"github.com/jhillyerd/inbucket/pkg/server/web" "github.com/jhillyerd/inbucket/pkg/server/web"
"github.com/jhillyerd/inbucket/pkg/stringutil"
) )
// RootIndex serves the Inbucket landing page // RootIndex serves the Inbucket landing page
@@ -58,7 +57,7 @@ func RootMonitorMailbox(w http.ResponseWriter, req *http.Request, ctx *web.Conte
http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther) http.Redirect(w, req, web.Reverse("RootIndex"), http.StatusSeeOther)
return nil return nil
} }
name, err := stringutil.ParseMailboxName(ctx.Vars["name"]) name, err := ctx.Manager.MailboxForAddress(ctx.Vars["name"])
if err != nil { if err != nil {
ctx.Session.AddFlash(err.Error(), "errors") ctx.Session.AddFlash(err.Error(), "errors")
_ = ctx.Session.Save(req, w) _ = ctx.Session.Save(req, w)