1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-17 14:37:02 +00:00

Implement HAProxy protocol support

This patch implements support for incoming connections wrapped in the
HAProxy protocol v1.

This is useful when running chasquid behind a HAProxy server, as it
needs the original source IP to perform SPF checks.

This patch is a reimplementation of one originally provided by Denys
Vitali in pull request #15, except the logic for the protocol handling
is moved to a new package, and the smtpsrv.Conn handling of the source
IP is simplified.

It is marked as experimental for now, since we want to give it a bit
more exposure just in case the option/api needs adjustment.

Thanks a lot to Denys Vitali (@denysvitali in github) for sending the
original patch for this, and helping test it!
This commit is contained in:
Alberto Bertogli
2020-11-12 22:00:46 +00:00
parent c9d3ba0ca0
commit e79586a014
22 changed files with 389 additions and 24 deletions

View File

@@ -0,0 +1,76 @@
// Package haproxy implements the handshake for the HAProxy client protocol
// version 1, as described in
// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt.
package haproxy
import (
"bufio"
"errors"
"net"
"strconv"
"strings"
)
var (
errInvalidProtoID = errors.New("invalid protocol identifier")
errUnkProtocol = errors.New("unknown protocol")
errInvalidFields = errors.New("invalid number of fields")
errInvalidSrcIP = errors.New("invalid src ip")
errInvalidDstIP = errors.New("invalid dst ip")
errInvalidSrcPort = errors.New("invalid src port")
errInvalidDstPort = errors.New("invalid dst port")
)
// Handshake performs the HAProxy protocol v1 handshake on the given reader,
// which is expected to be backed by a network connection.
// It returns the source and destination addresses, or an error if the
// handshake could not complete.
// Note that any timeouts or limits must be set by the caller on the
// underlying connection, this is helper only to perform the handshake.
func Handshake(r *bufio.Reader) (src, dst net.Addr, err error) {
line, err := r.ReadString('\n')
if err != nil {
return nil, nil, err
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "PROXY" {
return nil, nil, errInvalidProtoID
}
switch fields[1] {
case "TCP4", "TCP6":
// Allowed to continue, nothing to do.
default:
return nil, nil, errUnkProtocol
}
if len(fields) != 6 {
return nil, nil, errInvalidFields
}
srcIP := net.ParseIP(fields[2])
if srcIP == nil {
return nil, nil, errInvalidSrcIP
}
dstIP := net.ParseIP(fields[3])
if dstIP == nil {
return nil, nil, errInvalidDstIP
}
srcPort, err := strconv.ParseUint(fields[4], 10, 16)
if err != nil {
return nil, nil, errInvalidSrcPort
}
dstPort, err := strconv.ParseUint(fields[5], 10, 16)
if err != nil {
return nil, nil, errInvalidDstPort
}
src = &net.TCPAddr{IP: srcIP, Port: int(srcPort)}
dst = &net.TCPAddr{IP: dstIP, Port: int(dstPort)}
return src, dst, nil
}

View File

@@ -0,0 +1,97 @@
package haproxy
import (
"bufio"
"io"
"net"
"strings"
"testing"
)
func TestNoNewline(t *testing.T) {
r := bufio.NewReader(strings.NewReader("PROXY "))
_, _, err := Handshake(r)
if err != io.EOF {
t.Errorf("expected EOF, got %v", err)
}
}
func TestBasic(t *testing.T) {
var (
src4, _ = net.ResolveTCPAddr("tcp", "1.1.1.1:3333")
dst4, _ = net.ResolveTCPAddr("tcp", "2.2.2.2:4444")
src6, _ = net.ResolveTCPAddr("tcp", "[5::5]:7777")
dst6, _ = net.ResolveTCPAddr("tcp", "[6::6]:8888")
)
cases := []struct {
str string
src, dst net.Addr
err error
}{
// Early line errors.
{"", nil, nil, errInvalidProtoID},
{"lalala", nil, nil, errInvalidProtoID},
{"PROXY", nil, nil, errInvalidProtoID},
{"PROXY lalala", nil, nil, errUnkProtocol},
{"PROXY UNKNOWN", nil, nil, errUnkProtocol},
// Number of field errors.
{"PROXY TCP4", nil, nil, errInvalidFields},
{"PROXY TCP4 a", nil, nil, errInvalidFields},
{"PROXY TCP4 a b", nil, nil, errInvalidFields},
{"PROXY TCP4 a b c", nil, nil, errInvalidFields},
// Parsing of ipv4 addresses.
{"PROXY TCP4 a b c d", nil, nil, errInvalidSrcIP},
{"PROXY TCP4 1.1.1.1 b c d",
nil, nil, errInvalidDstIP},
{"PROXY TCP4 1.1.1.1 2.2.2.2 c d",
nil, nil, errInvalidSrcPort},
{"PROXY TCP4 1.1.1.1 2.2.2.2 3333 d",
nil, nil, errInvalidDstPort},
{"PROXY TCP4 1.1.1.1 2.2.2.2 3333 4444",
src4, dst4, nil},
// Parsing of ipv6 addresses.
{"PROXY TCP6 a b c d", nil, nil, errInvalidSrcIP},
{"PROXY TCP6 5::5 b c d",
nil, nil, errInvalidDstIP},
{"PROXY TCP6 5::5 6::6 c d",
nil, nil, errInvalidSrcPort},
{"PROXY TCP6 5::5 6::6 7777 d",
nil, nil, errInvalidDstPort},
{"PROXY TCP6 5::5 6::6 7777 8888",
src6, dst6, nil},
}
for i, c := range cases {
t.Logf("testing %d: %v", i, c.str)
src, dst, err := Handshake(newR(c.str))
if !addrEq(src, c.src) {
t.Errorf("%d: got src %v, expected %v", i, src, c.src)
}
if !addrEq(dst, c.dst) {
t.Errorf("%d: got dst %v, expected %v", i, dst, c.dst)
}
if err != c.err {
t.Errorf("%d: got error %v, expected %v", i, err, c.err)
}
}
}
func newR(s string) *bufio.Reader {
return bufio.NewReader(strings.NewReader(s + "\r\n"))
}
func addrEq(a, b net.Addr) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
ta := a.(*net.TCPAddr)
tb := b.(*net.TCPAddr)
return ta.IP.Equal(tb.IP) && ta.Port == tb.Port
}