mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-17 14:37:02 +00:00
domaininfo: Add a Clear method to clear information for a given domain
This patch adds a Clear method to the domaininfo database, which removes information for the given domain. This can be used to manually make the server forget about a domain, in case there are operational reasons to do so. Today, this is done via chasquid-util (which removes the backing file), but that is hacky, and this is part of replacing it with a cleaner implementation.
This commit is contained in:
@@ -75,7 +75,7 @@ func (db *DB) Reload() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) write(tr *trace.Trace, d *Domain) {
|
||||
func (db *DB) write(tr *trace.Trace, d *Domain) error {
|
||||
tr = tr.NewChild("DomainInfo.write", d.Name)
|
||||
defer tr.Finish()
|
||||
|
||||
@@ -85,6 +85,7 @@ func (db *DB) write(tr *trace.Trace, d *Domain) {
|
||||
} else {
|
||||
tr.Debugf("saved")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// IncomingSecLevel checks an incoming security level for the domain.
|
||||
@@ -158,3 +159,26 @@ func (db *DB) OutgoingSecLevel(tr *trace.Trace, domain string, level SecLevel) b
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Clear sets the security level for the given domain to plain.
|
||||
// This can be used for manual overrides in case there's an operational need
|
||||
// to do so.
|
||||
func (db *DB) Clear(tr *trace.Trace, domain string) bool {
|
||||
tr = tr.NewChild("DomainInfo.SetToPlain", domain)
|
||||
defer tr.Finish()
|
||||
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
|
||||
d, exists := db.info[domain]
|
||||
if !exists {
|
||||
tr.Debugf("does not exist")
|
||||
return false
|
||||
}
|
||||
|
||||
d.IncomingSecLevel = SecLevel_PLAIN
|
||||
d.OutgoingSecLevel = SecLevel_PLAIN
|
||||
db.write(tr, d)
|
||||
tr.Printf("set to plain")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package domaininfo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"blitiri.com.ar/go/chasquid/internal/testlib"
|
||||
@@ -17,14 +19,26 @@ func TestBasic(t *testing.T) {
|
||||
tr := trace.New("test", "basic")
|
||||
defer tr.Finish()
|
||||
|
||||
// IncomingSecLevel checks.
|
||||
if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
|
||||
t.Errorf("new domain as plain not allowed")
|
||||
t.Errorf("incoming: new domain as plain not allowed")
|
||||
}
|
||||
if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
|
||||
t.Errorf("increment to tls-secure not allowed")
|
||||
t.Errorf("incoming: increment to tls-secure not allowed")
|
||||
}
|
||||
if db.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
|
||||
t.Errorf("decrement to tls-insecure was allowed")
|
||||
t.Errorf("incoming: decrement to tls-insecure was allowed")
|
||||
}
|
||||
|
||||
// OutgoingSecLevel checks.
|
||||
if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
|
||||
t.Errorf("outgoing: new domain as plain not allowed")
|
||||
}
|
||||
if !db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
|
||||
t.Errorf("outgoing: increment to tls-secure not allowed")
|
||||
}
|
||||
if db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
|
||||
t.Errorf("outgoing: decrement to tls-insecure was allowed")
|
||||
}
|
||||
|
||||
// Check that it was added to the store and a new db sees it.
|
||||
@@ -35,6 +49,24 @@ func TestBasic(t *testing.T) {
|
||||
if db2.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
|
||||
t.Errorf("decrement to tls-insecure was allowed in new DB")
|
||||
}
|
||||
|
||||
// Check that Clear resets the entry back to plain.
|
||||
ok := db.Clear(tr, "d1")
|
||||
if !ok {
|
||||
t.Errorf("Clear(d1) did not find the domain")
|
||||
}
|
||||
if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
|
||||
t.Errorf("Clear did not reset the domain back to plain (incoming)")
|
||||
}
|
||||
if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
|
||||
t.Errorf("Clear did not reset the domain back to plain (outgoing)")
|
||||
}
|
||||
|
||||
// Check that Clear returns false if the domain does not exist.
|
||||
ok = db.Clear(tr, "notexist")
|
||||
if ok {
|
||||
t.Errorf("Clear(notexist) returned true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDomain(t *testing.T) {
|
||||
@@ -44,7 +76,7 @@ func TestNewDomain(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tr := trace.New("test", "basic")
|
||||
tr := trace.New("test", "newdomain")
|
||||
defer tr.Finish()
|
||||
|
||||
cases := []struct {
|
||||
@@ -56,12 +88,15 @@ func TestNewDomain(t *testing.T) {
|
||||
{"secure", SecLevel_TLS_SECURE},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if !db.IncomingSecLevel(tr, c.domain, c.level) {
|
||||
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
|
||||
}
|
||||
// The other tests do an incoming check first, so new domains would get
|
||||
// created via that path. We switch the order here to exercise that
|
||||
// OutgoingSecLevel also handles new domains successfuly.
|
||||
if !db.OutgoingSecLevel(tr, c.domain, c.level) {
|
||||
t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
|
||||
}
|
||||
if !db.IncomingSecLevel(tr, c.domain, c.level) {
|
||||
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +107,7 @@ func TestProgressions(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tr := trace.New("test", "basic")
|
||||
tr := trace.New("test", "progressions")
|
||||
defer tr.Finish()
|
||||
|
||||
cases := []struct {
|
||||
@@ -118,7 +153,7 @@ func TestErrors(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tr := trace.New("test", "basic")
|
||||
tr := trace.New("test", "errors")
|
||||
defer tr.Finish()
|
||||
|
||||
if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
|
||||
@@ -131,4 +166,41 @@ func TestErrors(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Errorf("no error when reloading db with invalid file")
|
||||
}
|
||||
|
||||
// Creating a db with an invalid file should also result in an error.
|
||||
_, err = New(dir)
|
||||
if err == nil {
|
||||
t.Errorf("no error when creating db with invalid file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectoryErrors(t *testing.T) {
|
||||
dir := testlib.MustTempDir(t)
|
||||
defer testlib.RemoveIfOk(t, dir)
|
||||
db, err := New(dir + "/db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tr := trace.New("test", "direrrors")
|
||||
defer tr.Finish()
|
||||
|
||||
// We want to cause store.ListIDs to return an error. To do so, we will
|
||||
// cause Readdir to fail by removing the underlying db directory.
|
||||
err = os.Remove(dir + "/db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Reload()
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Errorf("got %v, expected %v", err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
// We expect write() to also fail to store data in this scenario.
|
||||
d := Domain{Name: "d1"}
|
||||
err = db.write(tr, &d)
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Errorf("got %v, expected %v", err, os.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user