1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-17 14:37:02 +00:00

domaininfo: Add a Clear method to clear information for a given domain

This patch adds a Clear method to the domaininfo database, which removes
information for the given domain.

This can be used to manually make the server forget about a domain, in
case there are operational reasons to do so.

Today, this is done via chasquid-util (which removes the backing file),
but that is hacky, and this is part of replacing it with a cleaner
implementation.
This commit is contained in:
Alberto Bertogli
2023-07-30 11:34:00 +01:00
parent ac1c849a27
commit 764c09e94d
2 changed files with 106 additions and 10 deletions

View File

@@ -75,7 +75,7 @@ func (db *DB) Reload() error {
return nil return nil
} }
func (db *DB) write(tr *trace.Trace, d *Domain) { func (db *DB) write(tr *trace.Trace, d *Domain) error {
tr = tr.NewChild("DomainInfo.write", d.Name) tr = tr.NewChild("DomainInfo.write", d.Name)
defer tr.Finish() defer tr.Finish()
@@ -85,6 +85,7 @@ func (db *DB) write(tr *trace.Trace, d *Domain) {
} else { } else {
tr.Debugf("saved") tr.Debugf("saved")
} }
return err
} }
// IncomingSecLevel checks an incoming security level for the domain. // IncomingSecLevel checks an incoming security level for the domain.
@@ -158,3 +159,26 @@ func (db *DB) OutgoingSecLevel(tr *trace.Trace, domain string, level SecLevel) b
return true return true
} }
} }
// Clear sets the security level for the given domain to plain.
// This can be used for manual overrides in case there's an operational need
// to do so.
func (db *DB) Clear(tr *trace.Trace, domain string) bool {
tr = tr.NewChild("DomainInfo.SetToPlain", domain)
defer tr.Finish()
db.Lock()
defer db.Unlock()
d, exists := db.info[domain]
if !exists {
tr.Debugf("does not exist")
return false
}
d.IncomingSecLevel = SecLevel_PLAIN
d.OutgoingSecLevel = SecLevel_PLAIN
db.write(tr, d)
tr.Printf("set to plain")
return true
}

View File

@@ -1,6 +1,8 @@
package domaininfo package domaininfo
import ( import (
"errors"
"os"
"testing" "testing"
"blitiri.com.ar/go/chasquid/internal/testlib" "blitiri.com.ar/go/chasquid/internal/testlib"
@@ -17,14 +19,26 @@ func TestBasic(t *testing.T) {
tr := trace.New("test", "basic") tr := trace.New("test", "basic")
defer tr.Finish() defer tr.Finish()
// IncomingSecLevel checks.
if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) { if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
t.Errorf("new domain as plain not allowed") t.Errorf("incoming: new domain as plain not allowed")
} }
if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) { if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
t.Errorf("increment to tls-secure not allowed") t.Errorf("incoming: increment to tls-secure not allowed")
} }
if db.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) { if db.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
t.Errorf("decrement to tls-insecure was allowed") t.Errorf("incoming: decrement to tls-insecure was allowed")
}
// OutgoingSecLevel checks.
if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
t.Errorf("outgoing: new domain as plain not allowed")
}
if !db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
t.Errorf("outgoing: increment to tls-secure not allowed")
}
if db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
t.Errorf("outgoing: decrement to tls-insecure was allowed")
} }
// Check that it was added to the store and a new db sees it. // Check that it was added to the store and a new db sees it.
@@ -35,6 +49,24 @@ func TestBasic(t *testing.T) {
if db2.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) { if db2.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
t.Errorf("decrement to tls-insecure was allowed in new DB") t.Errorf("decrement to tls-insecure was allowed in new DB")
} }
// Check that Clear resets the entry back to plain.
ok := db.Clear(tr, "d1")
if !ok {
t.Errorf("Clear(d1) did not find the domain")
}
if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
t.Errorf("Clear did not reset the domain back to plain (incoming)")
}
if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
t.Errorf("Clear did not reset the domain back to plain (outgoing)")
}
// Check that Clear returns false if the domain does not exist.
ok = db.Clear(tr, "notexist")
if ok {
t.Errorf("Clear(notexist) returned true")
}
} }
func TestNewDomain(t *testing.T) { func TestNewDomain(t *testing.T) {
@@ -44,7 +76,7 @@ func TestNewDomain(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tr := trace.New("test", "basic") tr := trace.New("test", "newdomain")
defer tr.Finish() defer tr.Finish()
cases := []struct { cases := []struct {
@@ -56,12 +88,15 @@ func TestNewDomain(t *testing.T) {
{"secure", SecLevel_TLS_SECURE}, {"secure", SecLevel_TLS_SECURE},
} }
for _, c := range cases { for _, c := range cases {
if !db.IncomingSecLevel(tr, c.domain, c.level) { // The other tests do an incoming check first, so new domains would get
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level) // created via that path. We switch the order here to exercise that
} // OutgoingSecLevel also handles new domains successfuly.
if !db.OutgoingSecLevel(tr, c.domain, c.level) { if !db.OutgoingSecLevel(tr, c.domain, c.level) {
t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level) t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
} }
if !db.IncomingSecLevel(tr, c.domain, c.level) {
t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
}
} }
} }
@@ -72,7 +107,7 @@ func TestProgressions(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tr := trace.New("test", "basic") tr := trace.New("test", "progressions")
defer tr.Finish() defer tr.Finish()
cases := []struct { cases := []struct {
@@ -118,7 +153,7 @@ func TestErrors(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tr := trace.New("test", "basic") tr := trace.New("test", "errors")
defer tr.Finish() defer tr.Finish()
if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) { if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
@@ -131,4 +166,41 @@ func TestErrors(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("no error when reloading db with invalid file") t.Errorf("no error when reloading db with invalid file")
} }
// Creating a db with an invalid file should also result in an error.
_, err = New(dir)
if err == nil {
t.Errorf("no error when creating db with invalid file")
}
}
func TestDirectoryErrors(t *testing.T) {
dir := testlib.MustTempDir(t)
defer testlib.RemoveIfOk(t, dir)
db, err := New(dir + "/db")
if err != nil {
t.Fatal(err)
}
tr := trace.New("test", "direrrors")
defer tr.Finish()
// We want to cause store.ListIDs to return an error. To do so, we will
// cause Readdir to fail by removing the underlying db directory.
err = os.Remove(dir + "/db")
if err != nil {
t.Fatal(err)
}
err = db.Reload()
if !errors.Is(err, os.ErrNotExist) {
t.Errorf("got %v, expected %v", err, os.ErrNotExist)
}
// We expect write() to also fail to store data in this scenario.
d := Domain{Name: "d1"}
err = db.write(tr, &d)
if !errors.Is(err, os.ErrNotExist) {
t.Errorf("got %v, expected %v", err, os.ErrNotExist)
}
} }