Initial commit

This commit is contained in:
Armon Dadgar
2014-03-07 16:45:16 -08:00
parent debe0b5149
commit 4e433eb4da
4 changed files with 419 additions and 3 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*.test
*~

View File

@@ -1,4 +1,15 @@
go-proxyproto # proxyproto
=============
This library provides the `proxyproto` package which can be used for servers
listening behind HAProxy of Amazon ELB load balancers. Those load balancers
support the use of a proxy protocol (http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt),
which provides a simple mechansim for the server to get the address of the client
instead of the load balancer.
This library provides both a net.Listener and net.Conn implementation that
can be used to handle situation in which you may be using the proxy protocol.
The only caveat is that we check for the "PROXY " prefix to determine if the protocol
is being used. If that string may occur as part of your input, then it is ambiguous
if the protocol is being used and you may have problems.
Golang package to handle HAProxy Proxy Protocol

183
protocol.go Normal file
View File

@@ -0,0 +1,183 @@
package proxyproto
import (
"bufio"
"bytes"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
)
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
type Listener struct {
Listener net.Listener
}
// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
once sync.Once
}
// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
return NewConn(conn), nil
}
// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
}
return pConn
}
func (p *Conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.Read(b)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
func (p *Conn) Close() error {
return p.conn.Close()
}
func (p *Conn) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
func (p *Conn) RemoteAddr() net.Addr {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
}
})
if p.srcAddr != nil {
return p.srcAddr
}
return p.conn.RemoteAddr()
}
func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}
func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *Conn) checkPrefix() error {
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
return err
}
// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
return nil
}
}
// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.conn.Close()
return err
}
// Strip the carriage return and new line
header = header[:len(header)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
return nil
}

220
protocol_test.go Normal file
View File

@@ -0,0 +1,220 @@
package proxyproto
import (
"bytes"
"io"
"net"
"testing"
)
func TestPassthrough(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "10.1.1.1" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_ipv6(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "ffff::ffff" {
t.Fatalf("bad: %v", addr)
}
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
}
func TestParse_BadHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
pl := &Listener{l}
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write out the header!
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
conn.Write([]byte(header))
conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != io.EOF {
t.Fatalf("err: %v", err)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Check the remote addr, should be the local addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
// Read should fail
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err == nil {
t.Fatalf("err: %v", err)
}
}