Add optional timeout to get PROXY header

The library user can define a maximum time to wait
for the PROXY protocol header, before failing out to
normal connection.

We can assume that a proxy in front of the service will
send the PROXY header immediatelly.

This solves the issue of clients getting block when
getting the RemoteAddr() for an incoming connection that
does not send any data. That is the case of http.Serve on
go < 1.6 as described in https://github.com/armon/go-proxyproto/issues/1
This commit is contained in:
Hector Rivas Gandara
2016-07-12 18:19:20 +01:00
parent 609d6338d3
commit 49fdb5cfab
3 changed files with 98 additions and 16 deletions

View File

@@ -28,7 +28,7 @@ Using the library is very simple:
list, err := net.Listen("tcp", "...") list, err := net.Listen("tcp", "...")
// Wrap listener in a proxyproto listener // Wrap listener in a proxyproto listener
proxyList := &proxyproto.Listener{list} proxyList := &proxyproto.Listener{Listener: list}
conn, err :=proxyList.Accept() conn, err :=proxyList.Accept()
... ...

View File

@@ -24,8 +24,12 @@ var (
// whose connections may be using the HAProxy Proxy Protocol (version 1). // whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return // If the connection is using the protocol, the RemoteAddr() will return
// the correct client address. // the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct { type Listener struct {
Listener net.Listener Listener net.Listener
ProxyHeaderTimeout time.Duration
} }
// Conn is used to wrap and underlying connection which // Conn is used to wrap and underlying connection which
@@ -37,6 +41,7 @@ type Conn struct {
dstAddr *net.TCPAddr dstAddr *net.TCPAddr
srcAddr *net.TCPAddr srcAddr *net.TCPAddr
once sync.Once once sync.Once
proxyHeaderTimeout time.Duration
} }
// Accept waits for and returns the next connection to the listener. // Accept waits for and returns the next connection to the listener.
@@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(conn), nil return NewConn(conn, p.ProxyHeaderTimeout), nil
} }
// Close closes the underlying listener. // Close closes the underlying listener.
@@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking // NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn // the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn) *Conn { func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{ pConn := &Conn{
bufReader: bufio.NewReader(conn), bufReader: bufio.NewReader(conn),
conn: conn, conn: conn,
proxyHeaderTimeout: timeout,
} }
return pConn return pConn
} }
@@ -125,12 +131,23 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
} }
func (p *Conn) checkPrefix() error { func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}
// Incrementally check each byte of the prefix // Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ { for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i) inp, err := p.bufReader.Peek(i)
if err != nil { if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err return err
} }
}
// Check for a prefix mis-match, quit early // Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) { if !bytes.Equal(inp, prefix[:i]) {

View File

@@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"testing" "testing"
"time"
) )
func TestPassthrough(t *testing.T) { func TestPassthrough(t *testing.T) {
@@ -13,7 +14,7 @@ func TestPassthrough(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -53,13 +54,77 @@ func TestPassthrough(t *testing.T) {
} }
} }
func TestTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
clientWriteDelay := 200 * time.Millisecond
proxyHeaderTimeout := 50 * time.Millisecond
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Do not send data for a while
time.Sleep(clientWriteDelay)
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr is the original 127.0.0.1
remoteAddrStartTime := time.Now()
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
remoteAddrDuration := time.Since(remoteAddrStartTime)
// Check RemoteAddr() call did timeout
if remoteAddrDuration >= clientWriteDelay {
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
}
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4(t *testing.T) { func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -118,7 +183,7 @@ func TestParse_ipv6(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -177,7 +242,7 @@ func TestParse_BadHeader(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())