From 0a51641a30f8f8c7f439a5e4329447e20a48dabf Mon Sep 17 00:00:00 2001 From: James Hillyerd Date: Sat, 17 Feb 2024 19:25:01 -0800 Subject: [PATCH] feature: Context support for REST client (#496) * rest/client: add WithContext methods Signed-off-by: James Hillyerd * cmd/client: pass context Signed-off-by: James Hillyerd --------- Signed-off-by: James Hillyerd --- cmd/client/list.go | 6 +- cmd/client/match.go | 24 +++++--- cmd/client/mbox.go | 22 +++++--- pkg/rest/client/apiv1_client.go | 99 ++++++++++++++++++++++++++++----- pkg/rest/client/rest.go | 11 ++-- pkg/rest/client/rest_test.go | 9 ++- 6 files changed, 130 insertions(+), 41 deletions(-) diff --git a/cmd/client/list.go b/cmd/client/list.go index b7502a9..717e193 100644 --- a/cmd/client/list.go +++ b/cmd/client/list.go @@ -28,18 +28,20 @@ func (*listCmd) Usage() string { func (l *listCmd) SetFlags(f *flag.FlagSet) {} func (l *listCmd) Execute( - _ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { mailbox := f.Arg(0) if mailbox == "" { return usage("mailbox required") } + // Setup rest client c, err := client.New(baseURL()) if err != nil { return fatal("Couldn't build client", err) } + // Get list - headers, err := c.ListMailbox(mailbox) + headers, err := c.ListMailboxWithContext(ctx, mailbox) if err != nil { return fatal("REST call failed", err) } diff --git a/cmd/client/match.go b/cmd/client/match.go index 2f8a6ca..f447d0c 100644 --- a/cmd/client/match.go +++ b/cmd/client/match.go @@ -15,7 +15,7 @@ import ( type matchCmd struct { output string - outFunc func(headers []*client.MessageHeader) error + outFunc func(ctx context.Context, headers []*client.MessageHeader) error delete bool // match criteria from regexFlag @@ -51,11 +51,12 @@ func (m *matchCmd) SetFlags(f *flag.FlagSet) { } func (m *matchCmd) Execute( - _ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { mailbox := f.Arg(0) if mailbox == "" { return usage("mailbox required") } + // Select output function switch m.output { case "id": @@ -67,16 +68,19 @@ func (m *matchCmd) Execute( default: return usage("unknown output type: " + m.output) } + // Setup REST client c, err := client.New(baseURL()) if err != nil { return fatal("Couldn't build client", err) } + // Get list - headers, err := c.ListMailbox(mailbox) + headers, err := c.ListMailboxWithContext(ctx, mailbox) if err != nil { return fatal("List REST call failed", err) } + // Find matches matches := make([]*client.MessageHeader, 0, len(headers)) for _, h := range headers { @@ -84,24 +88,28 @@ func (m *matchCmd) Execute( matches = append(matches, h) } } + // Return error status if no matches if len(matches) == 0 { return subcommands.ExitFailure } + // Output matches - err = m.outFunc(matches) + err = m.outFunc(ctx, matches) if err != nil { return fatal("Error", err) } + + // Optionally, delete matches if m.delete { - // Delete matches for _, h := range matches { - err = h.Delete() + err = h.DeleteWithContext(ctx) if err != nil { return fatal("Delete REST call failed", err) } } } + return subcommands.ExitSuccess } @@ -148,14 +156,14 @@ func (m *matchCmd) match(header *client.MessageHeader) bool { return true } -func outputID(headers []*client.MessageHeader) error { +func outputID(_ context.Context, headers []*client.MessageHeader) error { for _, h := range headers { fmt.Println(h.ID) } return nil } -func outputJSON(headers []*client.MessageHeader) error { +func outputJSON(_ context.Context, headers []*client.MessageHeader) error { jsonEncoder := json.NewEncoder(os.Stdout) jsonEncoder.SetEscapeHTML(false) jsonEncoder.SetIndent("", " ") diff --git a/cmd/client/mbox.go b/cmd/client/mbox.go index c8cc849..6e8db2f 100644 --- a/cmd/client/mbox.go +++ b/cmd/client/mbox.go @@ -33,42 +33,46 @@ func (m *mboxCmd) SetFlags(f *flag.FlagSet) { } func (m *mboxCmd) Execute( - _ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { mailbox := f.Arg(0) if mailbox == "" { return usage("mailbox required") } + // Setup REST client c, err := client.New(baseURL()) if err != nil { return fatal("Couldn't build client", err) } + // Get list - headers, err := c.ListMailbox(mailbox) + headers, err := c.ListMailboxWithContext(ctx, mailbox) if err != nil { return fatal("List REST call failed", err) } - err = outputMbox(headers) + err = outputMbox(ctx, headers) if err != nil { return fatal("Error", err) } + + // Optionally, delete retrieved messages if m.delete { - // Delete matches for _, h := range headers { - err = h.Delete() + err = h.DeleteWithContext(ctx) if err != nil { return fatal("Delete REST call failed", err) } } } + return subcommands.ExitSuccess } -// outputMbox renders messages in mbox format -// also used by match subcommand -func outputMbox(headers []*client.MessageHeader) error { +// outputMbox renders messages in mbox format. +// It is also used by match subcommand. +func outputMbox(ctx context.Context, headers []*client.MessageHeader) error { for _, h := range headers { - source, err := h.GetSource() + source, err := h.GetSourceWithContext(ctx) if err != nil { return fmt.Errorf("get source REST failed: %v", err) } diff --git a/pkg/rest/client/apiv1_client.go b/pkg/rest/client/apiv1_client.go index 6889d34..4650cf8 100644 --- a/pkg/rest/client/apiv1_client.go +++ b/pkg/rest/client/apiv1_client.go @@ -3,6 +3,7 @@ package client import ( "bytes" + "context" "fmt" "net/http" "net/url" @@ -41,33 +42,56 @@ func New(baseURL string, opts ...Option) (*Client, error) { } // ListMailbox returns a list of messages for the requested mailbox -func (c *Client) ListMailbox(name string) (headers []*MessageHeader, err error) { +func (c *Client) ListMailbox(name string) ([]*MessageHeader, error) { + return c.ListMailboxWithContext(context.Background(), name) +} + +// ListMailboxWithContext returns a list of messages for the requested mailbox +func (c *Client) ListMailboxWithContext(ctx context.Context, name string) ([]*MessageHeader, error) { uri := "/api/v1/mailbox/" + url.QueryEscape(name) - err = c.doJSON("GET", uri, &headers) + headers := make([]*MessageHeader, 0, 32) + + err := c.doJSON(ctx, "GET", uri, &headers) if err != nil { return nil, err } + + // Add Client ref to each MessageHeader for convenience funcs. for _, h := range headers { h.client = c } - return + + return headers, nil } // GetMessage returns the message details given a mailbox name and message ID. func (c *Client) GetMessage(name, id string) (message *Message, err error) { + return c.GetMessageWithContext(context.Background(), name, id) +} + +// GetMessageWithContext returns the message details given a mailbox name and message ID. +func (c *Client) GetMessageWithContext(ctx context.Context, name, id string) (*Message, error) { uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id - err = c.doJSON("GET", uri, &message) + var message Message + + err := c.doJSON(ctx, "GET", uri, &message) if err != nil { return nil, err } + message.client = c - return + return &message, nil } // MarkSeen marks the specified message as having been read. func (c *Client) MarkSeen(name, id string) error { + return c.MarkSeenWithContext(context.Background(), name, id) +} + +// MarkSeenWithContext marks the specified message as having been read. +func (c *Client) MarkSeenWithContext(ctx context.Context, name, id string) error { uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id - err := c.doJSON("PATCH", uri, nil) + err := c.doJSON(ctx, "PATCH", uri, nil) if err != nil { return err } @@ -76,19 +100,25 @@ func (c *Client) MarkSeen(name, id string) error { // GetMessageSource returns the message source given a mailbox name and message ID. func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) { + return c.GetMessageSourceWithContext(context.Background(), name, id) +} + +// GetMessageSourceWithContext returns the message source given a mailbox name and message ID. +func (c *Client) GetMessageSourceWithContext(ctx context.Context, name, id string) (*bytes.Buffer, error) { uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id + "/source" - resp, err := c.do("GET", uri, nil) + resp, err := c.do(ctx, "GET", uri, nil) if err != nil { return nil, err } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) } + buf := new(bytes.Buffer) _, err = buf.ReadFrom(resp.Body) return buf, err @@ -96,29 +126,43 @@ func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) { // DeleteMessage deletes a single message given the mailbox name and message ID. func (c *Client) DeleteMessage(name, id string) error { + return c.DeleteMessageWithContext(context.Background(), name, id) +} + +// DeleteMessageWithContext deletes a single message given the mailbox name and message ID. +func (c *Client) DeleteMessageWithContext(ctx context.Context, name, id string) error { uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id - resp, err := c.do("DELETE", uri, nil) + resp, err := c.do(ctx, "DELETE", uri, nil) if err != nil { return err } _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) } + return nil } // PurgeMailbox deletes all messages in the given mailbox func (c *Client) PurgeMailbox(name string) error { + return c.PurgeMailboxWithContext(context.Background(), name) +} + +// PurgeMailboxWithContext deletes all messages in the given mailbox +func (c *Client) PurgeMailboxWithContext(ctx context.Context, name string) error { uri := "/api/v1/mailbox/" + url.QueryEscape(name) - resp, err := c.do("DELETE", uri, nil) + resp, err := c.do(ctx, "DELETE", uri, nil) if err != nil { return err } _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) } + return nil } @@ -130,17 +174,32 @@ type MessageHeader struct { // GetMessage returns this message with content func (h *MessageHeader) GetMessage() (message *Message, err error) { - return h.client.GetMessage(h.Mailbox, h.ID) + return h.GetMessageWithContext(context.Background()) +} + +// GetMessageWithContext returns this message with content +func (h *MessageHeader) GetMessageWithContext(ctx context.Context) (message *Message, err error) { + return h.client.GetMessageWithContext(ctx, h.Mailbox, h.ID) } // GetSource returns the source for this message func (h *MessageHeader) GetSource() (*bytes.Buffer, error) { - return h.client.GetMessageSource(h.Mailbox, h.ID) + return h.GetSourceWithContext(context.Background()) +} + +// GetSourceWithContext returns the source for this message +func (h *MessageHeader) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) { + return h.client.GetMessageSourceWithContext(ctx, h.Mailbox, h.ID) } // Delete deletes this message from the mailbox func (h *MessageHeader) Delete() error { - return h.client.DeleteMessage(h.Mailbox, h.ID) + return h.DeleteWithContext(context.Background()) +} + +// DeleteWithContext deletes this message from the mailbox +func (h *MessageHeader) DeleteWithContext(ctx context.Context) error { + return h.client.DeleteMessageWithContext(ctx, h.Mailbox, h.ID) } // Message represents an Inbucket message including content @@ -151,10 +210,20 @@ type Message struct { // GetSource returns the source for this message func (m *Message) GetSource() (*bytes.Buffer, error) { - return m.client.GetMessageSource(m.Mailbox, m.ID) + return m.GetSourceWithContext(context.Background()) +} + +// GetSourceWithContext returns the source for this message +func (m *Message) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) { + return m.client.GetMessageSourceWithContext(ctx, m.Mailbox, m.ID) } // Delete deletes this message from the mailbox func (m *Message) Delete() error { - return m.client.DeleteMessage(m.Mailbox, m.ID) + return m.DeleteWithContext(context.Background()) +} + +// DeleteWithContext deletes this message from the mailbox +func (m *Message) DeleteWithContext(ctx context.Context) error { + return m.client.DeleteMessageWithContext(ctx, m.Mailbox, m.ID) } diff --git a/pkg/rest/client/rest.go b/pkg/rest/client/rest.go index 6fd1402..8de01fa 100644 --- a/pkg/rest/client/rest.go +++ b/pkg/rest/client/rest.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -21,22 +22,24 @@ type restClient struct { } // do performs an HTTP request with this client and returns the response. -func (c *restClient) do(method, uri string, body []byte) (*http.Response, error) { +func (c *restClient) do(ctx context.Context, method, uri string, body []byte) (*http.Response, error) { url := c.baseURL.JoinPath(uri) var r io.Reader if body != nil { r = bytes.NewReader(body) } - req, err := http.NewRequest(method, url.String(), r) + + req, err := http.NewRequestWithContext(ctx, method, url.String(), r) if err != nil { return nil, fmt.Errorf("%s for %q: %v", method, url, err) } + return c.client.Do(req) } // doJSON performs an HTTP request with this client and marshalls the JSON response into v. -func (c *restClient) doJSON(method string, uri string, v interface{}) error { - resp, err := c.do(method, uri, nil) +func (c *restClient) doJSON(ctx context.Context, method string, uri string, v interface{}) error { + resp, err := c.do(ctx, method, uri, nil) if err != nil { return err } diff --git a/pkg/rest/client/rest_test.go b/pkg/rest/client/rest_test.go index fc58161..efa1410 100644 --- a/pkg/rest/client/rest_test.go +++ b/pkg/rest/client/rest_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "fmt" "io" "net/http" @@ -78,9 +79,11 @@ func TestDoTable(t *testing.T) { for _, test := range tests { testname := fmt.Sprintf("%s,%s", test.method, test.wantURL) t.Run(testname, func(t *testing.T) { + ctx := context.Background() mth := &mockHTTPClient{} c := &restClient{mth, test.base} - resp, err := c.do(test.method, test.uri, test.wantBody) + + resp, err := c.do(ctx, test.method, test.uri, test.wantBody) require.NoError(t, err) err = resp.Body.Close() require.NoError(t, err) @@ -107,7 +110,7 @@ func TestDoJSON(t *testing.T) { c := &restClient{mth, baseURL} var v map[string]interface{} - err := c.doJSON("GET", "/doget", &v) + err := c.doJSON(context.Background(), "GET", "/doget", &v) if err != nil { t.Fatal(err) } @@ -141,7 +144,7 @@ func TestDoJSONNilV(t *testing.T) { mth := &mockHTTPClient{} c := &restClient{mth, baseURL} - err := c.doJSON("GET", "/doget", nil) + err := c.doJSON(context.Background(), "GET", "/doget", nil) if err != nil { t.Fatal(err) }