mirror of
https://blitiri.com.ar/repos/chasquid
synced 2025-12-18 14:47:03 +00:00
localrpc: Add a package for local RPC over UNIX sockets
This patch adds a new package for doing local lightweight RPC calls over UNIX sockets. This will be used in later patches for communication between chasquid and chasquid-util.
This commit is contained in:
76
internal/localrpc/client_test.go
Normal file
76
internal/localrpc/client_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package localrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"io/fs"
|
||||||
|
"net"
|
||||||
|
"net/textproto"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewFakeServer(t *testing.T, path, output string) {
|
||||||
|
t.Helper()
|
||||||
|
lis, err := net.Listen("unix", path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("FakeServer %v: accepted ", conn)
|
||||||
|
|
||||||
|
name, inS, err := readRequest(
|
||||||
|
textproto.NewReader(bufio.NewReader(conn)))
|
||||||
|
t.Logf("FakeServer %v: readRequest: %q %q / %v", conn, name, inS, err)
|
||||||
|
|
||||||
|
n, err := conn.Write([]byte(output))
|
||||||
|
t.Logf("FakeServer %v: writeMessage(%q): %d %v",
|
||||||
|
conn, output, n, err)
|
||||||
|
|
||||||
|
t.Logf("FakeServer %v: closing", conn)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadServer(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "rpc-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
socketPath := filepath.Join(tmpDir, "rpc.sock")
|
||||||
|
|
||||||
|
// textproto client expects a numeric code, this should cause ReadCodeLine
|
||||||
|
// to fail with textproto.ProtocolError.
|
||||||
|
go NewFakeServer(t, socketPath, "xxx")
|
||||||
|
waitForServer(t, socketPath)
|
||||||
|
|
||||||
|
client := NewClient(socketPath)
|
||||||
|
_, err = client.Call("Echo")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
var protoErr textproto.ProtocolError
|
||||||
|
if !errors.As(err, &protoErr) {
|
||||||
|
t.Errorf("wanted textproto.ProtocolError, got: %v (%T)", err, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadSocket(t *testing.T) {
|
||||||
|
c := NewClient("/does/not/exist")
|
||||||
|
_, err := c.Call("Echo")
|
||||||
|
|
||||||
|
opErr, ok := err.(*net.OpError)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected net.OpError, got %q (%T)", err, err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
t.Errorf("wanted ErrNotExist, got: %q (%T)", opErr.Err, opErr.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
internal/localrpc/e2e_test.go
Normal file
139
internal/localrpc/e2e_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package localrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Echo(tr *trace.Trace, input url.Values) (url.Values, error) {
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Hola(tr *trace.Trace, input url.Values) (url.Values, error) {
|
||||||
|
output := url.Values{}
|
||||||
|
output.Set("greeting", "Hola "+input.Get("name"))
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var testErr = errors.New("test error")
|
||||||
|
|
||||||
|
func HolaErr(tr *trace.Trace, input url.Values) (url.Values, error) {
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type testServer struct {
|
||||||
|
dir string
|
||||||
|
sock string
|
||||||
|
*Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestServer(t *testing.T) *testServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "rpc-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsrv := &testServer{
|
||||||
|
dir: tmpDir,
|
||||||
|
sock: tmpDir + "/sock",
|
||||||
|
Server: NewServer(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tsrv.Register("Echo", Echo)
|
||||||
|
tsrv.Register("Hola", Hola)
|
||||||
|
tsrv.Register("HolaErr", HolaErr)
|
||||||
|
go tsrv.ListenAndServe(tsrv.sock)
|
||||||
|
|
||||||
|
waitForServer(t, tsrv.sock)
|
||||||
|
return tsrv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tsrv *testServer) Cleanup() {
|
||||||
|
tsrv.Close()
|
||||||
|
os.RemoveAll(tsrv.dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mkV(args ...string) url.Values {
|
||||||
|
v := url.Values{}
|
||||||
|
for i := 0; i < len(args); i += 2 {
|
||||||
|
v.Set(args[i], args[i+1])
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndToEnd(t *testing.T) {
|
||||||
|
srv := newTestServer(t)
|
||||||
|
defer srv.Cleanup()
|
||||||
|
|
||||||
|
// Run the client.
|
||||||
|
client := NewClient(srv.sock)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
method string
|
||||||
|
input url.Values
|
||||||
|
output url.Values
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"Echo", nil, mkV(), nil},
|
||||||
|
{"Echo", mkV("msg", "hola"), mkV("msg", "hola"), nil},
|
||||||
|
{"Hola", mkV("name", "marola"), mkV("greeting", "Hola marola"), nil},
|
||||||
|
{"HolaErr", nil, nil, testErr},
|
||||||
|
{"UnknownMethod", nil, nil, errUnknownMethod},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.method, func(t *testing.T) {
|
||||||
|
resp, err := client.CallWithValues(c.method, c.input)
|
||||||
|
if diff := cmp.Diff(c.err, err, transformErrors); diff != "" {
|
||||||
|
t.Errorf("error mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(c.output, resp); diff != "" {
|
||||||
|
t.Errorf("output mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Call too.
|
||||||
|
output, err := client.Call("Hola", "name", "marola")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(mkV("greeting", "Hola marola"), output); diff != "" {
|
||||||
|
t.Errorf("output mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServer(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
conn, err := net.Dial("unix", path)
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatal("server didn't start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow us to compare errors with cmp.Diff by their string content (since the
|
||||||
|
// instances/types don't carry across RPC boundaries).
|
||||||
|
var transformErrors = cmp.Transformer(
|
||||||
|
"error",
|
||||||
|
func(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return "<nil>"
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
})
|
||||||
193
internal/localrpc/localrpc.go
Normal file
193
internal/localrpc/localrpc.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
// Local RPC package.
|
||||||
|
//
|
||||||
|
// This is a simple RPC package that uses a line-oriented protocol for
|
||||||
|
// encoding and decoding, and Unix sockets for transport. It is meant to be
|
||||||
|
// used for lightweight occassional communication between processes on the
|
||||||
|
// same machine.
|
||||||
|
package localrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler is the type of RPC request handlers.
|
||||||
|
type Handler func(tr *trace.Trace, input url.Values) (url.Values, error)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Server
|
||||||
|
//
|
||||||
|
|
||||||
|
// Server represents the RPC server.
|
||||||
|
type Server struct {
|
||||||
|
handlers map[string]Handler
|
||||||
|
lis net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new local RPC server.
|
||||||
|
func NewServer() *Server {
|
||||||
|
return &Server{
|
||||||
|
handlers: make(map[string]Handler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errUnknownMethod = errors.New("unknown method")
|
||||||
|
|
||||||
|
// Register a handler for the given name.
|
||||||
|
func (s *Server) Register(name string, handler Handler) {
|
||||||
|
s.handlers[name] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe starts the server.
|
||||||
|
func (s *Server) ListenAndServe(path string) error {
|
||||||
|
tr := trace.New("LocalRPC.Server", path)
|
||||||
|
defer tr.Finish()
|
||||||
|
|
||||||
|
// Previous instances of the server may have shut down uncleanly, leaving
|
||||||
|
// behind the socket file. Remove it just in case.
|
||||||
|
os.Remove(path)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.lis, err = net.Listen("unix", path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tr.Printf("Listening")
|
||||||
|
for {
|
||||||
|
conn, err := s.lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
tr.Errorf("Accept error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go s.handleConn(tr, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the server.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
return s.lis.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleConn(tr *trace.Trace, conn net.Conn) {
|
||||||
|
tr = tr.NewChild("LocalRPC.Handle", conn.RemoteAddr().String())
|
||||||
|
defer tr.Finish()
|
||||||
|
|
||||||
|
// Set a generous deadline to prevent client issues from tying up a server
|
||||||
|
// goroutine indefinitely.
|
||||||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
tconn := textproto.NewConn(conn)
|
||||||
|
defer tconn.Close()
|
||||||
|
|
||||||
|
// Read the request.
|
||||||
|
name, inS, err := readRequest(&tconn.Reader)
|
||||||
|
if err != nil {
|
||||||
|
tr.Debugf("error reading request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tr.Debugf("<- %s %s", name, inS)
|
||||||
|
|
||||||
|
// Find the handler.
|
||||||
|
handler, ok := s.handlers[name]
|
||||||
|
if !ok {
|
||||||
|
writeError(tr, tconn, errUnknownMethod)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal the input.
|
||||||
|
inV, err := url.ParseQuery(inS)
|
||||||
|
if err != nil {
|
||||||
|
writeError(tr, tconn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the handler.
|
||||||
|
outV, err := handler(tr, inV)
|
||||||
|
if err != nil {
|
||||||
|
writeError(tr, tconn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the response.
|
||||||
|
outS := outV.Encode()
|
||||||
|
tr.Debugf("-> 200 %s", outS)
|
||||||
|
tconn.PrintfLine("200 %s", outS)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRequest(r *textproto.Reader) (string, string, error) {
|
||||||
|
line, err := r.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
sp := strings.SplitN(line, " ", 2)
|
||||||
|
if len(sp) == 1 {
|
||||||
|
return sp[0], "", nil
|
||||||
|
}
|
||||||
|
return sp[0], sp[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(tr *trace.Trace, tconn *textproto.Conn, err error) {
|
||||||
|
tr.Errorf("-> 500 %s", err.Error())
|
||||||
|
tconn.PrintfLine("500 %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default server. This is a singleton server that can be used for
|
||||||
|
// convenience.
|
||||||
|
var DefaultServer = NewServer()
|
||||||
|
|
||||||
|
//
|
||||||
|
// Client
|
||||||
|
//
|
||||||
|
|
||||||
|
// Client for the localrpc server.
|
||||||
|
type Client struct {
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new client for the given path.
|
||||||
|
func NewClient(path string) *Client {
|
||||||
|
return &Client{path: path}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallWithValues calls the given method.
|
||||||
|
func (c *Client) CallWithValues(name string, input url.Values) (url.Values, error) {
|
||||||
|
conn, err := textproto.Dial("unix", c.path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.PrintfLine("%s %s", name, input.Encode())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, msg, err := conn.ReadCodeLine(0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if code != 200 {
|
||||||
|
return nil, errors.New(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.ParseQuery(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the given method. The arguments are key-value strings, and must be
|
||||||
|
// provided in pairs.
|
||||||
|
func (c *Client) Call(name string, args ...string) (url.Values, error) {
|
||||||
|
v := url.Values{}
|
||||||
|
for i := 0; i < len(args); i += 2 {
|
||||||
|
v.Set(args[i], args[i+1])
|
||||||
|
}
|
||||||
|
return c.CallWithValues(name, v)
|
||||||
|
}
|
||||||
66
internal/localrpc/server_test.go
Normal file
66
internal/localrpc/server_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package localrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"blitiri.com.ar/go/chasquid/internal/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestListenError(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
err := server.ListenAndServe("/dev/null")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("ListenAndServe(/dev/null) = nil, want error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the server can handle a broken client sending a bad request.
|
||||||
|
func TestServerBadRequest(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
server.Register("Echo", Echo)
|
||||||
|
|
||||||
|
srvConn, cliConn := net.Pipe()
|
||||||
|
defer srvConn.Close()
|
||||||
|
defer cliConn.Close()
|
||||||
|
|
||||||
|
// Client sends an invalid request.
|
||||||
|
go cliConn.Write([]byte("Echo this is an ; invalid ; query\n"))
|
||||||
|
|
||||||
|
// Servers will handle the connection, and should return an error.
|
||||||
|
tr := trace.New("test", "TestBadRequest")
|
||||||
|
defer tr.Finish()
|
||||||
|
go server.handleConn(tr, srvConn)
|
||||||
|
|
||||||
|
// Read the error that the server should have sent.
|
||||||
|
code, msg, err := textproto.NewConn(cliConn).ReadResponse(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ReadResponse error: %q", err)
|
||||||
|
}
|
||||||
|
if code != 500 {
|
||||||
|
t.Errorf("ReadResponse code %d, expected 500", code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, "invalid semicolon separator") {
|
||||||
|
t.Errorf("ReadResponse message %q, does not contain 'invalid semicolon separator'", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShortReadRequest(t *testing.T) {
|
||||||
|
// This request is too short, it does not have any arguments.
|
||||||
|
// This does not happen with the real client, but just in case.
|
||||||
|
buf := bufio.NewReader(bytes.NewReader([]byte("Method\n")))
|
||||||
|
method, args, err := readRequest(textproto.NewReader(buf))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("readRequest error: %v", err)
|
||||||
|
}
|
||||||
|
if method != "Method" {
|
||||||
|
t.Errorf("readRequest method %q, expected 'Method'", method)
|
||||||
|
}
|
||||||
|
if args != "" {
|
||||||
|
t.Errorf("readRequest args %q, expected ''", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user