diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dd2440d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.test +*~ diff --git a/README.md b/README.md index be1aed6..113423c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,15 @@ -go-proxyproto -============= +# proxyproto + +This library provides the `proxyproto` package which can be used for servers +listening behind HAProxy of Amazon ELB load balancers. Those load balancers +support the use of a proxy protocol (http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt), +which provides a simple mechansim for the server to get the address of the client +instead of the load balancer. + +This library provides both a net.Listener and net.Conn implementation that +can be used to handle situation in which you may be using the proxy protocol. + +The only caveat is that we check for the "PROXY " prefix to determine if the protocol +is being used. If that string may occur as part of your input, then it is ambiguous +if the protocol is being used and you may have problems. -Golang package to handle HAProxy Proxy Protocol diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..5175cd6 --- /dev/null +++ b/protocol.go @@ -0,0 +1,183 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "fmt" + "log" + "net" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // prefix is the string we look for at the start of a connection + // to check if this connection is using the proxy protocol + prefix = []byte("PROXY ") + prefixLen = len(prefix) +) + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol (version 1). +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. +type Listener struct { + Listener net.Listener +} + +// Conn is used to wrap and underlying connection which +// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will +// return the address of the client instead of the proxy address. +type Conn struct { + bufReader *bufio.Reader + conn net.Conn + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr + once sync.Once +} + +// Accept waits for and returns the next connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + return NewConn(conn), nil +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} + +// NewConn is used to wrap a net.Conn that may be speaking +// the proxy protocol into a proxyproto.Conn +func NewConn(conn net.Conn) *Conn { + pConn := &Conn{ + bufReader: bufio.NewReader(conn), + conn: conn, + } + return pConn +} + +func (p *Conn) Read(b []byte) (int, error) { + var err error + p.once.Do(func() { err = p.checkPrefix() }) + if err != nil { + return 0, err + } + return p.bufReader.Read(b) +} + +func (p *Conn) Write(b []byte) (int, error) { + return p.conn.Write(b) +} + +func (p *Conn) Close() error { + return p.conn.Close() +} + +func (p *Conn) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +func (p *Conn) RemoteAddr() net.Addr { + p.once.Do(func() { + if err := p.checkPrefix(); err != nil { + log.Printf("[ERR] Failed to read proxy prefix: %v", err) + } + }) + if p.srcAddr != nil { + return p.srcAddr + } + return p.conn.RemoteAddr() +} + +func (p *Conn) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +func (p *Conn) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +func (p *Conn) checkPrefix() error { + // Incrementally check each byte of the prefix + for i := 1; i <= prefixLen; i++ { + inp, err := p.bufReader.Peek(i) + if err != nil { + return err + } + + // Check for a prefix mis-match, quit early + if !bytes.Equal(inp, prefix[:i]) { + return nil + } + } + + // Read the header line + header, err := p.bufReader.ReadString('\n') + if err != nil { + p.conn.Close() + return err + } + + // Strip the carriage return and new line + header = header[:len(header)-2] + + // Split on spaces, should be (PROXY ) + parts := strings.Split(header, " ") + if len(parts) != 6 { + p.conn.Close() + return fmt.Errorf("Invalid header line: %s", header) + } + + // Verify the type is known + switch parts[1] { + case "TCP4": + case "TCP6": + default: + p.conn.Close() + return fmt.Errorf("Unhandled address type: %s", parts[1]) + } + + // Parse out the source address + ip := net.ParseIP(parts[2]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid source ip: %s", parts[2]) + } + port, err := strconv.Atoi(parts[4]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid source port: %s", parts[4]) + } + p.srcAddr = &net.TCPAddr{IP: ip, Port: port} + + // Parse out the destination address + ip = net.ParseIP(parts[3]) + if ip == nil { + p.conn.Close() + return fmt.Errorf("Invalid destination ip: %s", parts[3]) + } + port, err = strconv.Atoi(parts[5]) + if err != nil { + p.conn.Close() + return fmt.Errorf("Invalid destination port: %s", parts[5]) + } + p.dstAddr = &net.TCPAddr{IP: ip, Port: port} + + return nil +} diff --git a/protocol_test.go b/protocol_test.go new file mode 100644 index 0000000..ba70ee9 --- /dev/null +++ b/protocol_test.go @@ -0,0 +1,220 @@ +package proxyproto + +import ( + "bytes" + "io" + "net" + "testing" +) + +func TestPassthrough(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestParse_ipv4(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the remote addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "10.1.1.1" { + t.Fatalf("bad: %v", addr) + } + if addr.Port != 1000 { + t.Fatalf("bad: %v", addr) + } +} + +func TestParse_ipv6(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } + + // Check the remote addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "ffff::ffff" { + t.Fatalf("bad: %v", addr) + } + if addr.Port != 1000 { + t.Fatalf("bad: %v", addr) + } +} + +func TestParse_BadHeader(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{l} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Write out the header! + header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n" + conn.Write([]byte(header)) + + conn.Write([]byte("ping")) + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != io.EOF { + t.Fatalf("err: %v", err) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Check the remote addr, should be the local addr + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "127.0.0.1" { + t.Fatalf("bad: %v", addr) + } + + // Read should fail + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err == nil { + t.Fatalf("err: %v", err) + } +}