1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-17 14:37:02 +00:00
Files
go-chasquid-smtp/test/util/minidns.go
Alberto Bertogli ed38945fca courier: Use DNSError.IsNotFound to identify NXDOMAIN
When resolving MX records, we need to distinguish between "no such
domain" and other kinds of errors. Before Go 1.13, this was not
possible, so we had a workaround that assumed any permanent error was a
"no such domain", which is not great, but functional.

Now that our minimum supported version is Go 1.15, we can remove the
workaround.

This patch replaces the workaround with proper logic using
DNSError.IsNotFound to identify NXDOMAIN results when resolving MX
records.

This requires to adjust a few tests, that used to work on environments
where resolving unknown domains (used for testing) returned a permanent
error, and now they no longer do so. Instead of relying on this
environmental property, we make the affected tests use our own DNS
server, which should make them more hermetic and reproducible.
2021-10-08 23:11:29 +01:00

311 lines
6.5 KiB
Go

// +build ignore
// minidns is a trivial DNS server used for testing.
//
// It takes an "answers" file which contains lines with the following format:
//
// <domain> <type> <value>
//
// For example:
//
// blah A 1.2.3.4
// blah MX mx1
//
// Supported types: A, AAAA, MX, TXT.
//
// It's only meant to be used for testing, so it's not robust, performant, or
// standards compliant.
//
package main
import (
"bufio"
"encoding/binary"
"flag"
"fmt"
"net"
"os"
"regexp"
"strings"
"sync"
"blitiri.com.ar/go/log"
"golang.org/x/net/dns/dnsmessage"
)
var (
addr = flag.String("addr", ":53", "address to listen to (UDP)")
zonesPath = flag.String("zones", "", "file with the zones")
)
func main() {
flag.Parse()
srv := &miniDNS{
answers: map[string][]dnsmessage.Resource{},
}
if *zonesPath == "" {
log.Fatalf("-zones must be given")
}
var zonesFile *os.File
if *zonesPath == "-" {
zonesFile = os.Stdin
} else {
var err error
zonesFile, err = os.Open(*zonesPath)
if err != nil {
log.Fatalf("error opening %v: %v", *zonesPath, err)
}
}
srv.loadZones(zonesFile)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.listenAndServeUDP(*addr)
}()
go func() {
defer wg.Done()
srv.listenAndServeTCP(*addr)
}()
wg.Wait()
}
type miniDNS struct {
// Domain -> Answers.
// We always respond the same regardless of the query.
// Not great, but does the trick.
answers map[string][]dnsmessage.Resource
}
func (m *miniDNS) listenAndServeUDP(addr string) {
conn, err := net.ListenPacket("udp", addr)
if err != nil {
log.Fatalf("error listening UDP %q: %v", addr, err)
}
log.Infof("listening on %v", conn.LocalAddr())
buf := make([]byte, 64*1024)
for {
n, addr, err := conn.ReadFrom(buf)
if err != nil {
log.Infof("error reading from udp: %v", err)
continue
}
msg := &dnsmessage.Message{}
err = msg.Unpack(buf[:n])
if err != nil {
log.Infof("%v error unpacking message: %v", addr, err)
}
if lq := len(msg.Questions); lq != 1 {
log.Infof("%v/%-5d dropping packet with %d questions",
addr, msg.ID, lq)
continue
}
q := msg.Questions[0]
log.Infof("%v/%-5d Q: %s %s %s",
addr, msg.ID, q.Name, q.Type, q.Class)
reply := m.handle(msg)
rbuf, err := reply.Pack()
if err != nil {
log.Fatalf("error packing reply: %v", err)
}
conn.WriteTo(rbuf, addr)
}
}
func (m *miniDNS) listenAndServeTCP(addr string) {
ls, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalf("error listening TCP %q: %v", addr, err)
}
log.Infof("listening on %v", addr)
for {
conn, err := ls.Accept()
if err != nil {
log.Infof("error accepting: %v", err)
continue
}
msg, err := readTCPMessage(conn)
if err != nil {
log.Infof("%v error reading message: %v", addr, err)
conn.Close()
continue
}
if lq := len(msg.Questions); lq != 1 {
log.Infof("%v/%-5d dropping packet with %d questions",
addr, msg.ID, lq)
conn.Close()
continue
}
q := msg.Questions[0]
log.Infof("%v/%-5d Q: %s %s %s",
addr, msg.ID, q.Name, q.Type, q.Class)
reply := m.handle(msg)
err = writeTCPMessage(conn, reply)
if err != nil {
log.Infof("error writing reply: %v", err)
}
conn.Close()
}
}
func readTCPMessage(conn net.Conn) (*dnsmessage.Message, error) {
// Read the 2-byte length first, then the message.
lenHdr := struct{ Len uint16 }{}
err := binary.Read(conn, binary.BigEndian, &lenHdr)
if err != nil {
return nil, err
}
data := make([]byte, lenHdr.Len)
err = binary.Read(conn, binary.BigEndian, &data)
if err != nil {
return nil, err
}
msg := &dnsmessage.Message{}
err = msg.Unpack(data)
if err != nil {
return nil, fmt.Errorf("%v error unpacking message: %v", addr, err)
}
return msg, nil
}
func writeTCPMessage(conn net.Conn, msg *dnsmessage.Message) error {
rbuf, err := msg.Pack()
if err != nil {
return fmt.Errorf("error packing reply: %v", err)
}
lenHdr := struct{ Len uint16 }{Len: uint16(len(rbuf))}
err = binary.Write(conn, binary.BigEndian, lenHdr)
if err != nil {
return err
}
_, err = conn.Write(rbuf)
return err
}
func (m *miniDNS) handle(msg *dnsmessage.Message) *dnsmessage.Message {
reply := &dnsmessage.Message{
Header: dnsmessage.Header{
ID: msg.ID,
Response: true,
RCode: dnsmessage.RCodeSuccess,
// We're authoritative for the zones we're serving.
// We should either set this, or RecursionAvailable, otherwise
// some client libraries will complain.
Authoritative: true,
},
Questions: msg.Questions,
}
q := msg.Questions[0]
if answers, ok := m.answers[q.Name.String()]; ok {
for _, ans := range answers {
if q.Type == ans.Header.Type {
log.Infof("-> %s %v", q.Type, ans.Body)
reply.Answers = append(reply.Answers, ans)
}
}
} else {
log.Infof("-> NXERROR")
reply.Header.RCode = dnsmessage.RCodeNameError
}
return reply
}
func (m *miniDNS) loadZones(f *os.File) {
scanner := bufio.NewScanner(f)
lineno := 0
for scanner.Scan() {
lineno++
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "#") || line == "" {
continue
}
vs := regexp.MustCompile("\\s+").Split(line, 3)
if len(vs) != 3 {
log.Fatalf("line %d: invalid format", lineno)
}
domain, t, value := vs[0], vs[1], vs[2]
if !strings.HasSuffix(domain, ".") {
domain += "."
}
var body dnsmessage.ResourceBody
var qType dnsmessage.Type
switch strings.ToLower(t) {
case "a":
qType = dnsmessage.TypeA
ip := net.ParseIP(value).To4()
if ip == nil {
log.Fatalf("line %d: invalid IP %q", lineno, value)
}
a := &dnsmessage.AResource{}
copy(a.A[:], ip[:4])
body = a
case "aaaa":
qType = dnsmessage.TypeAAAA
ip := net.ParseIP(value).To16()
if ip == nil {
log.Fatalf("line %d: invalid IP %q", lineno, value)
}
aaaa := &dnsmessage.AAAAResource{}
copy(aaaa.AAAA[:], ip[:16])
body = aaaa
case "mx":
qType = dnsmessage.TypeMX
if !strings.HasPrefix(value, ".") {
value += "."
}
body = &dnsmessage.MXResource{
Pref: 10,
MX: dnsmessage.MustNewName(value),
}
case "txt":
qType = dnsmessage.TypeTXT
body = &dnsmessage.TXTResource{
TXT: []string{value},
}
default:
log.Fatalf("line %d: unknown type %q", lineno, t)
}
answer := dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(domain),
Type: qType,
Class: dnsmessage.ClassINET,
},
Body: body,
}
m.answers[domain] = append(m.answers[domain], answer)
}
if err := scanner.Err(); err != nil {
log.Fatalf("error reading zones: %v", err)
}
}