Merge branch 'master' of https://github.com/armon/go-proxyproto into HEAD
This commit is contained in:
@@ -28,7 +28,7 @@ Using the library is very simple:
|
||||
list, err := net.Listen("tcp", "...")
|
||||
|
||||
// Wrap listener in a proxyproto listener
|
||||
proxyList := &proxyproto.Listener{list}
|
||||
proxyList := &proxyproto.Listener{Listener: list}
|
||||
conn, err :=proxyList.Accept()
|
||||
|
||||
...
|
||||
|
||||
22
protocol.go
22
protocol.go
@@ -24,8 +24,12 @@ var (
|
||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||
// If the connection is using the protocol, the RemoteAddr() will return
|
||||
// the correct client address.
|
||||
//
|
||||
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||
type Listener struct {
|
||||
Listener net.Listener
|
||||
ProxyHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// Conn is used to wrap and underlying connection which
|
||||
@@ -37,6 +41,7 @@ type Conn struct {
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
once sync.Once
|
||||
proxyHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
@@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(conn), nil
|
||||
return NewConn(conn, p.ProxyHeaderTimeout), nil
|
||||
}
|
||||
|
||||
// Close closes the underlying listener.
|
||||
@@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {
|
||||
|
||||
// NewConn is used to wrap a net.Conn that may be speaking
|
||||
// the proxy protocol into a proxyproto.Conn
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||
pConn := &Conn{
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
proxyHeaderTimeout: timeout,
|
||||
}
|
||||
return pConn
|
||||
}
|
||||
@@ -105,6 +111,7 @@ func (p *Conn) RemoteAddr() net.Addr {
|
||||
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
||||
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||
p.Close()
|
||||
p.bufReader = bufio.NewReader(p.conn)
|
||||
}
|
||||
})
|
||||
if p.srcAddr != nil {
|
||||
@@ -126,12 +133,23 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefix() error {
|
||||
if p.proxyHeaderTimeout != 0 {
|
||||
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||
p.conn.SetReadDeadline(readDeadLine)
|
||||
defer p.conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := p.bufReader.Peek(i)
|
||||
|
||||
if err != nil {
|
||||
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for a prefix mis-match, quit early
|
||||
if !bytes.Equal(inp, prefix[:i]) {
|
||||
|
||||
@@ -2,9 +2,9 @@ package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
@@ -13,7 +13,7 @@ func TestPassthrough(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -53,13 +53,77 @@ func TestPassthrough(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
clientWriteDelay := 200 * time.Millisecond
|
||||
proxyHeaderTimeout := 50 * time.Millisecond
|
||||
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Do not send data for a while
|
||||
time.Sleep(clientWriteDelay)
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr is the original 127.0.0.1
|
||||
remoteAddrStartTime := time.Now()
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
remoteAddrDuration := time.Since(remoteAddrStartTime)
|
||||
|
||||
// Check RemoteAddr() call did timeout
|
||||
if remoteAddrDuration >= clientWriteDelay {
|
||||
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
|
||||
}
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -118,7 +182,7 @@ func TestParse_ipv6(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -177,7 +241,7 @@ func TestParse_BadHeader(t *testing.T) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
pl := &Listener{Listener: l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
@@ -194,7 +258,7 @@ func TestParse_BadHeader(t *testing.T) {
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != io.EOF {
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user