mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-19 14:57:04 +00:00
auth: Implement an Authenticator type
This patch implements an Authenticator type, which connections use to do authentication and user existence checks. It simplifies the abstractions (the server doesn't need to know about userdb, or keep track of domain-userdb maps), and lays the foundation for other types of authentication backends which will come in later patches.
This commit is contained in:
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -58,36 +59,212 @@ func TestAuthenticate(t *testing.T) {
|
||||
db := userdb.New("/dev/null")
|
||||
db.AddUser("user", "password")
|
||||
|
||||
a := NewAuthenticator()
|
||||
a.Register("domain", WrapNoErrorBackend(db))
|
||||
|
||||
// Shorten the duration to speed up the test. This should still be long
|
||||
// enough for it to fail if we don't sleep intentionally.
|
||||
a.AuthDuration = 20 * time.Millisecond
|
||||
|
||||
// Test the correct case first
|
||||
check(t, a, "user", "domain", "password", true)
|
||||
|
||||
// Wrong password, but valid user@domain.
|
||||
ts := time.Now()
|
||||
if !Authenticate(db, "user", "password") {
|
||||
t.Errorf("failed valid authentication for user/password")
|
||||
if ok, _ := a.Authenticate("user", "domain", "invalid"); ok {
|
||||
t.Errorf("invalid password, but authentication succeeded")
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
if time.Since(ts) < a.AuthDuration {
|
||||
t.Errorf("authentication was too fast (invalid case)")
|
||||
}
|
||||
|
||||
// Incorrect cases.
|
||||
cases := []struct{ user, password string }{
|
||||
{"user", "incorrect"},
|
||||
{"invalid", "p"},
|
||||
// Incorrect cases, where the user@domain do not exist.
|
||||
cases := []struct{ user, domain, password string }{
|
||||
{"user", "unknown", "password"},
|
||||
{"invalid", "domain", "p"},
|
||||
{"invalid", "unknown", "p"},
|
||||
{"user", "", "password"},
|
||||
{"invalid", "", "p"},
|
||||
{"", "domain", "password"},
|
||||
{"", "", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
ts = time.Now()
|
||||
if Authenticate(db, c.user, c.password) {
|
||||
t.Errorf("successful auth on %v", c)
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
}
|
||||
}
|
||||
|
||||
// And the special case of a nil userdb.
|
||||
ts = time.Now()
|
||||
if Authenticate(nil, "user", "password") {
|
||||
t.Errorf("successful auth on a nil userdb")
|
||||
}
|
||||
if time.Since(ts) < AuthenticateTime {
|
||||
t.Errorf("authentication was too fast")
|
||||
check(t, a, c.user, c.domain, c.password, false)
|
||||
}
|
||||
}
|
||||
|
||||
func check(t *testing.T, a *Authenticator, user, domain, passwd string, expect bool) {
|
||||
c := fmt.Sprintf("{%s@%s %s}", user, domain, passwd)
|
||||
ts := time.Now()
|
||||
|
||||
ok, err := a.Authenticate(user, domain, passwd)
|
||||
if time.Since(ts) < a.AuthDuration {
|
||||
t.Errorf("auth on %v was too fast", c)
|
||||
}
|
||||
if ok != expect {
|
||||
t.Errorf("auth on %v: got %v, expected %v", c, ok, expect)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("auth on %v: got error %v", c, err)
|
||||
}
|
||||
|
||||
ok, err = a.Exists(user, domain)
|
||||
if ok != expect {
|
||||
t.Errorf("exists on %v: got %v, expected %v", c, ok, expect)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("exists on %v: error %v", c, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaces(t *testing.T) {
|
||||
var _ NoErrorBackend = userdb.New("/dev/null")
|
||||
}
|
||||
|
||||
// Backend implementation for testing.
|
||||
type TestBE struct {
|
||||
users map[string]string
|
||||
reloadCount int
|
||||
nextError error
|
||||
}
|
||||
|
||||
func NewTestBE() *TestBE {
|
||||
return &TestBE{
|
||||
users: map[string]string{},
|
||||
}
|
||||
}
|
||||
func (d *TestBE) add(user, password string) {
|
||||
d.users[user] = password
|
||||
}
|
||||
|
||||
func (d *TestBE) Authenticate(user, password string) (bool, error) {
|
||||
if d.nextError != nil {
|
||||
return false, d.nextError
|
||||
}
|
||||
|
||||
if validP, ok := d.users[user]; ok {
|
||||
return validP == password, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (d *TestBE) Exists(user string) (bool, error) {
|
||||
if d.nextError != nil {
|
||||
return false, d.nextError
|
||||
}
|
||||
_, ok := d.users[user]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (d *TestBE) Reload() error {
|
||||
d.reloadCount++
|
||||
if d.nextError != nil {
|
||||
return d.nextError
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMultipleBackends(t *testing.T) {
|
||||
domain1 := NewTestBE()
|
||||
domain2 := NewTestBE()
|
||||
fallback := NewTestBE()
|
||||
|
||||
a := NewAuthenticator()
|
||||
a.Register("domain1", domain1)
|
||||
a.Register("domain2", domain2)
|
||||
a.Fallback = fallback
|
||||
|
||||
// Shorten the duration to speed up the test. This should still be long
|
||||
// enough for it to fail if we don't sleep intentionally.
|
||||
a.AuthDuration = 20 * time.Millisecond
|
||||
|
||||
domain1.add("user1", "passwd1")
|
||||
domain2.add("user2", "passwd2")
|
||||
fallback.add("user3@fallback", "passwd3")
|
||||
fallback.add("user4@domain1", "passwd4")
|
||||
|
||||
// Successful tests.
|
||||
cases := []struct{ user, domain, password string }{
|
||||
{"user1", "domain1", "passwd1"},
|
||||
{"user2", "domain2", "passwd2"},
|
||||
{"user3", "fallback", "passwd3"},
|
||||
{"user4", "domain1", "passwd4"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
check(t, a, c.user, c.domain, c.password, true)
|
||||
}
|
||||
|
||||
// Unsuccessful tests (users don't exist).
|
||||
cases = []struct{ user, domain, password string }{
|
||||
{"nobody", "domain1", "p"},
|
||||
{"nobody", "domain2", "p"},
|
||||
{"nobody", "fallback", "p"},
|
||||
{"user3", "", "p"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
check(t, a, c.user, c.domain, c.password, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
be := NewTestBE()
|
||||
be.add("user", "passwd")
|
||||
|
||||
a := NewAuthenticator()
|
||||
a.Register("domain", be)
|
||||
a.AuthDuration = 0
|
||||
|
||||
ok, err := a.Authenticate("user", "domain", "passwd")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("failed auth")
|
||||
}
|
||||
|
||||
expectedErr := fmt.Errorf("test error")
|
||||
be.nextError = expectedErr
|
||||
|
||||
ok, err = a.Authenticate("user", "domain", "passwd")
|
||||
if ok {
|
||||
t.Errorf("authentication succeeded, expected error")
|
||||
}
|
||||
if err != expectedErr {
|
||||
t.Errorf("expected error, got %v", err)
|
||||
}
|
||||
|
||||
ok, err = a.Exists("user", "domain")
|
||||
if ok {
|
||||
t.Errorf("exists succeeded, expected error")
|
||||
}
|
||||
if err != expectedErr {
|
||||
t.Errorf("expected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
be1 := NewTestBE()
|
||||
be2 := NewTestBE()
|
||||
fallback := NewTestBE()
|
||||
|
||||
a := NewAuthenticator()
|
||||
a.Register("domain1", be1)
|
||||
a.Register("domain2", be2)
|
||||
a.Fallback = fallback
|
||||
|
||||
err := a.Reload()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error reloading: %v", err)
|
||||
}
|
||||
if be1.reloadCount != 1 || be2.reloadCount != 1 || fallback.reloadCount != 1 {
|
||||
t.Errorf("unexpected reload counts: %d %d %d != 1 1 1",
|
||||
be1.reloadCount, be2.reloadCount, fallback.reloadCount)
|
||||
}
|
||||
|
||||
be2.nextError = fmt.Errorf("test error")
|
||||
err = a.Reload()
|
||||
if err == nil {
|
||||
t.Errorf("expected error reloading, got nil")
|
||||
}
|
||||
if be1.reloadCount != 2 || be2.reloadCount != 2 || fallback.reloadCount != 2 {
|
||||
t.Errorf("unexpected reload counts: %d %d %d != 2 2 2",
|
||||
be1.reloadCount, be2.reloadCount, fallback.reloadCount)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user