//go:build !coverage // +build !coverage // minidns is a trivial DNS server used for testing. // // It takes an "answers" file which contains lines with the following format: // // // // For example: // // blah A 1.2.3.4 // blah MX mx1 // // Supported types: A, AAAA, MX, TXT. // // It's only meant to be used for testing, so it's not robust, performant, or // standards compliant. package main import ( "bufio" "encoding/binary" "flag" "fmt" "net" "os" "regexp" "strings" "sync" "blitiri.com.ar/go/log" "golang.org/x/net/dns/dnsmessage" ) var ( addr = flag.String("addr", ":53", "address to listen to (UDP)") zonesPath = flag.String("zones", "", "file with the zones") ) func main() { flag.Parse() srv := &miniDNS{ answers: map[string][]dnsmessage.Resource{}, } if *zonesPath == "" { log.Fatalf("-zones must be given") } var zonesFile *os.File if *zonesPath == "-" { zonesFile = os.Stdin } else { var err error zonesFile, err = os.Open(*zonesPath) if err != nil { log.Fatalf("error opening %v: %v", *zonesPath, err) } } srv.loadZones(zonesFile) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() srv.listenAndServeUDP(*addr) }() go func() { defer wg.Done() srv.listenAndServeTCP(*addr) }() wg.Wait() } type miniDNS struct { // Domain -> Answers. // We always respond the same regardless of the query. // Not great, but does the trick. answers map[string][]dnsmessage.Resource } func (m *miniDNS) listenAndServeUDP(addr string) { conn, err := net.ListenPacket("udp", addr) if err != nil { log.Fatalf("error listening UDP %q: %v", addr, err) } log.Infof("listening on %v", conn.LocalAddr()) buf := make([]byte, 64*1024) for { n, addr, err := conn.ReadFrom(buf) if err != nil { log.Infof("error reading from udp: %v", err) continue } msg := &dnsmessage.Message{} err = msg.Unpack(buf[:n]) if err != nil { log.Infof("%v error unpacking message: %v", addr, err) } if lq := len(msg.Questions); lq != 1 { log.Infof("%v/%-5d dropping packet with %d questions", addr, msg.ID, lq) continue } q := msg.Questions[0] log.Infof("%v/%-5d Q: %s %s %s", addr, msg.ID, q.Name, q.Type, q.Class) reply := m.handle(msg) rbuf, err := reply.Pack() if err != nil { log.Fatalf("error packing reply: %v", err) } _, err = conn.WriteTo(rbuf, addr) if err != nil { log.Infof("%v/%-5d error writing: %v", addr, msg.ID, err) } } } func (m *miniDNS) listenAndServeTCP(addr string) { ls, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("error listening TCP %q: %v", addr, err) } log.Infof("listening on %v", addr) for { conn, err := ls.Accept() if err != nil { log.Infof("error accepting: %v", err) continue } msg, err := readTCPMessage(conn) if err != nil { log.Infof("%v error reading message: %v", addr, err) conn.Close() continue } if lq := len(msg.Questions); lq != 1 { log.Infof("%v/%-5d dropping packet with %d questions", addr, msg.ID, lq) conn.Close() continue } q := msg.Questions[0] log.Infof("%v/%-5d Q: %s %s %s", addr, msg.ID, q.Name, q.Type, q.Class) reply := m.handle(msg) err = writeTCPMessage(conn, reply) if err != nil { log.Infof("error writing reply: %v", err) } conn.Close() } } func readTCPMessage(conn net.Conn) (*dnsmessage.Message, error) { // Read the 2-byte length first, then the message. lenHdr := struct{ Len uint16 }{} err := binary.Read(conn, binary.BigEndian, &lenHdr) if err != nil { return nil, err } data := make([]byte, lenHdr.Len) err = binary.Read(conn, binary.BigEndian, &data) if err != nil { return nil, err } msg := &dnsmessage.Message{} err = msg.Unpack(data) if err != nil { return nil, fmt.Errorf("%v error unpacking message: %v", addr, err) } return msg, nil } func writeTCPMessage(conn net.Conn, msg *dnsmessage.Message) error { rbuf, err := msg.Pack() if err != nil { return fmt.Errorf("error packing reply: %v", err) } lenHdr := struct{ Len uint16 }{Len: uint16(len(rbuf))} err = binary.Write(conn, binary.BigEndian, lenHdr) if err != nil { return err } _, err = conn.Write(rbuf) return err } func (m *miniDNS) handle(msg *dnsmessage.Message) *dnsmessage.Message { reply := &dnsmessage.Message{ Header: dnsmessage.Header{ ID: msg.ID, Response: true, RCode: dnsmessage.RCodeSuccess, // We're authoritative for the zones we're serving. // We should either set this, or RecursionAvailable, otherwise // some client libraries will complain. Authoritative: true, }, Questions: msg.Questions, } q := msg.Questions[0] if answers, ok := m.answers[q.Name.String()]; ok { for _, ans := range answers { if q.Type == ans.Header.Type { log.Infof("-> %s %v", q.Type, ans.Body) reply.Answers = append(reply.Answers, ans) } } } else { log.Infof("-> NXERROR") reply.Header.RCode = dnsmessage.RCodeNameError } return reply } func (m *miniDNS) loadZones(f *os.File) { scanner := bufio.NewScanner(f) lineno := 0 for scanner.Scan() { lineno++ line := strings.TrimSpace(scanner.Text()) if strings.HasPrefix(line, "#") || line == "" { continue } vs := regexp.MustCompile(`\s+`).Split(line, 3) if len(vs) != 3 { log.Fatalf("line %d: invalid format", lineno) } domain, t, value := vs[0], vs[1], vs[2] if !strings.HasSuffix(domain, ".") { domain += "." } var body dnsmessage.ResourceBody var qType dnsmessage.Type switch strings.ToLower(t) { case "a": qType = dnsmessage.TypeA ip := net.ParseIP(value).To4() if ip == nil { log.Fatalf("line %d: invalid IP %q", lineno, value) } a := &dnsmessage.AResource{} copy(a.A[:], ip[:4]) body = a case "aaaa": qType = dnsmessage.TypeAAAA ip := net.ParseIP(value).To16() if ip == nil { log.Fatalf("line %d: invalid IP %q", lineno, value) } aaaa := &dnsmessage.AAAAResource{} copy(aaaa.AAAA[:], ip[:16]) body = aaaa case "mx": qType = dnsmessage.TypeMX if !strings.HasPrefix(value, ".") { value += "." } body = &dnsmessage.MXResource{ Pref: 10, MX: dnsmessage.MustNewName(value), } case "txt": qType = dnsmessage.TypeTXT body = &dnsmessage.TXTResource{ TXT: []string{value}, } default: log.Fatalf("line %d: unknown type %q", lineno, t) } answer := dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(domain), Type: qType, Class: dnsmessage.ClassINET, }, Body: body, } m.answers[domain] = append(m.answers[domain], answer) } if err := scanner.Err(); err != nil { log.Fatalf("error reading zones: %v", err) } }