Add optional timeout to get PROXY header
The library user can define a maximum time to wait for the PROXY protocol header, before failing out to normal connection. We can assume that a proxy in front of the service will send the PROXY header immediatelly. This solves the issue of clients getting block when getting the RemoteAddr() for an incoming connection that does not send any data. That is the case of http.Serve on go < 1.6 as described in https://github.com/armon/go-proxyproto/issues/1
This commit is contained in:
@@ -28,7 +28,7 @@ Using the library is very simple:
|
||||
list, err := net.Listen("tcp", "...")
|
||||
|
||||
// Wrap listener in a proxyproto listener
|
||||
proxyList := &proxyproto.Listener{list}
|
||||
proxyList := &proxyproto.Listener{Listener: list}
|
||||
conn, err :=proxyList.Accept()
|
||||
|
||||
...
|
||||
|
||||
39
protocol.go
39
protocol.go
@@ -24,19 +24,24 @@ var (
|
||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||
// If the connection is using the protocol, the RemoteAddr() will return
|
||||
// the correct client address.
|
||||
//
|
||||
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||
type Listener struct {
|
||||
Listener net.Listener
|
||||
Listener net.Listener
|
||||
ProxyHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// Conn is used to wrap and underlying connection which
|
||||
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
|
||||
// return the address of the client instead of the proxy address.
|
||||
type Conn struct {
|
||||
bufReader *bufio.Reader
|
||||
conn net.Conn
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
once sync.Once
|
||||
bufReader *bufio.Reader
|
||||
conn net.Conn
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
once sync.Once
|
||||
proxyHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
@@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(conn), nil
|
||||
return NewConn(conn, p.ProxyHeaderTimeout), nil
|
||||
}
|
||||
|
||||
// Close closes the underlying listener.
|
||||
@@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {
|
||||
|
||||
// NewConn is used to wrap a net.Conn that may be speaking
|
||||
// the proxy protocol into a proxyproto.Conn
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||
pConn := &Conn{
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
proxyHeaderTimeout: timeout,
|
||||
}
|
||||
return pConn
|
||||
}
|
||||
@@ -125,11 +131,22 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefix() error {
|
||||
if p.proxyHeaderTimeout != 0 {
|
||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||
p.conn.SetReadDeadline(readDeadLine)
|
||||
defer p.conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := p.bufReader.Peek(i)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for a prefix mis-match, quit early
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
@@ -13,7 +14,7 @@ func TestPassthrough(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -53,13 +54,77 @@ func TestPassthrough(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
clientWriteDelay := 200 * time.Millisecond
|
||||
proxyHeaderTimeout := 50 * time.Millisecond
|
||||
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Do not send data for a while
|
||||
time.Sleep(clientWriteDelay)
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr is the original 127.0.0.1
|
||||
remoteAddrStartTime := time.Now()
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
remoteAddrDuration := time.Since(remoteAddrStartTime)
|
||||
|
||||
// Check RemoteAddr() call did timeout
|
||||
if remoteAddrDuration >= clientWriteDelay {
|
||||
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
|
||||
}
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -118,7 +183,7 @@ func TestParse_ipv6(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -177,7 +242,7 @@ func TestParse_BadHeader(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
|
||||
Reference in New Issue
Block a user