Merge branch 'master' of https://github.com/armon/go-proxyproto into HEAD
This commit is contained in:
@@ -28,7 +28,7 @@ Using the library is very simple:
|
|||||||
list, err := net.Listen("tcp", "...")
|
list, err := net.Listen("tcp", "...")
|
||||||
|
|
||||||
// Wrap listener in a proxyproto listener
|
// Wrap listener in a proxyproto listener
|
||||||
proxyList := &proxyproto.Listener{list}
|
proxyList := &proxyproto.Listener{Listener: list}
|
||||||
conn, err :=proxyList.Accept()
|
conn, err :=proxyList.Accept()
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|||||||
22
protocol.go
22
protocol.go
@@ -24,8 +24,12 @@ var (
|
|||||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||||
// If the connection is using the protocol, the RemoteAddr() will return
|
// If the connection is using the protocol, the RemoteAddr() will return
|
||||||
// the correct client address.
|
// the correct client address.
|
||||||
|
//
|
||||||
|
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||||
|
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
Listener net.Listener
|
Listener net.Listener
|
||||||
|
ProxyHeaderTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn is used to wrap and underlying connection which
|
// Conn is used to wrap and underlying connection which
|
||||||
@@ -37,6 +41,7 @@ type Conn struct {
|
|||||||
dstAddr *net.TCPAddr
|
dstAddr *net.TCPAddr
|
||||||
srcAddr *net.TCPAddr
|
srcAddr *net.TCPAddr
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
proxyHeaderTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept waits for and returns the next connection to the listener.
|
// Accept waits for and returns the next connection to the listener.
|
||||||
@@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewConn(conn), nil
|
return NewConn(conn, p.ProxyHeaderTimeout), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying listener.
|
// Close closes the underlying listener.
|
||||||
@@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {
|
|||||||
|
|
||||||
// NewConn is used to wrap a net.Conn that may be speaking
|
// NewConn is used to wrap a net.Conn that may be speaking
|
||||||
// the proxy protocol into a proxyproto.Conn
|
// the proxy protocol into a proxyproto.Conn
|
||||||
func NewConn(conn net.Conn) *Conn {
|
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||||
pConn := &Conn{
|
pConn := &Conn{
|
||||||
bufReader: bufio.NewReader(conn),
|
bufReader: bufio.NewReader(conn),
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
proxyHeaderTimeout: timeout,
|
||||||
}
|
}
|
||||||
return pConn
|
return pConn
|
||||||
}
|
}
|
||||||
@@ -105,6 +111,7 @@ func (p *Conn) RemoteAddr() net.Addr {
|
|||||||
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
||||||
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||||
p.Close()
|
p.Close()
|
||||||
|
p.bufReader = bufio.NewReader(p.conn)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if p.srcAddr != nil {
|
if p.srcAddr != nil {
|
||||||
@@ -126,12 +133,23 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Conn) checkPrefix() error {
|
func (p *Conn) checkPrefix() error {
|
||||||
|
if p.proxyHeaderTimeout != 0 {
|
||||||
|
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||||
|
p.conn.SetReadDeadline(readDeadLine)
|
||||||
|
defer p.conn.SetReadDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
|
||||||
// Incrementally check each byte of the prefix
|
// Incrementally check each byte of the prefix
|
||||||
for i := 1; i <= prefixLen; i++ {
|
for i := 1; i <= prefixLen; i++ {
|
||||||
inp, err := p.bufReader.Peek(i)
|
inp, err := p.bufReader.Peek(i)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check for a prefix mis-match, quit early
|
// Check for a prefix mis-match, quit early
|
||||||
if !bytes.Equal(inp, prefix[:i]) {
|
if !bytes.Equal(inp, prefix[:i]) {
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package proxyproto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPassthrough(t *testing.T) {
|
func TestPassthrough(t *testing.T) {
|
||||||
@@ -13,7 +13,7 @@ func TestPassthrough(t *testing.T) {
|
|||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pl := &Listener{l}
|
pl := &Listener{Listener: l}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||||
@@ -53,13 +53,77 @@ func TestPassthrough(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeout(t *testing.T) {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientWriteDelay := 200 * time.Millisecond
|
||||||
|
proxyHeaderTimeout := 50 * time.Millisecond
|
||||||
|
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Do not send data for a while
|
||||||
|
time.Sleep(clientWriteDelay)
|
||||||
|
|
||||||
|
conn.Write([]byte("ping"))
|
||||||
|
recv := make([]byte, 4)
|
||||||
|
_, err = conn.Read(recv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(recv, []byte("pong")) {
|
||||||
|
t.Fatalf("bad: %v", recv)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := pl.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Check the remote addr is the original 127.0.0.1
|
||||||
|
remoteAddrStartTime := time.Now()
|
||||||
|
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||||
|
if addr.IP.String() != "127.0.0.1" {
|
||||||
|
t.Fatalf("bad: %v", addr)
|
||||||
|
}
|
||||||
|
remoteAddrDuration := time.Since(remoteAddrStartTime)
|
||||||
|
|
||||||
|
// Check RemoteAddr() call did timeout
|
||||||
|
if remoteAddrDuration >= clientWriteDelay {
|
||||||
|
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
recv := make([]byte, 4)
|
||||||
|
_, err = conn.Read(recv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(recv, []byte("ping")) {
|
||||||
|
t.Fatalf("bad: %v", recv)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParse_ipv4(t *testing.T) {
|
func TestParse_ipv4(t *testing.T) {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pl := &Listener{l}
|
pl := &Listener{Listener: l}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||||
@@ -118,7 +182,7 @@ func TestParse_ipv6(t *testing.T) {
|
|||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pl := &Listener{l}
|
pl := &Listener{Listener: l}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||||
@@ -177,7 +241,7 @@ func TestParse_BadHeader(t *testing.T) {
|
|||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pl := &Listener{l}
|
pl := &Listener{Listener: l}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||||
@@ -194,7 +258,7 @@ func TestParse_BadHeader(t *testing.T) {
|
|||||||
|
|
||||||
recv := make([]byte, 4)
|
recv := make([]byte, 4)
|
||||||
_, err = conn.Read(recv)
|
_, err = conn.Read(recv)
|
||||||
if err != io.EOF {
|
if err == nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
Reference in New Issue
Block a user