Initial commit
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.test
|
||||
*~
|
||||
17
README.md
17
README.md
@@ -1,4 +1,15 @@
|
||||
go-proxyproto
|
||||
=============
|
||||
# proxyproto
|
||||
|
||||
This library provides the `proxyproto` package which can be used for servers
|
||||
listening behind HAProxy of Amazon ELB load balancers. Those load balancers
|
||||
support the use of a proxy protocol (http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt),
|
||||
which provides a simple mechansim for the server to get the address of the client
|
||||
instead of the load balancer.
|
||||
|
||||
This library provides both a net.Listener and net.Conn implementation that
|
||||
can be used to handle situation in which you may be using the proxy protocol.
|
||||
|
||||
The only caveat is that we check for the "PROXY " prefix to determine if the protocol
|
||||
is being used. If that string may occur as part of your input, then it is ambiguous
|
||||
if the protocol is being used and you may have problems.
|
||||
|
||||
Golang package to handle HAProxy Proxy Protocol
|
||||
|
||||
183
protocol.go
Normal file
183
protocol.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// prefix is the string we look for at the start of a connection
|
||||
// to check if this connection is using the proxy protocol
|
||||
prefix = []byte("PROXY ")
|
||||
prefixLen = len(prefix)
|
||||
)
|
||||
|
||||
// Listener is used to wrap an underlying listener,
|
||||
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||
// If the connection is using the protocol, the RemoteAddr() will return
|
||||
// the correct client address.
|
||||
type Listener struct {
|
||||
Listener net.Listener
|
||||
}
|
||||
|
||||
// Conn is used to wrap and underlying connection which
|
||||
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
|
||||
// return the address of the client instead of the proxy address.
|
||||
type Conn struct {
|
||||
bufReader *bufio.Reader
|
||||
conn net.Conn
|
||||
dstAddr *net.TCPAddr
|
||||
srcAddr *net.TCPAddr
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (p *Listener) Accept() (net.Conn, error) {
|
||||
// Get the underlying connection
|
||||
conn, err := p.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(conn), nil
|
||||
}
|
||||
|
||||
// Close closes the underlying listener.
|
||||
func (p *Listener) Close() error {
|
||||
return p.Listener.Close()
|
||||
}
|
||||
|
||||
// Addr returns the underlying listener's network address.
|
||||
func (p *Listener) Addr() net.Addr {
|
||||
return p.Listener.Addr()
|
||||
}
|
||||
|
||||
// NewConn is used to wrap a net.Conn that may be speaking
|
||||
// the proxy protocol into a proxyproto.Conn
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
pConn := &Conn{
|
||||
bufReader: bufio.NewReader(conn),
|
||||
conn: conn,
|
||||
}
|
||||
return pConn
|
||||
}
|
||||
|
||||
func (p *Conn) Read(b []byte) (int, error) {
|
||||
var err error
|
||||
p.once.Do(func() { err = p.checkPrefix() })
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.bufReader.Read(b)
|
||||
}
|
||||
|
||||
func (p *Conn) Write(b []byte) (int, error) {
|
||||
return p.conn.Write(b)
|
||||
}
|
||||
|
||||
func (p *Conn) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func (p *Conn) LocalAddr() net.Addr {
|
||||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (p *Conn) RemoteAddr() net.Addr {
|
||||
p.once.Do(func() {
|
||||
if err := p.checkPrefix(); err != nil {
|
||||
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||
}
|
||||
})
|
||||
if p.srcAddr != nil {
|
||||
return p.srcAddr
|
||||
}
|
||||
return p.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (p *Conn) SetDeadline(t time.Time) error {
|
||||
return p.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetReadDeadline(t time.Time) error {
|
||||
return p.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return p.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (p *Conn) checkPrefix() error {
|
||||
// Incrementally check each byte of the prefix
|
||||
for i := 1; i <= prefixLen; i++ {
|
||||
inp, err := p.bufReader.Peek(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for a prefix mis-match, quit early
|
||||
if !bytes.Equal(inp, prefix[:i]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Read the header line
|
||||
header, err := p.bufReader.ReadString('\n')
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Strip the carriage return and new line
|
||||
header = header[:len(header)-2]
|
||||
|
||||
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||
parts := strings.Split(header, " ")
|
||||
if len(parts) != 6 {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid header line: %s", header)
|
||||
}
|
||||
|
||||
// Verify the type is known
|
||||
switch parts[1] {
|
||||
case "TCP4":
|
||||
case "TCP6":
|
||||
default:
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Unhandled address type: %s", parts[1])
|
||||
}
|
||||
|
||||
// Parse out the source address
|
||||
ip := net.ParseIP(parts[2])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source ip: %s", parts[2])
|
||||
}
|
||||
port, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid source port: %s", parts[4])
|
||||
}
|
||||
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
// Parse out the destination address
|
||||
ip = net.ParseIP(parts[3])
|
||||
if ip == nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination ip: %s", parts[3])
|
||||
}
|
||||
port, err = strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
p.conn.Close()
|
||||
return fmt.Errorf("Invalid destination port: %s", parts[5])
|
||||
}
|
||||
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||
|
||||
return nil
|
||||
}
|
||||
220
protocol_test.go
Normal file
220
protocol_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPassthrough(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "10.1.1.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_ipv6(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("pong")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(recv, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", recv)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the remote addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "ffff::ffff" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
if addr.Port != 1000 {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_BadHeader(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
pl := &Listener{l}
|
||||
|
||||
go func() {
|
||||
conn, err := net.Dial("tcp", pl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write out the header!
|
||||
header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
|
||||
conn.Write([]byte(header))
|
||||
|
||||
conn.Write([]byte("ping"))
|
||||
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := pl.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Check the remote addr, should be the local addr
|
||||
addr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if addr.IP.String() != "127.0.0.1" {
|
||||
t.Fatalf("bad: %v", addr)
|
||||
}
|
||||
|
||||
// Read should fail
|
||||
recv := make([]byte, 4)
|
||||
_, err = conn.Read(recv)
|
||||
if err == nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user