mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
Use the external log, spf and systemd packages
The log, spf and systemd packages have been externalized; use them instead of the internal version to avoid having two versions of the same thing.
This commit is contained in:
@@ -15,12 +15,12 @@ import (
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/config"
|
||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/chasquid/internal/maillog"
|
||||
"blitiri.com.ar/go/chasquid/internal/normalize"
|
||||
"blitiri.com.ar/go/chasquid/internal/smtpsrv"
|
||||
"blitiri.com.ar/go/chasquid/internal/systemd"
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
"blitiri.com.ar/go/log"
|
||||
"blitiri.com.ar/go/systemd"
|
||||
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"net"
|
||||
"net/smtp"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/spf"
|
||||
"blitiri.com.ar/go/chasquid/internal/tlsconst"
|
||||
"blitiri.com.ar/go/spf"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/spf"
|
||||
"blitiri.com.ar/go/spf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/log"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
// Package log implements a simple logger.
|
||||
//
|
||||
// It implements an API somewhat similar to "github.com/google/glog" with a
|
||||
// focus towards logging to stderr, which is useful for systemd-based
|
||||
// environments.
|
||||
//
|
||||
// There are command line flags (defined using the flag package) to control
|
||||
// the behaviour of the default logger. By default, it will write to stderr
|
||||
// without timestamps; this is suitable for systemd (or equivalent) logging.
|
||||
package log
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/syslog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Flags that control the default logging.
|
||||
var (
|
||||
vLevel = flag.Int("v", 0, "Verbosity level (1 = debug)")
|
||||
|
||||
logFile = flag.String("logfile", "",
|
||||
"file to log to (enables logtime)")
|
||||
|
||||
logToSyslog = flag.String("logtosyslog", "",
|
||||
"log to syslog, with the given tag")
|
||||
|
||||
logTime = flag.Bool("logtime", false,
|
||||
"include the time when writing the log to stderr")
|
||||
|
||||
alsoLogToStderr = flag.Bool("alsologtostderr", false,
|
||||
"also log to stderr, in addition to the file")
|
||||
)
|
||||
|
||||
// Logging levels.
|
||||
type Level int
|
||||
|
||||
const (
|
||||
Fatal = Level(-2)
|
||||
Error = Level(-1)
|
||||
Info = Level(0)
|
||||
Debug = Level(1)
|
||||
)
|
||||
|
||||
var levelToLetter = map[Level]string{
|
||||
Fatal: "☠",
|
||||
Error: "E",
|
||||
Info: "_",
|
||||
Debug: ".",
|
||||
}
|
||||
|
||||
// A Logger represents a logging object that writes logs to the given writer.
|
||||
type Logger struct {
|
||||
Level Level
|
||||
LogTime bool
|
||||
|
||||
CallerSkip int
|
||||
|
||||
w io.WriteCloser
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func New(w io.WriteCloser) *Logger {
|
||||
return &Logger{
|
||||
w: w,
|
||||
CallerSkip: 0,
|
||||
Level: Info,
|
||||
LogTime: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFile(path string) (*Logger, error) {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := New(f)
|
||||
l.LogTime = true
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func NewSyslog(priority syslog.Priority, tag string) (*Logger, error) {
|
||||
w, err := syslog.New(priority, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := New(w)
|
||||
l.LogTime = false
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.w.Close()
|
||||
}
|
||||
|
||||
func (l *Logger) V(level Level) bool {
|
||||
return level <= l.Level
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level Level, skip int, format string, a ...interface{}) {
|
||||
if !l.V(level) {
|
||||
return
|
||||
}
|
||||
|
||||
// Message.
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
|
||||
// Caller.
|
||||
_, file, line, ok := runtime.Caller(1 + l.CallerSkip + skip)
|
||||
if !ok {
|
||||
file = "unknown"
|
||||
}
|
||||
fl := fmt.Sprintf("%s:%-4d", filepath.Base(file), line)
|
||||
if len(fl) > 18 {
|
||||
fl = fl[len(fl)-18:]
|
||||
}
|
||||
msg = fmt.Sprintf("%-18s", fl) + " " + msg
|
||||
|
||||
// Level.
|
||||
letter, ok := levelToLetter[level]
|
||||
if !ok {
|
||||
letter = strconv.Itoa(int(level))
|
||||
}
|
||||
msg = letter + " " + msg
|
||||
|
||||
// Time.
|
||||
if l.LogTime {
|
||||
msg = time.Now().Format("20060102 15:04:05.000000 ") + msg
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(msg, "\n") {
|
||||
msg += "\n"
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
l.w.Write([]byte(msg))
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
func (l *Logger) Debugf(format string, a ...interface{}) {
|
||||
l.Log(Debug, 1, format, a...)
|
||||
}
|
||||
|
||||
func (l *Logger) Infof(format string, a ...interface{}) {
|
||||
l.Log(Info, 1, format, a...)
|
||||
}
|
||||
|
||||
func (l *Logger) Errorf(format string, a ...interface{}) error {
|
||||
l.Log(Error, 1, format, a...)
|
||||
return fmt.Errorf(format, a...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatalf(format string, a ...interface{}) {
|
||||
l.Log(-2, 1, format, a...)
|
||||
// TODO: Log traceback?
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// The default logger, used by the top-level functions below.
|
||||
var Default = &Logger{
|
||||
w: os.Stderr,
|
||||
CallerSkip: 1,
|
||||
Level: Info,
|
||||
LogTime: false,
|
||||
}
|
||||
|
||||
// Init the default logger, based on the command-line flags.
|
||||
// Must be called after flag.Parse().
|
||||
func Init() {
|
||||
var err error
|
||||
|
||||
if *logToSyslog != "" {
|
||||
Default, err = NewSyslog(syslog.LOG_DAEMON|syslog.LOG_INFO, *logToSyslog)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else if *logFile != "" {
|
||||
Default, err = NewFile(*logFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*logTime = true
|
||||
}
|
||||
|
||||
if *alsoLogToStderr && Default.w != os.Stderr {
|
||||
Default.w = multiWriteCloser(Default.w, os.Stderr)
|
||||
}
|
||||
|
||||
Default.CallerSkip = 1
|
||||
Default.Level = Level(*vLevel)
|
||||
Default.LogTime = *logTime
|
||||
}
|
||||
|
||||
func V(level Level) bool {
|
||||
return Default.V(level)
|
||||
}
|
||||
|
||||
func Log(level Level, skip int, format string, a ...interface{}) {
|
||||
Default.Log(level, skip, format, a...)
|
||||
}
|
||||
|
||||
func Debugf(format string, a ...interface{}) {
|
||||
Default.Debugf(format, a...)
|
||||
}
|
||||
|
||||
func Infof(format string, a ...interface{}) {
|
||||
Default.Infof(format, a...)
|
||||
}
|
||||
|
||||
func Errorf(format string, a ...interface{}) error {
|
||||
return Default.Errorf(format, a...)
|
||||
}
|
||||
|
||||
func Fatalf(format string, a ...interface{}) {
|
||||
Default.Fatalf(format, a...)
|
||||
}
|
||||
|
||||
// multiWriteCloser creates a WriteCloser that duplicates its writes and
|
||||
// closes to all the provided writers.
|
||||
func multiWriteCloser(wc ...io.WriteCloser) io.WriteCloser {
|
||||
return mwc(wc)
|
||||
}
|
||||
|
||||
type mwc []io.WriteCloser
|
||||
|
||||
func (m mwc) Write(p []byte) (n int, err error) {
|
||||
for _, w := range m {
|
||||
if n, err = w.Write(p); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (m mwc) Close() error {
|
||||
for _, w := range m {
|
||||
if err := w.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustNewFile(t *testing.T) (string, *Logger) {
|
||||
f, err := ioutil.TempFile("", "log_test-")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
l, err := NewFile(f.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open new log file: %v", err)
|
||||
}
|
||||
|
||||
return f.Name(), l
|
||||
}
|
||||
|
||||
func checkContentsMatch(t *testing.T, name, path, expected string) {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
got := string(content)
|
||||
if !regexp.MustCompile(expected).Match(content) {
|
||||
t.Errorf("%s: regexp %q did not match %q",
|
||||
name, expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func testLogger(t *testing.T, fname string, l *Logger) {
|
||||
l.LogTime = false
|
||||
l.Infof("message %d", 1)
|
||||
checkContentsMatch(t, "info-no-time", fname,
|
||||
"^_ log_test.go:.... message 1\n")
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.Infof("message %d\n", 1)
|
||||
checkContentsMatch(t, "info-with-newline", fname,
|
||||
"^_ log_test.go:.... message 1\n")
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.LogTime = true
|
||||
l.Infof("message %d", 1)
|
||||
checkContentsMatch(t, "info-with-time", fname,
|
||||
`^\d{8} ..:..:..\.\d{6} _ log_test.go:.... message 1\n`)
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.LogTime = false
|
||||
l.Errorf("error %d", 1)
|
||||
checkContentsMatch(t, "error", fname, `^E log_test.go:.... error 1\n`)
|
||||
|
||||
if l.V(Debug) {
|
||||
t.Fatalf("Debug level enabled by default (level: %v)", l.Level)
|
||||
}
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.LogTime = false
|
||||
l.Debugf("debug %d", 1)
|
||||
checkContentsMatch(t, "debug-no-log", fname, `^$`)
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.Level = Debug
|
||||
l.Debugf("debug %d", 1)
|
||||
checkContentsMatch(t, "debug", fname, `^\. log_test.go:.... debug 1\n`)
|
||||
|
||||
if !l.V(Debug) {
|
||||
t.Errorf("l.Level = Debug, but V(Debug) = false")
|
||||
}
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.Level = Info
|
||||
l.Log(Debug, 0, "log debug %d", 1)
|
||||
l.Log(Info, 0, "log info %d", 1)
|
||||
checkContentsMatch(t, "log", fname,
|
||||
`^_ log_test.go:.... log info 1\n`)
|
||||
|
||||
os.Truncate(fname, 0)
|
||||
l.Level = Info
|
||||
l.Log(Fatal, 0, "log fatal %d", 1)
|
||||
checkContentsMatch(t, "log", fname,
|
||||
`^☠ log_test.go:.... log fatal 1\n`)
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
fname, l := mustNewFile(t)
|
||||
defer l.Close()
|
||||
defer os.Remove(fname)
|
||||
|
||||
testLogger(t, fname, l)
|
||||
}
|
||||
|
||||
func TestDefaultFile(t *testing.T) {
|
||||
fname, l := mustNewFile(t)
|
||||
l.Close()
|
||||
defer os.Remove(fname)
|
||||
|
||||
*logFile = fname
|
||||
|
||||
Init()
|
||||
|
||||
testLogger(t, fname, Default)
|
||||
}
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
"blitiri.com.ar/go/log"
|
||||
)
|
||||
|
||||
// Global event logs.
|
||||
|
||||
@@ -23,11 +23,11 @@ import (
|
||||
"blitiri.com.ar/go/chasquid/internal/aliases"
|
||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||
"blitiri.com.ar/go/chasquid/internal/envelope"
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/chasquid/internal/maillog"
|
||||
"blitiri.com.ar/go/chasquid/internal/protoio"
|
||||
"blitiri.com.ar/go/chasquid/internal/set"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
"blitiri.com.ar/go/log"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
|
||||
@@ -28,10 +28,10 @@ import (
|
||||
"blitiri.com.ar/go/chasquid/internal/normalize"
|
||||
"blitiri.com.ar/go/chasquid/internal/queue"
|
||||
"blitiri.com.ar/go/chasquid/internal/set"
|
||||
"blitiri.com.ar/go/chasquid/internal/spf"
|
||||
"blitiri.com.ar/go/chasquid/internal/tlsconst"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
"blitiri.com.ar/go/spf"
|
||||
)
|
||||
|
||||
// Exported variables.
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||
"blitiri.com.ar/go/chasquid/internal/spf"
|
||||
"blitiri.com.ar/go/chasquid/internal/testlib"
|
||||
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||
"blitiri.com.ar/go/spf"
|
||||
)
|
||||
|
||||
func TestSecLevel(t *testing.T) {
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
"blitiri.com.ar/go/chasquid/internal/aliases"
|
||||
"blitiri.com.ar/go/chasquid/internal/courier"
|
||||
"blitiri.com.ar/go/chasquid/internal/domaininfo"
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/chasquid/internal/maillog"
|
||||
"blitiri.com.ar/go/chasquid/internal/queue"
|
||||
"blitiri.com.ar/go/chasquid/internal/set"
|
||||
"blitiri.com.ar/go/chasquid/internal/userdb"
|
||||
"blitiri.com.ar/go/log"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
|
||||
@@ -1,412 +0,0 @@
|
||||
// Package spf implements SPF (Sender Policy Framework) lookup and validation.
|
||||
//
|
||||
// Supported:
|
||||
// - "all".
|
||||
// - "include".
|
||||
// - "a".
|
||||
// - "mx".
|
||||
// - "ip4".
|
||||
// - "ip6".
|
||||
// - "redirect".
|
||||
//
|
||||
// Not supported (return Neutral if used):
|
||||
// - "exists".
|
||||
// - "exp".
|
||||
// - Macros.
|
||||
//
|
||||
// References:
|
||||
// https://tools.ietf.org/html/rfc7208
|
||||
// https://en.wikipedia.org/wiki/Sender_Policy_Framework
|
||||
package spf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Functions that we can override for testing purposes.
|
||||
var (
|
||||
lookupTXT = net.LookupTXT
|
||||
lookupMX = net.LookupMX
|
||||
lookupIP = net.LookupIP
|
||||
lookupAddr = net.LookupAddr
|
||||
)
|
||||
|
||||
// Results and Errors. Note the values have meaning, we use them in headers.
|
||||
// https://tools.ietf.org/html/rfc7208#section-8
|
||||
type Result string
|
||||
|
||||
var (
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.1
|
||||
// Not able to reach any conclusion.
|
||||
None = Result("none")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.2
|
||||
// No definite assertion (positive or negative).
|
||||
Neutral = Result("neutral")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.3
|
||||
// Client is authorized to inject mail.
|
||||
Pass = Result("pass")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.4
|
||||
// Client is *not* authorized to use the domain
|
||||
Fail = Result("fail")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.5
|
||||
// Not authorized, but unwilling to make a strong policy statement/
|
||||
SoftFail = Result("softfail")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.6
|
||||
// Transient error while performing the check.
|
||||
TempError = Result("temperror")
|
||||
|
||||
// https://tools.ietf.org/html/rfc7208#section-8.7
|
||||
// Records could not be correctly interpreted.
|
||||
PermError = Result("permerror")
|
||||
)
|
||||
|
||||
var QualToResult = map[byte]Result{
|
||||
'+': Pass,
|
||||
'-': Fail,
|
||||
'~': SoftFail,
|
||||
'?': Neutral,
|
||||
}
|
||||
|
||||
// CheckHost function fetches SPF records, parses them, and evaluates them to
|
||||
// determine whether a particular host is or is not permitted to send mail
|
||||
// with a given identity.
|
||||
// Reference: https://tools.ietf.org/html/rfc7208#section-4
|
||||
func CheckHost(ip net.IP, domain string) (Result, error) {
|
||||
r := &resolution{ip, 0, nil}
|
||||
return r.Check(domain)
|
||||
}
|
||||
|
||||
type resolution struct {
|
||||
ip net.IP
|
||||
count uint
|
||||
|
||||
// Result of doing a reverse lookup for ip (so we only do it once).
|
||||
ipNames []string
|
||||
}
|
||||
|
||||
func (r *resolution) Check(domain string) (Result, error) {
|
||||
// Limit the number of resolutions to 10
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.6.4
|
||||
if r.count > 10 {
|
||||
return PermError, fmt.Errorf("lookup limit reached")
|
||||
}
|
||||
r.count++
|
||||
|
||||
txt, err := getDNSRecord(domain)
|
||||
if err != nil {
|
||||
if isTemporary(err) {
|
||||
return TempError, err
|
||||
}
|
||||
// Could not resolve the name, it may be missing the record.
|
||||
// https://tools.ietf.org/html/rfc7208#section-2.6.1
|
||||
return None, err
|
||||
}
|
||||
|
||||
if txt == "" {
|
||||
// No record => None.
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.6
|
||||
return None, nil
|
||||
}
|
||||
|
||||
fields := strings.Fields(txt)
|
||||
|
||||
// redirects must be handled after the rest; instead of having two loops,
|
||||
// we just move them to the end.
|
||||
var newfields, redirects []string
|
||||
for _, field := range fields {
|
||||
if strings.HasPrefix(field, "redirect:") {
|
||||
redirects = append(redirects, field)
|
||||
} else {
|
||||
newfields = append(newfields, field)
|
||||
}
|
||||
}
|
||||
fields = append(newfields, redirects...)
|
||||
|
||||
for _, field := range fields {
|
||||
if strings.HasPrefix(field, "v=") {
|
||||
continue
|
||||
}
|
||||
if r.count > 10 {
|
||||
return PermError, fmt.Errorf("lookup limit reached")
|
||||
}
|
||||
if strings.Contains(field, "%") {
|
||||
return Neutral, fmt.Errorf("macros not supported")
|
||||
}
|
||||
|
||||
// See if we have a qualifier, defaulting to + (pass).
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.6.2
|
||||
result, ok := QualToResult[field[0]]
|
||||
if ok {
|
||||
field = field[1:]
|
||||
} else {
|
||||
result = Pass
|
||||
}
|
||||
|
||||
if field == "all" {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5.1
|
||||
return result, fmt.Errorf("matched 'all'")
|
||||
} else if strings.HasPrefix(field, "include:") {
|
||||
if ok, res, err := r.includeField(result, field); ok {
|
||||
return res, err
|
||||
}
|
||||
} else if strings.HasPrefix(field, "a") {
|
||||
if ok, res, err := r.aField(result, field, domain); ok {
|
||||
return res, err
|
||||
}
|
||||
} else if strings.HasPrefix(field, "mx") {
|
||||
if ok, res, err := r.mxField(result, field, domain); ok {
|
||||
return res, err
|
||||
}
|
||||
} else if strings.HasPrefix(field, "ip4:") || strings.HasPrefix(field, "ip6:") {
|
||||
if ok, res, err := r.ipField(result, field); ok {
|
||||
return res, err
|
||||
}
|
||||
} else if strings.HasPrefix(field, "ptr") {
|
||||
if ok, res, err := r.ptrField(result, field, domain); ok {
|
||||
return res, err
|
||||
}
|
||||
} else if strings.HasPrefix(field, "exists") {
|
||||
return Neutral, fmt.Errorf("'exists' not supported")
|
||||
} else if strings.HasPrefix(field, "exp=") {
|
||||
return Neutral, fmt.Errorf("'exp' not supported")
|
||||
} else if strings.HasPrefix(field, "redirect=") {
|
||||
// https://tools.ietf.org/html/rfc7208#section-6.1
|
||||
result, err := r.Check(field[len("redirect="):])
|
||||
if result == None {
|
||||
result = PermError
|
||||
}
|
||||
return result, err
|
||||
} else {
|
||||
// http://www.openspf.org/SPF_Record_Syntax
|
||||
return PermError, fmt.Errorf("unknown field %q", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Got to the end of the evaluation without a result => Neutral.
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.7
|
||||
return Neutral, nil
|
||||
}
|
||||
|
||||
// getDNSRecord gets TXT records from the given domain, and returns the SPF
|
||||
// (if any). Note that at most one SPF is allowed per a given domain:
|
||||
// https://tools.ietf.org/html/rfc7208#section-3
|
||||
// https://tools.ietf.org/html/rfc7208#section-3.2
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.5
|
||||
func getDNSRecord(domain string) (string, error) {
|
||||
txts, err := lookupTXT(domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, txt := range txts {
|
||||
if strings.HasPrefix(txt, "v=spf1 ") {
|
||||
return txt, nil
|
||||
}
|
||||
|
||||
// An empty record is explicitly allowed:
|
||||
// https://tools.ietf.org/html/rfc7208#section-4.5
|
||||
if txt == "v=spf1" {
|
||||
return txt, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func isTemporary(err error) bool {
|
||||
derr, ok := err.(*net.DNSError)
|
||||
return ok && derr.Temporary()
|
||||
}
|
||||
|
||||
// ipField processes an "ip" field.
|
||||
func (r *resolution) ipField(res Result, field string) (bool, Result, error) {
|
||||
fip := field[4:]
|
||||
if strings.Contains(fip, "/") {
|
||||
_, ipnet, err := net.ParseCIDR(fip)
|
||||
if err != nil {
|
||||
return true, PermError, err
|
||||
}
|
||||
if ipnet.Contains(r.ip) {
|
||||
return true, res, fmt.Errorf("matched %v", ipnet)
|
||||
}
|
||||
} else {
|
||||
ip := net.ParseIP(fip)
|
||||
if ip == nil {
|
||||
return true, PermError, fmt.Errorf("invalid ipX value")
|
||||
}
|
||||
if ip.Equal(r.ip) {
|
||||
return true, res, fmt.Errorf("matched %v", ip)
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// ptrField processes a "ptr" field.
|
||||
func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, error) {
|
||||
// Extract the domain if the field is in the form "ptr:domain"
|
||||
if len(field) >= 4 {
|
||||
domain = field[4:]
|
||||
|
||||
}
|
||||
|
||||
if r.ipNames == nil {
|
||||
r.count++
|
||||
n, err := lookupAddr(r.ip.String())
|
||||
if err != nil {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5
|
||||
if isTemporary(err) {
|
||||
return true, TempError, err
|
||||
}
|
||||
return false, "", err
|
||||
}
|
||||
r.ipNames = n
|
||||
}
|
||||
|
||||
for _, n := range r.ipNames {
|
||||
if strings.HasSuffix(n, domain+".") {
|
||||
return true, res, fmt.Errorf("matched ptr:%s", domain)
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// includeField processes an "include" field.
|
||||
func (r *resolution) includeField(res Result, field string) (bool, Result, error) {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5.2
|
||||
incdomain := field[len("include:"):]
|
||||
ir, err := r.Check(incdomain)
|
||||
switch ir {
|
||||
case Pass:
|
||||
return true, res, err
|
||||
case Fail, SoftFail, Neutral:
|
||||
return false, ir, err
|
||||
case TempError:
|
||||
return true, TempError, err
|
||||
case PermError, None:
|
||||
return true, PermError, err
|
||||
}
|
||||
|
||||
return false, "", fmt.Errorf("This should never be reached")
|
||||
|
||||
}
|
||||
|
||||
func ipMatch(ip, tomatch net.IP, mask int) (bool, error) {
|
||||
if mask >= 0 {
|
||||
_, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ipnet.Contains(ip) {
|
||||
return true, fmt.Errorf("%v", ipnet)
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
if ip.Equal(tomatch) {
|
||||
return true, fmt.Errorf("%v", tomatch)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
var aRegexp = regexp.MustCompile("a(:([^/]+))?(/(.+))?")
|
||||
var mxRegexp = regexp.MustCompile("mx(:([^/]+))?(/(.+))?")
|
||||
|
||||
func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error) {
|
||||
var err error
|
||||
mask := -1
|
||||
if groups := re.FindStringSubmatch(field); groups != nil {
|
||||
if groups[2] != "" {
|
||||
domain = groups[2]
|
||||
}
|
||||
if groups[4] != "" {
|
||||
mask, err = strconv.Atoi(groups[4])
|
||||
if err != nil {
|
||||
return "", -1, fmt.Errorf("error parsing mask")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return domain, mask, nil
|
||||
}
|
||||
|
||||
// aField processes an "a" field.
|
||||
func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5.3
|
||||
domain, mask, err := domainAndMask(aRegexp, field, domain)
|
||||
if err != nil {
|
||||
return true, PermError, err
|
||||
}
|
||||
|
||||
r.count++
|
||||
ips, err := lookupIP(domain)
|
||||
if err != nil {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5
|
||||
if isTemporary(err) {
|
||||
return true, TempError, err
|
||||
}
|
||||
return false, "", err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
ok, err := ipMatch(r.ip, ip, mask)
|
||||
if ok {
|
||||
return true, res, fmt.Errorf("matched 'a' (%v)", err)
|
||||
} else if err != nil {
|
||||
return true, PermError, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// mxField processes an "mx" field.
|
||||
func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5.4
|
||||
domain, mask, err := domainAndMask(mxRegexp, field, domain)
|
||||
if err != nil {
|
||||
return true, PermError, err
|
||||
}
|
||||
|
||||
r.count++
|
||||
mxs, err := lookupMX(domain)
|
||||
if err != nil {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5
|
||||
if isTemporary(err) {
|
||||
return true, TempError, err
|
||||
}
|
||||
return false, "", err
|
||||
}
|
||||
mxips := []net.IP{}
|
||||
for _, mx := range mxs {
|
||||
r.count++
|
||||
ips, err := lookupIP(mx.Host)
|
||||
if err != nil {
|
||||
// https://tools.ietf.org/html/rfc7208#section-5
|
||||
if isTemporary(err) {
|
||||
return true, TempError, err
|
||||
}
|
||||
return false, "", err
|
||||
}
|
||||
mxips = append(mxips, ips...)
|
||||
}
|
||||
for _, ip := range mxips {
|
||||
ok, err := ipMatch(r.ip, ip, mask)
|
||||
if ok {
|
||||
return true, res, fmt.Errorf("matched 'mx' (%v)", err)
|
||||
} else if err != nil {
|
||||
return true, PermError, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
package spf
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var txtResults = map[string][]string{}
|
||||
var txtErrors = map[string]error{}
|
||||
|
||||
func LookupTXT(domain string) (txts []string, err error) {
|
||||
return txtResults[domain], txtErrors[domain]
|
||||
}
|
||||
|
||||
var mxResults = map[string][]*net.MX{}
|
||||
var mxErrors = map[string]error{}
|
||||
|
||||
func LookupMX(domain string) (mxs []*net.MX, err error) {
|
||||
return mxResults[domain], mxErrors[domain]
|
||||
}
|
||||
|
||||
var ipResults = map[string][]net.IP{}
|
||||
var ipErrors = map[string]error{}
|
||||
|
||||
func LookupIP(host string) (ips []net.IP, err error) {
|
||||
return ipResults[host], ipErrors[host]
|
||||
}
|
||||
|
||||
var addrResults = map[string][]string{}
|
||||
var addrErrors = map[string]error{}
|
||||
|
||||
func LookupAddr(host string) (addrs []string, err error) {
|
||||
return addrResults[host], addrErrors[host]
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
lookupTXT = LookupTXT
|
||||
lookupMX = LookupMX
|
||||
lookupIP = LookupIP
|
||||
lookupAddr = LookupAddr
|
||||
|
||||
flag.Parse()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
var ip1110 = net.ParseIP("1.1.1.0")
|
||||
var ip1111 = net.ParseIP("1.1.1.1")
|
||||
var ip6666 = net.ParseIP("2001:db8::68")
|
||||
var ip6660 = net.ParseIP("2001:db8::0")
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
cases := []struct {
|
||||
txt string
|
||||
res Result
|
||||
}{
|
||||
{"", None},
|
||||
{"blah", None},
|
||||
{"v=spf1", Neutral},
|
||||
{"v=spf1 ", Neutral},
|
||||
{"v=spf1 -", PermError},
|
||||
{"v=spf1 all", Pass},
|
||||
{"v=spf1 +all", Pass},
|
||||
{"v=spf1 -all ", Fail},
|
||||
{"v=spf1 ~all", SoftFail},
|
||||
{"v=spf1 ?all", Neutral},
|
||||
{"v=spf1 a ~all", SoftFail},
|
||||
{"v=spf1 a/24", Neutral},
|
||||
{"v=spf1 a:d1110/24", Pass},
|
||||
{"v=spf1 a:d1110", Neutral},
|
||||
{"v=spf1 a:d1111", Pass},
|
||||
{"v=spf1 a:nothing/24", Neutral},
|
||||
{"v=spf1 mx", Neutral},
|
||||
{"v=spf1 mx/24", Neutral},
|
||||
{"v=spf1 mx:a/montoto ~all", PermError},
|
||||
{"v=spf1 mx:d1110/24 ~all", Pass},
|
||||
{"v=spf1 ip4:1.2.3.4 ~all", SoftFail},
|
||||
{"v=spf1 ip6:12 ~all", PermError},
|
||||
{"v=spf1 ip4:1.1.1.1 -all", Pass},
|
||||
{"v=spf1 ptr -all", Pass},
|
||||
{"v=spf1 ptr:d1111 -all", Pass},
|
||||
{"v=spf1 ptr:lalala -all", Pass},
|
||||
{"v=spf1 blah", PermError},
|
||||
}
|
||||
|
||||
ipResults["d1111"] = []net.IP{ip1111}
|
||||
ipResults["d1110"] = []net.IP{ip1110}
|
||||
mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
|
||||
addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."}
|
||||
|
||||
for _, c := range cases {
|
||||
txtResults["domain"] = []string{c.txt}
|
||||
res, err := CheckHost(ip1111, "domain")
|
||||
if (res == TempError || res == PermError) && (err == nil) {
|
||||
t.Errorf("%q: expected error, got nil", c.txt)
|
||||
}
|
||||
if res != c.res {
|
||||
t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
|
||||
t.Logf("%q: error: %v", c.txt, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPv6(t *testing.T) {
|
||||
cases := []struct {
|
||||
txt string
|
||||
res Result
|
||||
}{
|
||||
{"v=spf1 all", Pass},
|
||||
{"v=spf1 a ~all", SoftFail},
|
||||
{"v=spf1 a/24", Neutral},
|
||||
{"v=spf1 a:d6660/24", Pass},
|
||||
{"v=spf1 a:d6660", Neutral},
|
||||
{"v=spf1 a:d6666", Pass},
|
||||
{"v=spf1 a:nothing/24", Neutral},
|
||||
{"v=spf1 mx:d6660/24 ~all", Pass},
|
||||
{"v=spf1 ip6:2001:db8::68 ~all", Pass},
|
||||
{"v=spf1 ip6:2001:db8::1/24 ~all", Pass},
|
||||
{"v=spf1 ip6:2001:db8::1/100 ~all", Pass},
|
||||
{"v=spf1 ptr -all", Pass},
|
||||
{"v=spf1 ptr:d6666 -all", Pass},
|
||||
{"v=spf1 ptr:sonlas6 -all", Pass},
|
||||
}
|
||||
|
||||
ipResults["d6666"] = []net.IP{ip6666}
|
||||
ipResults["d6660"] = []net.IP{ip6660}
|
||||
mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}}
|
||||
addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."}
|
||||
|
||||
for _, c := range cases {
|
||||
txtResults["domain"] = []string{c.txt}
|
||||
res, err := CheckHost(ip6666, "domain")
|
||||
if (res == TempError || res == PermError) && (err == nil) {
|
||||
t.Errorf("%q: expected error, got nil", c.txt)
|
||||
}
|
||||
if res != c.res {
|
||||
t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
|
||||
t.Logf("%q: error: %v", c.txt, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotSupported(t *testing.T) {
|
||||
cases := []string{
|
||||
"v=spf1 exists:blah -all",
|
||||
"v=spf1 exp=blah -all",
|
||||
"v=spf1 a:%{o} -all",
|
||||
}
|
||||
|
||||
for _, txt := range cases {
|
||||
txtResults["domain"] = []string{txt}
|
||||
res, err := CheckHost(ip1111, "domain")
|
||||
if res != Neutral {
|
||||
t.Errorf("%q: expected neutral, got %v", txt, res)
|
||||
t.Logf("%q: error: %v", txt, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursion(t *testing.T) {
|
||||
txtResults["domain"] = []string{"v=spf1 include:domain ~all"}
|
||||
|
||||
res, err := CheckHost(ip1111, "domain")
|
||||
if res != PermError {
|
||||
t.Errorf("expected permerror, got %v (%v)", res, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoRecord(t *testing.T) {
|
||||
txtResults["d1"] = []string{""}
|
||||
txtResults["d2"] = []string{"loco", "v=spf2"}
|
||||
txtErrors["nospf"] = fmt.Errorf("no such domain")
|
||||
|
||||
for _, domain := range []string{"d1", "d2", "d3", "nospf"} {
|
||||
res, err := CheckHost(ip1111, domain)
|
||||
if res != None {
|
||||
t.Errorf("expected none, got %v (%v)", res, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
// Package systemd implements utility functions to interact with systemd.
|
||||
package systemd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
// Error to return when $LISTEN_PID does not refer to us.
|
||||
ErrPIDMismatch = errors.New("$LISTEN_PID != our PID")
|
||||
|
||||
// First FD for listeners.
|
||||
// It's 3 by definition, but using a variable simplifies testing.
|
||||
firstFD = 3
|
||||
)
|
||||
|
||||
// Listeners creates a slice net.Listener from the file descriptors passed
|
||||
// by systemd, via the LISTEN_FDS environment variable.
|
||||
// See sd_listen_fds(3) and sd_listen_fds_with_names(3) for more details.
|
||||
func Listeners() (map[string][]net.Listener, error) {
|
||||
pidStr := os.Getenv("LISTEN_PID")
|
||||
nfdsStr := os.Getenv("LISTEN_FDS")
|
||||
fdNamesStr := os.Getenv("LISTEN_FDNAMES")
|
||||
fdNames := strings.Split(fdNamesStr, ":")
|
||||
|
||||
// Nothing to do if the variables are not set.
|
||||
if pidStr == "" || nfdsStr == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(pidStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error converting $LISTEN_PID=%q: %v", pidStr, err)
|
||||
} else if pid != os.Getpid() {
|
||||
return nil, ErrPIDMismatch
|
||||
}
|
||||
|
||||
nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error reading $LISTEN_FDS=%q: %v", nfdsStr, err)
|
||||
}
|
||||
|
||||
// We should have as many names as we have descriptors.
|
||||
// Note that if we have no descriptors, fdNames will be [""] (due to how
|
||||
// strings.Split works), so we consider that special case.
|
||||
if nfds > 0 && (fdNamesStr == "" || len(fdNames) != nfds) {
|
||||
return nil, fmt.Errorf(
|
||||
"Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?")
|
||||
}
|
||||
|
||||
listeners := map[string][]net.Listener{}
|
||||
|
||||
for i := 0; i < nfds; i++ {
|
||||
fd := firstFD + i
|
||||
// We don't want childs to inherit these file descriptors.
|
||||
syscall.CloseOnExec(fd)
|
||||
|
||||
name := fdNames[i]
|
||||
|
||||
sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name)
|
||||
lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Error making listener out of fd %d: %v", fd, err)
|
||||
}
|
||||
|
||||
listeners[name] = append(listeners[name], lis)
|
||||
}
|
||||
|
||||
// Remove them from the environment, to prevent accidental reuse (by
|
||||
// us or children processes).
|
||||
os.Unsetenv("LISTEN_PID")
|
||||
os.Unsetenv("LISTEN_FDS")
|
||||
os.Unsetenv("LISTEN_FDNAMES")
|
||||
|
||||
return listeners, nil
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package systemd
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setenv(pid, fds string, names ...string) {
|
||||
os.Setenv("LISTEN_PID", pid)
|
||||
os.Setenv("LISTEN_FDS", fds)
|
||||
os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":"))
|
||||
}
|
||||
|
||||
func TestEmptyEnvironment(t *testing.T) {
|
||||
cases := []struct{ pid, fds string }{
|
||||
{"", ""},
|
||||
{"123", ""},
|
||||
{"", "4"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
setenv(c.pid, c.fds)
|
||||
|
||||
if ls, err := Listeners(); ls != nil || err != nil {
|
||||
t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds)
|
||||
t.Errorf("Unexpected result: %v // %v", ls, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadEnvironment(t *testing.T) {
|
||||
// Create a listener so we have something to reference.
|
||||
l := newListener(t)
|
||||
firstFD = listenerFd(t, l)
|
||||
|
||||
ourPID := strconv.Itoa(os.Getpid())
|
||||
cases := []struct {
|
||||
pid, fds string
|
||||
names []string
|
||||
}{
|
||||
{"a", "1", []string{"name"}}, // Invalid PID.
|
||||
{ourPID, "a", []string{"name"}}, // Invalid number of fds.
|
||||
{"1", "1", []string{"name"}}, // PID != ourselves.
|
||||
{ourPID, "1", []string{"name1", "name2"}}, // Too many names.
|
||||
{ourPID, "1", []string{}}, // Not enough names.
|
||||
}
|
||||
for _, c := range cases {
|
||||
setenv(c.pid, c.fds, c.names...)
|
||||
|
||||
if ls, err := Listeners(); err == nil {
|
||||
t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names)
|
||||
t.Errorf("Unexpected result: %v // %v", ls, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongPID(t *testing.T) {
|
||||
// Find a pid != us. 1 should always work in practice.
|
||||
pid := 1
|
||||
for pid == os.Getpid() {
|
||||
pid = rand.Int()
|
||||
}
|
||||
|
||||
setenv(strconv.Itoa(pid), "4")
|
||||
if _, err := Listeners(); err != ErrPIDMismatch {
|
||||
t.Errorf("Did not fail with PID mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoFDs(t *testing.T) {
|
||||
setenv(strconv.Itoa(os.Getpid()), "0")
|
||||
if ls, err := Listeners(); len(ls) != 0 || err != nil {
|
||||
t.Errorf("Got a non-empty result: %v // %v", ls, err)
|
||||
}
|
||||
}
|
||||
|
||||
// newListener creates a TCP listener.
|
||||
func newListener(t *testing.T) *net.TCPListener {
|
||||
addr := &net.TCPAddr{
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create TCP listener: %v", err)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// listenerFd returns a file descriptor for the listener.
|
||||
// Note it is a NEW file descriptor, not the original one.
|
||||
func listenerFd(t *testing.T, l *net.TCPListener) int {
|
||||
f, err := l.File()
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get TCP listener file: %v", err)
|
||||
}
|
||||
|
||||
return int(f.Fd())
|
||||
}
|
||||
|
||||
func sameAddr(a, b net.Addr) bool {
|
||||
return a.Network() == b.Network() && a.String() == b.String()
|
||||
}
|
||||
|
||||
func TestOneSocket(t *testing.T) {
|
||||
l := newListener(t)
|
||||
firstFD = listenerFd(t, l)
|
||||
|
||||
setenv(strconv.Itoa(os.Getpid()), "1", "name")
|
||||
|
||||
lsMap, err := Listeners()
|
||||
if err != nil || len(lsMap) != 1 {
|
||||
t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
|
||||
}
|
||||
|
||||
ls := lsMap["name"]
|
||||
|
||||
if !sameAddr(ls[0].Addr(), l.Addr()) {
|
||||
t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
|
||||
l.Addr(), ls[0].Addr())
|
||||
}
|
||||
|
||||
if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
|
||||
t.Errorf("Failed to reset the environment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManySockets(t *testing.T) {
|
||||
// Create two contiguous listeners.
|
||||
// The test environment does not guarantee us that they are contiguous, so
|
||||
// keep going until they are.
|
||||
var l0, l1 *net.TCPListener
|
||||
var f0, f1 int = -1, -3
|
||||
|
||||
for f0+1 != f1 {
|
||||
// We have to be careful with the order of these operations, because
|
||||
// listenerFd will create *new* file descriptors.
|
||||
l0 = newListener(t)
|
||||
l1 = newListener(t)
|
||||
f0 = listenerFd(t, l0)
|
||||
f1 = listenerFd(t, l1)
|
||||
t.Logf("Looping for FDs: %d %d", f0, f1)
|
||||
}
|
||||
|
||||
firstFD = f0
|
||||
|
||||
setenv(strconv.Itoa(os.Getpid()), "2", "name1", "name2")
|
||||
|
||||
lsMap, err := Listeners()
|
||||
if err != nil || len(lsMap) != 2 {
|
||||
t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
|
||||
}
|
||||
|
||||
ls := []net.Listener{
|
||||
lsMap["name1"][0],
|
||||
lsMap["name2"][0],
|
||||
}
|
||||
|
||||
if !sameAddr(ls[0].Addr(), l0.Addr()) {
|
||||
t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
|
||||
l0.Addr(), ls[0].Addr())
|
||||
}
|
||||
|
||||
if !sameAddr(ls[1].Addr(), l1.Addr()) {
|
||||
t.Errorf("Listener 1 address mismatch, expected %#v, got %#v",
|
||||
l1.Addr(), ls[1].Addr())
|
||||
}
|
||||
|
||||
if os.Getenv("LISTEN_PID") != "" ||
|
||||
os.Getenv("LISTEN_FDS") != "" ||
|
||||
os.Getenv("LISTEN_FDNAMES") != "" {
|
||||
t.Errorf("Failed to reset the environment")
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/log"
|
||||
"blitiri.com.ar/go/log"
|
||||
|
||||
nettrace "golang.org/x/net/trace"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user