diff --git a/smtpd/handler.go b/smtpd/handler.go index 5f5d757..8f787a0 100644 --- a/smtpd/handler.go +++ b/smtpd/handler.go @@ -24,6 +24,8 @@ const ( QUIT // Close session ) +const STAMP_FMT = "Mon, 02 Jan 2006 15:04:05 -0700 (MST)" + func (s State) String() string { switch s { case GREET: @@ -59,15 +61,16 @@ var commands = map[string]bool{ } type Session struct { - server *Server - id int - conn net.Conn - remoteHost string - sendError error - state State - reader *bufio.Reader - from string - recipients *list.List + server *Server + id int + conn net.Conn + remoteDomain string + remoteHost string + sendError error + state State + reader *bufio.Reader + from string + recipients *list.List } func NewSession(server *Server, id int, conn net.Conn) *Session { @@ -196,9 +199,21 @@ func (s *Server) startSession(id int, conn net.Conn) { func (ss *Session) greetHandler(cmd string, arg string) { switch cmd { case "HELO": + domain, err := parseHelloArgument(arg) + if err != nil { + ss.send("501 Domain/address argument required for HELO") + return + } + ss.remoteDomain = domain ss.send("250 Great, let's get this show on the road") ss.enterState(READY) case "EHLO": + domain, err := parseHelloArgument(arg) + if err != nil { + ss.send("501 Domain/address argument required for EHLO") + return + } + ss.remoteDomain = domain ss.send("250-Great, let's get this show on the road") ss.send("250-8BITMIME") ss.send(fmt.Sprintf("250 SIZE %v", ss.server.maxMessageBytes)) @@ -208,6 +223,17 @@ func (ss *Session) greetHandler(cmd string, arg string) { } } +func parseHelloArgument(arg string) (string, error) { + domain := arg + if idx := strings.IndexRune(arg, ' '); idx >= 0 { + domain = arg[:idx] + } + if domain == "" { + return "", fmt.Errorf("Invalid domain") + } + return domain, nil +} + // READY state -> waiting for MAIL func (ss *Session) readyHandler(cmd string, arg string) { if cmd == "MAIL" { @@ -305,11 +331,12 @@ func (ss *Session) mailHandler(cmd string, arg string) { // DATA func (ss *Session) dataHandler() { - msgSize := 0 - + // Timestamp for Received header + stamp := time.Now().Format(STAMP_FMT) // Get a Mailbox and a new Message for each recipient mailboxes := make([]Mailbox, ss.recipients.Len()) messages := make([]Message, ss.recipients.Len()) + msgSize := 0 if ss.server.storeMessages { i := 0 for e := ss.recipients.Front(); e != nil; e = e.Next() { @@ -332,6 +359,11 @@ func (ss *Session) dataHandler() { } mailboxes[i] = mb messages[i] = mb.NewMessage() + + // Generate Received header + recd := fmt.Sprintf("Received: from %s ([%s]) by %s\r\n for <%s>; %s\r\n", + ss.remoteDomain, ss.remoteHost, ss.server.domain, recip, stamp) + messages[i].Append([]byte(recd)) } else { log.LogTrace("Not storing message for %q", recip) } diff --git a/smtpd/handler_test.go b/smtpd/handler_test.go index 17ea35c..edf8639 100644 --- a/smtpd/handler_test.go +++ b/smtpd/handler_test.go @@ -33,6 +33,8 @@ func TestGreetState(t *testing.T) { // Test out some mangled HELOs script = []scriptStep{ + {"HELO", 501}, + {"EHLO", 501}, {"HELLO", 500}, {"HELL", 500}, {"hello", 500}, @@ -43,9 +45,6 @@ func TestGreetState(t *testing.T) { } // Valid HELOs - if err := playSession(t, server, []scriptStep{{"HELO", 250}}); err != nil { - t.Error(err) - } if err := playSession(t, server, []scriptStep{{"HELO mydomain", 250}}); err != nil { t.Error(err) } @@ -55,6 +54,23 @@ func TestGreetState(t *testing.T) { if err := playSession(t, server, []scriptStep{{"HelO mydom.com", 250}}); err != nil { t.Error(err) } + if err := playSession(t, server, []scriptStep{{"helo 127.0.0.1", 250}}); err != nil { + t.Error(err) + } + + // Valid EHLOs + if err := playSession(t, server, []scriptStep{{"EHLO mydomain", 250}}); err != nil { + t.Error(err) + } + if err := playSession(t, server, []scriptStep{{"EHLO mydom.com", 250}}); err != nil { + t.Error(err) + } + if err := playSession(t, server, []scriptStep{{"EhlO mydom.com", 250}}); err != nil { + t.Error(err) + } + if err := playSession(t, server, []scriptStep{{"ehlo 127.0.0.1", 250}}); err != nil { + t.Error(err) + } if t.Failed() { // Wait for handler to finish logging @@ -317,7 +333,7 @@ func playScriptAgainst(t *testing.T, c *textproto.Conn, script []scriptStep) err } c.StartResponse(id) - code, msg, err := c.ReadCodeLine(step.expect) + code, msg, err := c.ReadResponse(step.expect) if err != nil { err = fmt.Errorf("Step %d, sent %q, expected %v, got %v: %q", i, step.send, step.expect, code, msg)