1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-17 17:47:03 +00:00

Make use of pkg context

- Use context inside of servers for shutdown
- Remove unnecessary localShutdown related code
This commit is contained in:
James Hillyerd
2017-01-15 21:49:04 -08:00
parent 0e02061c4a
commit a222b7c428
4 changed files with 30 additions and 44 deletions

View File

@@ -2,6 +2,7 @@
package httpd package httpd
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -60,7 +61,7 @@ func Initialize(cfg config.WebConfig, ds smtpd.DataStore, shutdownChan chan bool
} }
// Start begins listening for HTTP requests // Start begins listening for HTTP requests
func Start() { func Start(ctx context.Context) {
addr := fmt.Sprintf("%v:%v", webConfig.IP4address, webConfig.IP4port) addr := fmt.Sprintf("%v:%v", webConfig.IP4address, webConfig.IP4port)
server = &http.Server{ server = &http.Server{
Addr: addr, Addr: addr,
@@ -80,11 +81,11 @@ func Start() {
} }
// Listener go routine // Listener go routine
go serve() go serve(ctx)
// Wait for shutdown // Wait for shutdown
select { select {
case _ = <-globalShutdown: case _ = <-ctx.Done():
log.Tracef("HTTP server shutting down on request") log.Tracef("HTTP server shutting down on request")
} }
@@ -95,12 +96,12 @@ func Start() {
} }
// serve begins serving HTTP requests // serve begins serving HTTP requests
func serve() { func serve(ctx context.Context) {
// server.Serve blocks until we close the listener // server.Serve blocks until we close the listener
err := server.Serve(listener) err := server.Serve(listener)
select { select {
case _ = <-globalShutdown: case _ = <-ctx.Done():
// Nop // Nop
default: default:
log.Errorf("HTTP server failed: %v", err) log.Errorf("HTTP server failed: %v", err)

View File

@@ -2,6 +2,7 @@
package main package main
import ( import (
"context"
"expvar" "expvar"
"flag" "flag"
"fmt" "fmt"
@@ -52,6 +53,9 @@ func main() {
return return
} }
// Root context
rootCtx, rootCancel := context.WithCancel(context.Background())
// Load & Parse config // Load & Parse config
if flag.NArg() != 1 { if flag.NArg() != 1 {
flag.Usage() flag.Usage()
@@ -98,16 +102,16 @@ func main() {
httpd.Initialize(config.GetWebConfig(), ds, shutdownChan) httpd.Initialize(config.GetWebConfig(), ds, shutdownChan)
webui.SetupRoutes(httpd.Router) webui.SetupRoutes(httpd.Router)
rest.SetupRoutes(httpd.Router) rest.SetupRoutes(httpd.Router)
go httpd.Start() go httpd.Start(rootCtx)
// Start POP3 server // Start POP3 server
// TODO pass datastore // TODO pass datastore
pop3Server = pop3d.New(shutdownChan) pop3Server = pop3d.New(shutdownChan)
go pop3Server.Start() go pop3Server.Start(rootCtx)
// Startup SMTP server // Startup SMTP server
smtpServer = smtpd.NewServer(config.GetSMTPConfig(), ds, shutdownChan) smtpServer = smtpd.NewServer(config.GetSMTPConfig(), ds, shutdownChan)
go smtpServer.Start() go smtpServer.Start(rootCtx)
// Loop forever waiting for signals or shutdown channel // Loop forever waiting for signals or shutdown channel
signalLoop: signalLoop:
@@ -128,6 +132,7 @@ signalLoop:
close(shutdownChan) close(shutdownChan)
} }
case _ = <-shutdownChan: case _ = <-shutdownChan:
rootCancel()
break signalLoop break signalLoop
} }
} }

View File

@@ -1,6 +1,7 @@
package pop3d package pop3d
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@@ -18,7 +19,6 @@ type Server struct {
dataStore smtpd.DataStore dataStore smtpd.DataStore
listener net.Listener listener net.Listener
globalShutdown chan bool globalShutdown chan bool
localShutdown chan bool
waitgroup *sync.WaitGroup waitgroup *sync.WaitGroup
} }
@@ -35,20 +35,17 @@ func New(shutdownChan chan bool) *Server {
dataStore: ds, dataStore: ds,
maxIdleSeconds: cfg.MaxIdleSeconds, maxIdleSeconds: cfg.MaxIdleSeconds,
globalShutdown: shutdownChan, globalShutdown: shutdownChan,
localShutdown: make(chan bool),
waitgroup: new(sync.WaitGroup), waitgroup: new(sync.WaitGroup),
} }
} }
// Start the server and listen for connections // Start the server and listen for connections
func (s *Server) Start() { func (s *Server) Start(ctx context.Context) {
cfg := config.GetPOP3Config() cfg := config.GetPOP3Config()
addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%v:%v", addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%v:%v",
cfg.IP4address, cfg.IP4port)) cfg.IP4address, cfg.IP4port))
if err != nil { if err != nil {
log.Errorf("POP3 Failed to build tcp4 address: %v", err) log.Errorf("POP3 Failed to build tcp4 address: %v", err)
// serve() never called, so we do local shutdown here
close(s.localShutdown)
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
@@ -57,18 +54,16 @@ func (s *Server) Start() {
s.listener, err = net.ListenTCP("tcp4", addr) s.listener, err = net.ListenTCP("tcp4", addr)
if err != nil { if err != nil {
log.Errorf("POP3 failed to start tcp4 listener: %v", err) log.Errorf("POP3 failed to start tcp4 listener: %v", err)
// serve() never called, so we do local shutdown here
close(s.localShutdown)
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
// Listener go routine // Listener go routine
go s.serve() go s.serve(ctx)
// Wait for shutdown // Wait for shutdown
select { select {
case _ = <-s.globalShutdown: case _ = <-ctx.Done():
} }
log.Tracef("POP3 shutdown requested, connections will be drained") log.Tracef("POP3 shutdown requested, connections will be drained")
@@ -79,7 +74,7 @@ func (s *Server) Start() {
} }
// serve is the listen/accept loop // serve is the listen/accept loop
func (s *Server) serve() { func (s *Server) serve(ctx context.Context) {
// Handle incoming connections // Handle incoming connections
var tempDelay time.Duration var tempDelay time.Duration
for sid := 1; ; sid++ { for sid := 1; ; sid++ {
@@ -100,11 +95,11 @@ func (s *Server) serve() {
} else { } else {
// Permanent error // Permanent error
select { select {
case _ = <-s.globalShutdown: case <-ctx.Done():
close(s.localShutdown) // POP3 is shutting down
return return
default: default:
close(s.localShutdown) // Something went wrong
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
@@ -128,10 +123,6 @@ func (s *Server) emergencyShutdown() {
// Drain causes the caller to block until all active POP3 sessions have finished // Drain causes the caller to block until all active POP3 sessions have finished
func (s *Server) Drain() { func (s *Server) Drain() {
// Wait for listener to exit
select {
case _ = <-s.localShutdown:
}
// Wait for sessions to close // Wait for sessions to close
s.waitgroup.Wait() s.waitgroup.Wait()
log.Tracef("POP3 connections have drained") log.Tracef("POP3 connections have drained")

View File

@@ -2,6 +2,7 @@ package smtpd
import ( import (
"container/list" "container/list"
"context"
"expvar" "expvar"
"fmt" "fmt"
"net" "net"
@@ -27,9 +28,6 @@ type Server struct {
// globalShutdown is the signal Inbucket needs to shut down // globalShutdown is the signal Inbucket needs to shut down
globalShutdown chan bool globalShutdown chan bool
// localShutdown indicates this component has completed shutting down
localShutdown chan bool
// waitgroup tracks individual sessions // waitgroup tracks individual sessions
waitgroup *sync.WaitGroup waitgroup *sync.WaitGroup
} }
@@ -67,19 +65,16 @@ func NewServer(cfg config.SMTPConfig, ds DataStore, globalShutdown chan bool) *S
domainNoStore: strings.ToLower(cfg.DomainNoStore), domainNoStore: strings.ToLower(cfg.DomainNoStore),
waitgroup: new(sync.WaitGroup), waitgroup: new(sync.WaitGroup),
globalShutdown: globalShutdown, globalShutdown: globalShutdown,
localShutdown: make(chan bool),
} }
} }
// Start the listener and handle incoming connections // Start the listener and handle incoming connections
func (s *Server) Start() { func (s *Server) Start(ctx context.Context) {
cfg := config.GetSMTPConfig() cfg := config.GetSMTPConfig()
addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%v:%v", addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%v:%v",
cfg.IP4address, cfg.IP4port)) cfg.IP4address, cfg.IP4port))
if err != nil { if err != nil {
log.Errorf("Failed to build tcp4 address: %v", err) log.Errorf("Failed to build tcp4 address: %v", err)
// serve() never called, so we do local shutdown here
close(s.localShutdown)
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
@@ -88,8 +83,6 @@ func (s *Server) Start() {
s.listener, err = net.ListenTCP("tcp4", addr) s.listener, err = net.ListenTCP("tcp4", addr)
if err != nil { if err != nil {
log.Errorf("SMTP failed to start tcp4 listener: %v", err) log.Errorf("SMTP failed to start tcp4 listener: %v", err)
// serve() never called, so we do local shutdown here
close(s.localShutdown)
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
@@ -104,11 +97,11 @@ func (s *Server) Start() {
StartRetentionScanner(s.dataStore, s.globalShutdown) StartRetentionScanner(s.dataStore, s.globalShutdown)
// Listener go routine // Listener go routine
go s.serve() go s.serve(ctx)
// Wait for shutdown // Wait for shutdown
select { select {
case _ = <-s.globalShutdown: case <-ctx.Done():
log.Tracef("SMTP shutdown requested, connections will be drained") log.Tracef("SMTP shutdown requested, connections will be drained")
} }
@@ -119,7 +112,7 @@ func (s *Server) Start() {
} }
// serve is the listen/accept loop // serve is the listen/accept loop
func (s *Server) serve() { func (s *Server) serve(ctx context.Context) {
// Handle incoming connections // Handle incoming connections
var tempDelay time.Duration var tempDelay time.Duration
for sessionID := 1; ; sessionID++ { for sessionID := 1; ; sessionID++ {
@@ -141,11 +134,11 @@ func (s *Server) serve() {
} else { } else {
// Permanent error // Permanent error
select { select {
case _ = <-s.globalShutdown: case <-ctx.Done():
close(s.localShutdown) // SMTP is shutting down
return return
default: default:
close(s.localShutdown) // Something went wrong
s.emergencyShutdown() s.emergencyShutdown()
return return
} }
@@ -170,10 +163,6 @@ func (s *Server) emergencyShutdown() {
// Drain causes the caller to block until all active SMTP sessions have finished // Drain causes the caller to block until all active SMTP sessions have finished
func (s *Server) Drain() { func (s *Server) Drain() {
// Wait for listener to exit
select {
case _ = <-s.localShutdown:
}
// Wait for sessions to close // Wait for sessions to close
s.waitgroup.Wait() s.waitgroup.Wait()
log.Tracef("SMTP connections have drained") log.Tracef("SMTP connections have drained")