1
0
mirror of https://github.com/jhillyerd/inbucket.git synced 2025-12-18 01:57:02 +00:00

feature: Context support for REST client (#496)

* rest/client: add WithContext methods

Signed-off-by: James Hillyerd <james@hillyerd.com>

* cmd/client: pass context

Signed-off-by: James Hillyerd <james@hillyerd.com>

---------

Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
James Hillyerd
2024-02-17 19:25:01 -08:00
committed by GitHub
parent 73203c6bcd
commit 0a51641a30
6 changed files with 130 additions and 41 deletions

View File

@@ -28,18 +28,20 @@ func (*listCmd) Usage() string {
func (l *listCmd) SetFlags(f *flag.FlagSet) {} func (l *listCmd) SetFlags(f *flag.FlagSet) {}
func (l *listCmd) Execute( func (l *listCmd) Execute(
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
mailbox := f.Arg(0) mailbox := f.Arg(0)
if mailbox == "" { if mailbox == "" {
return usage("mailbox required") return usage("mailbox required")
} }
// Setup rest client // Setup rest client
c, err := client.New(baseURL()) c, err := client.New(baseURL())
if err != nil { if err != nil {
return fatal("Couldn't build client", err) return fatal("Couldn't build client", err)
} }
// Get list // Get list
headers, err := c.ListMailbox(mailbox) headers, err := c.ListMailboxWithContext(ctx, mailbox)
if err != nil { if err != nil {
return fatal("REST call failed", err) return fatal("REST call failed", err)
} }

View File

@@ -15,7 +15,7 @@ import (
type matchCmd struct { type matchCmd struct {
output string output string
outFunc func(headers []*client.MessageHeader) error outFunc func(ctx context.Context, headers []*client.MessageHeader) error
delete bool delete bool
// match criteria // match criteria
from regexFlag from regexFlag
@@ -51,11 +51,12 @@ func (m *matchCmd) SetFlags(f *flag.FlagSet) {
} }
func (m *matchCmd) Execute( func (m *matchCmd) Execute(
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
mailbox := f.Arg(0) mailbox := f.Arg(0)
if mailbox == "" { if mailbox == "" {
return usage("mailbox required") return usage("mailbox required")
} }
// Select output function // Select output function
switch m.output { switch m.output {
case "id": case "id":
@@ -67,16 +68,19 @@ func (m *matchCmd) Execute(
default: default:
return usage("unknown output type: " + m.output) return usage("unknown output type: " + m.output)
} }
// Setup REST client // Setup REST client
c, err := client.New(baseURL()) c, err := client.New(baseURL())
if err != nil { if err != nil {
return fatal("Couldn't build client", err) return fatal("Couldn't build client", err)
} }
// Get list // Get list
headers, err := c.ListMailbox(mailbox) headers, err := c.ListMailboxWithContext(ctx, mailbox)
if err != nil { if err != nil {
return fatal("List REST call failed", err) return fatal("List REST call failed", err)
} }
// Find matches // Find matches
matches := make([]*client.MessageHeader, 0, len(headers)) matches := make([]*client.MessageHeader, 0, len(headers))
for _, h := range headers { for _, h := range headers {
@@ -84,24 +88,28 @@ func (m *matchCmd) Execute(
matches = append(matches, h) matches = append(matches, h)
} }
} }
// Return error status if no matches // Return error status if no matches
if len(matches) == 0 { if len(matches) == 0 {
return subcommands.ExitFailure return subcommands.ExitFailure
} }
// Output matches // Output matches
err = m.outFunc(matches) err = m.outFunc(ctx, matches)
if err != nil { if err != nil {
return fatal("Error", err) return fatal("Error", err)
} }
// Optionally, delete matches
if m.delete { if m.delete {
// Delete matches
for _, h := range matches { for _, h := range matches {
err = h.Delete() err = h.DeleteWithContext(ctx)
if err != nil { if err != nil {
return fatal("Delete REST call failed", err) return fatal("Delete REST call failed", err)
} }
} }
} }
return subcommands.ExitSuccess return subcommands.ExitSuccess
} }
@@ -148,14 +156,14 @@ func (m *matchCmd) match(header *client.MessageHeader) bool {
return true return true
} }
func outputID(headers []*client.MessageHeader) error { func outputID(_ context.Context, headers []*client.MessageHeader) error {
for _, h := range headers { for _, h := range headers {
fmt.Println(h.ID) fmt.Println(h.ID)
} }
return nil return nil
} }
func outputJSON(headers []*client.MessageHeader) error { func outputJSON(_ context.Context, headers []*client.MessageHeader) error {
jsonEncoder := json.NewEncoder(os.Stdout) jsonEncoder := json.NewEncoder(os.Stdout)
jsonEncoder.SetEscapeHTML(false) jsonEncoder.SetEscapeHTML(false)
jsonEncoder.SetIndent("", " ") jsonEncoder.SetIndent("", " ")

View File

@@ -33,42 +33,46 @@ func (m *mboxCmd) SetFlags(f *flag.FlagSet) {
} }
func (m *mboxCmd) Execute( func (m *mboxCmd) Execute(
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
mailbox := f.Arg(0) mailbox := f.Arg(0)
if mailbox == "" { if mailbox == "" {
return usage("mailbox required") return usage("mailbox required")
} }
// Setup REST client // Setup REST client
c, err := client.New(baseURL()) c, err := client.New(baseURL())
if err != nil { if err != nil {
return fatal("Couldn't build client", err) return fatal("Couldn't build client", err)
} }
// Get list // Get list
headers, err := c.ListMailbox(mailbox) headers, err := c.ListMailboxWithContext(ctx, mailbox)
if err != nil { if err != nil {
return fatal("List REST call failed", err) return fatal("List REST call failed", err)
} }
err = outputMbox(headers) err = outputMbox(ctx, headers)
if err != nil { if err != nil {
return fatal("Error", err) return fatal("Error", err)
} }
// Optionally, delete retrieved messages
if m.delete { if m.delete {
// Delete matches
for _, h := range headers { for _, h := range headers {
err = h.Delete() err = h.DeleteWithContext(ctx)
if err != nil { if err != nil {
return fatal("Delete REST call failed", err) return fatal("Delete REST call failed", err)
} }
} }
} }
return subcommands.ExitSuccess return subcommands.ExitSuccess
} }
// outputMbox renders messages in mbox format // outputMbox renders messages in mbox format.
// also used by match subcommand // It is also used by match subcommand.
func outputMbox(headers []*client.MessageHeader) error { func outputMbox(ctx context.Context, headers []*client.MessageHeader) error {
for _, h := range headers { for _, h := range headers {
source, err := h.GetSource() source, err := h.GetSourceWithContext(ctx)
if err != nil { if err != nil {
return fmt.Errorf("get source REST failed: %v", err) return fmt.Errorf("get source REST failed: %v", err)
} }

View File

@@ -3,6 +3,7 @@ package client
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -41,33 +42,56 @@ func New(baseURL string, opts ...Option) (*Client, error) {
} }
// ListMailbox returns a list of messages for the requested mailbox // ListMailbox returns a list of messages for the requested mailbox
func (c *Client) ListMailbox(name string) (headers []*MessageHeader, err error) { func (c *Client) ListMailbox(name string) ([]*MessageHeader, error) {
return c.ListMailboxWithContext(context.Background(), name)
}
// ListMailboxWithContext returns a list of messages for the requested mailbox
func (c *Client) ListMailboxWithContext(ctx context.Context, name string) ([]*MessageHeader, error) {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) uri := "/api/v1/mailbox/" + url.QueryEscape(name)
err = c.doJSON("GET", uri, &headers) headers := make([]*MessageHeader, 0, 32)
err := c.doJSON(ctx, "GET", uri, &headers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Add Client ref to each MessageHeader for convenience funcs.
for _, h := range headers { for _, h := range headers {
h.client = c h.client = c
} }
return
return headers, nil
} }
// GetMessage returns the message details given a mailbox name and message ID. // GetMessage returns the message details given a mailbox name and message ID.
func (c *Client) GetMessage(name, id string) (message *Message, err error) { func (c *Client) GetMessage(name, id string) (message *Message, err error) {
return c.GetMessageWithContext(context.Background(), name, id)
}
// GetMessageWithContext returns the message details given a mailbox name and message ID.
func (c *Client) GetMessageWithContext(ctx context.Context, name, id string) (*Message, error) {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
err = c.doJSON("GET", uri, &message) var message Message
err := c.doJSON(ctx, "GET", uri, &message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
message.client = c message.client = c
return return &message, nil
} }
// MarkSeen marks the specified message as having been read. // MarkSeen marks the specified message as having been read.
func (c *Client) MarkSeen(name, id string) error { func (c *Client) MarkSeen(name, id string) error {
return c.MarkSeenWithContext(context.Background(), name, id)
}
// MarkSeenWithContext marks the specified message as having been read.
func (c *Client) MarkSeenWithContext(ctx context.Context, name, id string) error {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
err := c.doJSON("PATCH", uri, nil) err := c.doJSON(ctx, "PATCH", uri, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -76,19 +100,25 @@ func (c *Client) MarkSeen(name, id string) error {
// GetMessageSource returns the message source given a mailbox name and message ID. // GetMessageSource returns the message source given a mailbox name and message ID.
func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) { func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) {
return c.GetMessageSourceWithContext(context.Background(), name, id)
}
// GetMessageSourceWithContext returns the message source given a mailbox name and message ID.
func (c *Client) GetMessageSourceWithContext(ctx context.Context, name, id string) (*bytes.Buffer, error) {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id + "/source" uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id + "/source"
resp, err := c.do("GET", uri, nil) resp, err := c.do(ctx, "GET", uri, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = resp.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, return nil,
fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
_, err = buf.ReadFrom(resp.Body) _, err = buf.ReadFrom(resp.Body)
return buf, err return buf, err
@@ -96,29 +126,43 @@ func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) {
// DeleteMessage deletes a single message given the mailbox name and message ID. // DeleteMessage deletes a single message given the mailbox name and message ID.
func (c *Client) DeleteMessage(name, id string) error { func (c *Client) DeleteMessage(name, id string) error {
return c.DeleteMessageWithContext(context.Background(), name, id)
}
// DeleteMessageWithContext deletes a single message given the mailbox name and message ID.
func (c *Client) DeleteMessageWithContext(ctx context.Context, name, id string) error {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
resp, err := c.do("DELETE", uri, nil) resp, err := c.do(ctx, "DELETE", uri, nil)
if err != nil { if err != nil {
return err return err
} }
_ = resp.Body.Close() _ = resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
} }
return nil return nil
} }
// PurgeMailbox deletes all messages in the given mailbox // PurgeMailbox deletes all messages in the given mailbox
func (c *Client) PurgeMailbox(name string) error { func (c *Client) PurgeMailbox(name string) error {
return c.PurgeMailboxWithContext(context.Background(), name)
}
// PurgeMailboxWithContext deletes all messages in the given mailbox
func (c *Client) PurgeMailboxWithContext(ctx context.Context, name string) error {
uri := "/api/v1/mailbox/" + url.QueryEscape(name) uri := "/api/v1/mailbox/" + url.QueryEscape(name)
resp, err := c.do("DELETE", uri, nil) resp, err := c.do(ctx, "DELETE", uri, nil)
if err != nil { if err != nil {
return err return err
} }
_ = resp.Body.Close() _ = resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status) return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
} }
return nil return nil
} }
@@ -130,17 +174,32 @@ type MessageHeader struct {
// GetMessage returns this message with content // GetMessage returns this message with content
func (h *MessageHeader) GetMessage() (message *Message, err error) { func (h *MessageHeader) GetMessage() (message *Message, err error) {
return h.client.GetMessage(h.Mailbox, h.ID) return h.GetMessageWithContext(context.Background())
}
// GetMessageWithContext returns this message with content
func (h *MessageHeader) GetMessageWithContext(ctx context.Context) (message *Message, err error) {
return h.client.GetMessageWithContext(ctx, h.Mailbox, h.ID)
} }
// GetSource returns the source for this message // GetSource returns the source for this message
func (h *MessageHeader) GetSource() (*bytes.Buffer, error) { func (h *MessageHeader) GetSource() (*bytes.Buffer, error) {
return h.client.GetMessageSource(h.Mailbox, h.ID) return h.GetSourceWithContext(context.Background())
}
// GetSourceWithContext returns the source for this message
func (h *MessageHeader) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) {
return h.client.GetMessageSourceWithContext(ctx, h.Mailbox, h.ID)
} }
// Delete deletes this message from the mailbox // Delete deletes this message from the mailbox
func (h *MessageHeader) Delete() error { func (h *MessageHeader) Delete() error {
return h.client.DeleteMessage(h.Mailbox, h.ID) return h.DeleteWithContext(context.Background())
}
// DeleteWithContext deletes this message from the mailbox
func (h *MessageHeader) DeleteWithContext(ctx context.Context) error {
return h.client.DeleteMessageWithContext(ctx, h.Mailbox, h.ID)
} }
// Message represents an Inbucket message including content // Message represents an Inbucket message including content
@@ -151,10 +210,20 @@ type Message struct {
// GetSource returns the source for this message // GetSource returns the source for this message
func (m *Message) GetSource() (*bytes.Buffer, error) { func (m *Message) GetSource() (*bytes.Buffer, error) {
return m.client.GetMessageSource(m.Mailbox, m.ID) return m.GetSourceWithContext(context.Background())
}
// GetSourceWithContext returns the source for this message
func (m *Message) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) {
return m.client.GetMessageSourceWithContext(ctx, m.Mailbox, m.ID)
} }
// Delete deletes this message from the mailbox // Delete deletes this message from the mailbox
func (m *Message) Delete() error { func (m *Message) Delete() error {
return m.client.DeleteMessage(m.Mailbox, m.ID) return m.DeleteWithContext(context.Background())
}
// DeleteWithContext deletes this message from the mailbox
func (m *Message) DeleteWithContext(ctx context.Context) error {
return m.client.DeleteMessageWithContext(ctx, m.Mailbox, m.ID)
} }

View File

@@ -2,6 +2,7 @@ package client
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -21,22 +22,24 @@ type restClient struct {
} }
// do performs an HTTP request with this client and returns the response. // do performs an HTTP request with this client and returns the response.
func (c *restClient) do(method, uri string, body []byte) (*http.Response, error) { func (c *restClient) do(ctx context.Context, method, uri string, body []byte) (*http.Response, error) {
url := c.baseURL.JoinPath(uri) url := c.baseURL.JoinPath(uri)
var r io.Reader var r io.Reader
if body != nil { if body != nil {
r = bytes.NewReader(body) r = bytes.NewReader(body)
} }
req, err := http.NewRequest(method, url.String(), r)
req, err := http.NewRequestWithContext(ctx, method, url.String(), r)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s for %q: %v", method, url, err) return nil, fmt.Errorf("%s for %q: %v", method, url, err)
} }
return c.client.Do(req) return c.client.Do(req)
} }
// doJSON performs an HTTP request with this client and marshalls the JSON response into v. // doJSON performs an HTTP request with this client and marshalls the JSON response into v.
func (c *restClient) doJSON(method string, uri string, v interface{}) error { func (c *restClient) doJSON(ctx context.Context, method string, uri string, v interface{}) error {
resp, err := c.do(method, uri, nil) resp, err := c.do(ctx, method, uri, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2,6 +2,7 @@ package client
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -78,9 +79,11 @@ func TestDoTable(t *testing.T) {
for _, test := range tests { for _, test := range tests {
testname := fmt.Sprintf("%s,%s", test.method, test.wantURL) testname := fmt.Sprintf("%s,%s", test.method, test.wantURL)
t.Run(testname, func(t *testing.T) { t.Run(testname, func(t *testing.T) {
ctx := context.Background()
mth := &mockHTTPClient{} mth := &mockHTTPClient{}
c := &restClient{mth, test.base} c := &restClient{mth, test.base}
resp, err := c.do(test.method, test.uri, test.wantBody)
resp, err := c.do(ctx, test.method, test.uri, test.wantBody)
require.NoError(t, err) require.NoError(t, err)
err = resp.Body.Close() err = resp.Body.Close()
require.NoError(t, err) require.NoError(t, err)
@@ -107,7 +110,7 @@ func TestDoJSON(t *testing.T) {
c := &restClient{mth, baseURL} c := &restClient{mth, baseURL}
var v map[string]interface{} var v map[string]interface{}
err := c.doJSON("GET", "/doget", &v) err := c.doJSON(context.Background(), "GET", "/doget", &v)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -141,7 +144,7 @@ func TestDoJSONNilV(t *testing.T) {
mth := &mockHTTPClient{} mth := &mockHTTPClient{}
c := &restClient{mth, baseURL} c := &restClient{mth, baseURL}
err := c.doJSON("GET", "/doget", nil) err := c.doJSON(context.Background(), "GET", "/doget", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }