1
0
mirror of https://blitiri.com.ar/repos/chasquid synced 2025-12-18 14:47:03 +00:00

domaininfo: Reload periodically

This patch makes chasquid reload domaininfo periodically, so it notices
any external changes made to it.

It is in line with what we do for aliases and authentication already,
and makes it possible for external removals an additions to the
domaininfo database to be picked up without a restart.
This commit is contained in:
Alberto Bertogli
2018-05-20 12:21:15 +01:00
parent 2064e9e65d
commit a177fec7c3
5 changed files with 83 additions and 25 deletions

View File

@@ -38,11 +38,22 @@ func New(dir string) (*DB, error) {
} }
l.ev = trace.NewEventLog("DomainInfo", dir) l.ev = trace.NewEventLog("DomainInfo", dir)
err = l.Reload()
if err != nil {
return nil, err
}
return l, nil return l, nil
} }
// Load the database from disk; should be called once at initialization. // Reload the database from disk.
func (db *DB) Load() error { func (db *DB) Reload() error {
db.Lock()
defer db.Unlock()
// Clear the map, in case it has data.
db.info = map[string]*Domain{}
ids, err := db.store.ListIDs() ids, err := db.store.ListIDs()
if err != nil { if err != nil {
return err return err

View File

@@ -2,7 +2,6 @@ package domaininfo
import ( import (
"testing" "testing"
"time"
"blitiri.com.ar/go/chasquid/internal/testlib" "blitiri.com.ar/go/chasquid/internal/testlib"
) )
@@ -15,10 +14,6 @@ func TestBasic(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if err := db.Load(); err != nil {
t.Fatal(err)
}
if !db.IncomingSecLevel("d1", SecLevel_PLAIN) { if !db.IncomingSecLevel("d1", SecLevel_PLAIN) {
t.Errorf("new domain as plain not allowed") t.Errorf("new domain as plain not allowed")
} }
@@ -29,24 +24,11 @@ func TestBasic(t *testing.T) {
t.Errorf("decrement to tls-insecure was allowed") t.Errorf("decrement to tls-insecure was allowed")
} }
// Wait until it is written to disk.
for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); {
d := &Domain{}
ok, _ := db.store.Get("d1", d)
if ok {
break
}
time.Sleep(50 * time.Millisecond)
}
// Check that it was added to the store and a new db sees it. // Check that it was added to the store and a new db sees it.
db2, err := New(dir) db2, err := New(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := db2.Load(); err != nil {
t.Fatal(err)
}
if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) { if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
t.Errorf("decrement to tls-insecure was allowed in new DB") t.Errorf("decrement to tls-insecure was allowed in new DB")
} }
@@ -113,3 +95,30 @@ func TestProgressions(t *testing.T) {
} }
} }
} }
func TestErrors(t *testing.T) {
// Non-existent directory.
_, err := New("/doesnotexists")
if err == nil {
t.Error("could create a DB on a non-existent directory")
}
// Corrupt/invalid file.
dir := testlib.MustTempDir(t)
defer testlib.RemoveIfOk(t, dir)
db, err := New(dir)
if err != nil {
t.Fatal(err)
}
if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) {
t.Errorf("increment to tls-secure not allowed")
}
testlib.Rewrite(t, dir+"/s:d1", "invalid-text-protobuf-contents")
err = db.Reload()
if err == nil {
t.Errorf("no error when reloading db with invalid file")
}
}

View File

@@ -140,11 +140,6 @@ func (s *Server) InitDomainInfo(dir string) *domaininfo.DB {
log.Fatalf("Error opening domain info database: %v", err) log.Fatalf("Error opening domain info database: %v", err)
} }
err = s.dinfo.Load()
if err != nil {
log.Fatalf("Error loading domain info database: %v", err)
}
return s.dinfo return s.dinfo
} }
@@ -176,6 +171,11 @@ func (s *Server) periodicallyReload() {
if err != nil { if err != nil {
log.Errorf("Error reloading authenticators: %v", err) log.Errorf("Error reloading authenticators: %v", err)
} }
err = s.dinfo.Reload()
if err != nil {
log.Errorf("Error reloading domaininfo: %v", err)
}
} }
} }

View File

@@ -37,3 +37,17 @@ func RemoveIfOk(t *testing.T, dir string) {
os.RemoveAll(dir) os.RemoveAll(dir)
} }
} }
func Rewrite(t *testing.T, path, contents string) error {
// Safeguard, to make sure we only mess with test files.
if !strings.Contains(path, "testlib_") {
panic("invalid/dangerous path")
}
err := ioutil.WriteFile(path, []byte(contents), 0600)
if err != nil {
t.Errorf("failed to rewrite file: %v", err)
}
return err
}

View File

@@ -52,3 +52,27 @@ func TestLeaveDirOnError(t *testing.T) {
// Remove the directory for real this time. // Remove the directory for real this time.
RemoveIfOk(t, dir) RemoveIfOk(t, dir)
} }
func TestRewriteSafeguard(t *testing.T) {
myt := &testing.T{}
defer func() {
if r := recover(); r != nil {
t.Logf("recovered: %v", r)
} else {
t.Fatalf("check did not panic as expected")
}
}()
Rewrite(myt, "/something", "test")
}
func TestRewrite(t *testing.T) {
dir := MustTempDir(t)
defer RemoveIfOk(t, dir)
myt := &testing.T{}
Rewrite(myt, dir+"/file", "hola")
if myt.Failed() {
t.Errorf("basic rewrite failed")
}
}