Merge branch 'master' of https://github.com/armon/go-proxyproto into HEAD

This commit is contained in:
Dolf Schimmel
2017-04-23 16:46:58 +02:00
3 changed files with 100 additions and 18 deletions

View File

@@ -28,7 +28,7 @@ Using the library is very simple:
list, err := net.Listen("tcp", "...") list, err := net.Listen("tcp", "...")
// Wrap listener in a proxyproto listener // Wrap listener in a proxyproto listener
proxyList := &proxyproto.Listener{list} proxyList := &proxyproto.Listener{Listener: list}
conn, err :=proxyList.Accept() conn, err :=proxyList.Accept()
... ...

View File

@@ -24,19 +24,24 @@ var (
// whose connections may be using the HAProxy Proxy Protocol (version 1). // whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return // If the connection is using the protocol, the RemoteAddr() will return
// the correct client address. // the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct { type Listener struct {
Listener net.Listener Listener net.Listener
ProxyHeaderTimeout time.Duration
} }
// Conn is used to wrap and underlying connection which // Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address. // return the address of the client instead of the proxy address.
type Conn struct { type Conn struct {
bufReader *bufio.Reader bufReader *bufio.Reader
conn net.Conn conn net.Conn
dstAddr *net.TCPAddr dstAddr *net.TCPAddr
srcAddr *net.TCPAddr srcAddr *net.TCPAddr
once sync.Once once sync.Once
proxyHeaderTimeout time.Duration
} }
// Accept waits for and returns the next connection to the listener. // Accept waits for and returns the next connection to the listener.
@@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(conn), nil return NewConn(conn, p.ProxyHeaderTimeout), nil
} }
// Close closes the underlying listener. // Close closes the underlying listener.
@@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking // NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn // the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn) *Conn { func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{ pConn := &Conn{
bufReader: bufio.NewReader(conn), bufReader: bufio.NewReader(conn),
conn: conn, conn: conn,
proxyHeaderTimeout: timeout,
} }
return pConn return pConn
} }
@@ -105,6 +111,7 @@ func (p *Conn) RemoteAddr() net.Addr {
if err := p.checkPrefix(); err != nil && err != io.EOF { if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err) log.Printf("[ERR] Failed to read proxy prefix: %v", err)
p.Close() p.Close()
p.bufReader = bufio.NewReader(p.conn)
} }
}) })
if p.srcAddr != nil { if p.srcAddr != nil {
@@ -126,11 +133,22 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
} }
func (p *Conn) checkPrefix() error { func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}
// Incrementally check each byte of the prefix // Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ { for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i) inp, err := p.bufReader.Peek(i)
if err != nil { if err != nil {
return err if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
} }
// Check for a prefix mis-match, quit early // Check for a prefix mis-match, quit early

View File

@@ -2,9 +2,9 @@ package proxyproto
import ( import (
"bytes" "bytes"
"io"
"net" "net"
"testing" "testing"
"time"
) )
func TestPassthrough(t *testing.T) { func TestPassthrough(t *testing.T) {
@@ -13,7 +13,7 @@ func TestPassthrough(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -53,13 +53,77 @@ func TestPassthrough(t *testing.T) {
} }
} }
func TestTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
clientWriteDelay := 200 * time.Millisecond
proxyHeaderTimeout := 50 * time.Millisecond
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Do not send data for a while
time.Sleep(clientWriteDelay)
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr is the original 127.0.0.1
remoteAddrStartTime := time.Now()
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
remoteAddrDuration := time.Since(remoteAddrStartTime)
// Check RemoteAddr() call did timeout
if remoteAddrDuration >= clientWriteDelay {
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
}
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4(t *testing.T) { func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -118,7 +182,7 @@ func TestParse_ipv6(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -177,7 +241,7 @@ func TestParse_BadHeader(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
pl := &Listener{l} pl := &Listener{Listener: l}
go func() { go func() {
conn, err := net.Dial("tcp", pl.Addr().String()) conn, err := net.Dial("tcp", pl.Addr().String())
@@ -194,7 +258,7 @@ func TestParse_BadHeader(t *testing.T) {
recv := make([]byte, 4) recv := make([]byte, 4)
_, err = conn.Read(recv) _, err = conn.Read(recv)
if err != io.EOF { if err == nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
}() }()