mirror of
https://github.com/jhillyerd/inbucket.git
synced 2025-12-18 01:57:02 +00:00
feature: Context support for REST client (#496)
* rest/client: add WithContext methods Signed-off-by: James Hillyerd <james@hillyerd.com> * cmd/client: pass context Signed-off-by: James Hillyerd <james@hillyerd.com> --------- Signed-off-by: James Hillyerd <james@hillyerd.com>
This commit is contained in:
@@ -28,18 +28,20 @@ func (*listCmd) Usage() string {
|
|||||||
func (l *listCmd) SetFlags(f *flag.FlagSet) {}
|
func (l *listCmd) SetFlags(f *flag.FlagSet) {}
|
||||||
|
|
||||||
func (l *listCmd) Execute(
|
func (l *listCmd) Execute(
|
||||||
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
mailbox := f.Arg(0)
|
mailbox := f.Arg(0)
|
||||||
if mailbox == "" {
|
if mailbox == "" {
|
||||||
return usage("mailbox required")
|
return usage("mailbox required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup rest client
|
// Setup rest client
|
||||||
c, err := client.New(baseURL())
|
c, err := client.New(baseURL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Couldn't build client", err)
|
return fatal("Couldn't build client", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get list
|
// Get list
|
||||||
headers, err := c.ListMailbox(mailbox)
|
headers, err := c.ListMailboxWithContext(ctx, mailbox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("REST call failed", err)
|
return fatal("REST call failed", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
type matchCmd struct {
|
type matchCmd struct {
|
||||||
output string
|
output string
|
||||||
outFunc func(headers []*client.MessageHeader) error
|
outFunc func(ctx context.Context, headers []*client.MessageHeader) error
|
||||||
delete bool
|
delete bool
|
||||||
// match criteria
|
// match criteria
|
||||||
from regexFlag
|
from regexFlag
|
||||||
@@ -51,11 +51,12 @@ func (m *matchCmd) SetFlags(f *flag.FlagSet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *matchCmd) Execute(
|
func (m *matchCmd) Execute(
|
||||||
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
mailbox := f.Arg(0)
|
mailbox := f.Arg(0)
|
||||||
if mailbox == "" {
|
if mailbox == "" {
|
||||||
return usage("mailbox required")
|
return usage("mailbox required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select output function
|
// Select output function
|
||||||
switch m.output {
|
switch m.output {
|
||||||
case "id":
|
case "id":
|
||||||
@@ -67,16 +68,19 @@ func (m *matchCmd) Execute(
|
|||||||
default:
|
default:
|
||||||
return usage("unknown output type: " + m.output)
|
return usage("unknown output type: " + m.output)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup REST client
|
// Setup REST client
|
||||||
c, err := client.New(baseURL())
|
c, err := client.New(baseURL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Couldn't build client", err)
|
return fatal("Couldn't build client", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get list
|
// Get list
|
||||||
headers, err := c.ListMailbox(mailbox)
|
headers, err := c.ListMailboxWithContext(ctx, mailbox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("List REST call failed", err)
|
return fatal("List REST call failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find matches
|
// Find matches
|
||||||
matches := make([]*client.MessageHeader, 0, len(headers))
|
matches := make([]*client.MessageHeader, 0, len(headers))
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
@@ -84,24 +88,28 @@ func (m *matchCmd) Execute(
|
|||||||
matches = append(matches, h)
|
matches = append(matches, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return error status if no matches
|
// Return error status if no matches
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
return subcommands.ExitFailure
|
return subcommands.ExitFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output matches
|
// Output matches
|
||||||
err = m.outFunc(matches)
|
err = m.outFunc(ctx, matches)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Error", err)
|
return fatal("Error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Optionally, delete matches
|
||||||
if m.delete {
|
if m.delete {
|
||||||
// Delete matches
|
|
||||||
for _, h := range matches {
|
for _, h := range matches {
|
||||||
err = h.Delete()
|
err = h.DeleteWithContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Delete REST call failed", err)
|
return fatal("Delete REST call failed", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return subcommands.ExitSuccess
|
return subcommands.ExitSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,14 +156,14 @@ func (m *matchCmd) match(header *client.MessageHeader) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputID(headers []*client.MessageHeader) error {
|
func outputID(_ context.Context, headers []*client.MessageHeader) error {
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
fmt.Println(h.ID)
|
fmt.Println(h.ID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputJSON(headers []*client.MessageHeader) error {
|
func outputJSON(_ context.Context, headers []*client.MessageHeader) error {
|
||||||
jsonEncoder := json.NewEncoder(os.Stdout)
|
jsonEncoder := json.NewEncoder(os.Stdout)
|
||||||
jsonEncoder.SetEscapeHTML(false)
|
jsonEncoder.SetEscapeHTML(false)
|
||||||
jsonEncoder.SetIndent("", " ")
|
jsonEncoder.SetIndent("", " ")
|
||||||
|
|||||||
@@ -33,42 +33,46 @@ func (m *mboxCmd) SetFlags(f *flag.FlagSet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mboxCmd) Execute(
|
func (m *mboxCmd) Execute(
|
||||||
_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
mailbox := f.Arg(0)
|
mailbox := f.Arg(0)
|
||||||
if mailbox == "" {
|
if mailbox == "" {
|
||||||
return usage("mailbox required")
|
return usage("mailbox required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup REST client
|
// Setup REST client
|
||||||
c, err := client.New(baseURL())
|
c, err := client.New(baseURL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Couldn't build client", err)
|
return fatal("Couldn't build client", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get list
|
// Get list
|
||||||
headers, err := c.ListMailbox(mailbox)
|
headers, err := c.ListMailboxWithContext(ctx, mailbox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("List REST call failed", err)
|
return fatal("List REST call failed", err)
|
||||||
}
|
}
|
||||||
err = outputMbox(headers)
|
err = outputMbox(ctx, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Error", err)
|
return fatal("Error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Optionally, delete retrieved messages
|
||||||
if m.delete {
|
if m.delete {
|
||||||
// Delete matches
|
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
err = h.Delete()
|
err = h.DeleteWithContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fatal("Delete REST call failed", err)
|
return fatal("Delete REST call failed", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return subcommands.ExitSuccess
|
return subcommands.ExitSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
// outputMbox renders messages in mbox format
|
// outputMbox renders messages in mbox format.
|
||||||
// also used by match subcommand
|
// It is also used by match subcommand.
|
||||||
func outputMbox(headers []*client.MessageHeader) error {
|
func outputMbox(ctx context.Context, headers []*client.MessageHeader) error {
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
source, err := h.GetSource()
|
source, err := h.GetSourceWithContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get source REST failed: %v", err)
|
return fmt.Errorf("get source REST failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -41,33 +42,56 @@ func New(baseURL string, opts ...Option) (*Client, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListMailbox returns a list of messages for the requested mailbox
|
// ListMailbox returns a list of messages for the requested mailbox
|
||||||
func (c *Client) ListMailbox(name string) (headers []*MessageHeader, err error) {
|
func (c *Client) ListMailbox(name string) ([]*MessageHeader, error) {
|
||||||
|
return c.ListMailboxWithContext(context.Background(), name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMailboxWithContext returns a list of messages for the requested mailbox
|
||||||
|
func (c *Client) ListMailboxWithContext(ctx context.Context, name string) ([]*MessageHeader, error) {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name)
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name)
|
||||||
err = c.doJSON("GET", uri, &headers)
|
headers := make([]*MessageHeader, 0, 32)
|
||||||
|
|
||||||
|
err := c.doJSON(ctx, "GET", uri, &headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add Client ref to each MessageHeader for convenience funcs.
|
||||||
for _, h := range headers {
|
for _, h := range headers {
|
||||||
h.client = c
|
h.client = c
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return headers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessage returns the message details given a mailbox name and message ID.
|
// GetMessage returns the message details given a mailbox name and message ID.
|
||||||
func (c *Client) GetMessage(name, id string) (message *Message, err error) {
|
func (c *Client) GetMessage(name, id string) (message *Message, err error) {
|
||||||
|
return c.GetMessageWithContext(context.Background(), name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessageWithContext returns the message details given a mailbox name and message ID.
|
||||||
|
func (c *Client) GetMessageWithContext(ctx context.Context, name, id string) (*Message, error) {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
||||||
err = c.doJSON("GET", uri, &message)
|
var message Message
|
||||||
|
|
||||||
|
err := c.doJSON(ctx, "GET", uri, &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
message.client = c
|
message.client = c
|
||||||
return
|
return &message, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkSeen marks the specified message as having been read.
|
// MarkSeen marks the specified message as having been read.
|
||||||
func (c *Client) MarkSeen(name, id string) error {
|
func (c *Client) MarkSeen(name, id string) error {
|
||||||
|
return c.MarkSeenWithContext(context.Background(), name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSeenWithContext marks the specified message as having been read.
|
||||||
|
func (c *Client) MarkSeenWithContext(ctx context.Context, name, id string) error {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
||||||
err := c.doJSON("PATCH", uri, nil)
|
err := c.doJSON(ctx, "PATCH", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -76,19 +100,25 @@ func (c *Client) MarkSeen(name, id string) error {
|
|||||||
|
|
||||||
// GetMessageSource returns the message source given a mailbox name and message ID.
|
// GetMessageSource returns the message source given a mailbox name and message ID.
|
||||||
func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) {
|
func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) {
|
||||||
|
return c.GetMessageSourceWithContext(context.Background(), name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessageSourceWithContext returns the message source given a mailbox name and message ID.
|
||||||
|
func (c *Client) GetMessageSourceWithContext(ctx context.Context, name, id string) (*bytes.Buffer, error) {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id + "/source"
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id + "/source"
|
||||||
resp, err := c.do("GET", uri, nil)
|
resp, err := c.do(ctx, "GET", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil,
|
return nil,
|
||||||
fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
_, err = buf.ReadFrom(resp.Body)
|
_, err = buf.ReadFrom(resp.Body)
|
||||||
return buf, err
|
return buf, err
|
||||||
@@ -96,29 +126,43 @@ func (c *Client) GetMessageSource(name, id string) (*bytes.Buffer, error) {
|
|||||||
|
|
||||||
// DeleteMessage deletes a single message given the mailbox name and message ID.
|
// DeleteMessage deletes a single message given the mailbox name and message ID.
|
||||||
func (c *Client) DeleteMessage(name, id string) error {
|
func (c *Client) DeleteMessage(name, id string) error {
|
||||||
|
return c.DeleteMessageWithContext(context.Background(), name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessageWithContext deletes a single message given the mailbox name and message ID.
|
||||||
|
func (c *Client) DeleteMessageWithContext(ctx context.Context, name, id string) error {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name) + "/" + id
|
||||||
resp, err := c.do("DELETE", uri, nil)
|
resp, err := c.do(ctx, "DELETE", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PurgeMailbox deletes all messages in the given mailbox
|
// PurgeMailbox deletes all messages in the given mailbox
|
||||||
func (c *Client) PurgeMailbox(name string) error {
|
func (c *Client) PurgeMailbox(name string) error {
|
||||||
|
return c.PurgeMailboxWithContext(context.Background(), name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeMailboxWithContext deletes all messages in the given mailbox
|
||||||
|
func (c *Client) PurgeMailboxWithContext(ctx context.Context, name string) error {
|
||||||
uri := "/api/v1/mailbox/" + url.QueryEscape(name)
|
uri := "/api/v1/mailbox/" + url.QueryEscape(name)
|
||||||
resp, err := c.do("DELETE", uri, nil)
|
resp, err := c.do(ctx, "DELETE", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
return fmt.Errorf("unexpected HTTP response status %v: %s", resp.StatusCode, resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,17 +174,32 @@ type MessageHeader struct {
|
|||||||
|
|
||||||
// GetMessage returns this message with content
|
// GetMessage returns this message with content
|
||||||
func (h *MessageHeader) GetMessage() (message *Message, err error) {
|
func (h *MessageHeader) GetMessage() (message *Message, err error) {
|
||||||
return h.client.GetMessage(h.Mailbox, h.ID)
|
return h.GetMessageWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessageWithContext returns this message with content
|
||||||
|
func (h *MessageHeader) GetMessageWithContext(ctx context.Context) (message *Message, err error) {
|
||||||
|
return h.client.GetMessageWithContext(ctx, h.Mailbox, h.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSource returns the source for this message
|
// GetSource returns the source for this message
|
||||||
func (h *MessageHeader) GetSource() (*bytes.Buffer, error) {
|
func (h *MessageHeader) GetSource() (*bytes.Buffer, error) {
|
||||||
return h.client.GetMessageSource(h.Mailbox, h.ID)
|
return h.GetSourceWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSourceWithContext returns the source for this message
|
||||||
|
func (h *MessageHeader) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) {
|
||||||
|
return h.client.GetMessageSourceWithContext(ctx, h.Mailbox, h.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes this message from the mailbox
|
// Delete deletes this message from the mailbox
|
||||||
func (h *MessageHeader) Delete() error {
|
func (h *MessageHeader) Delete() error {
|
||||||
return h.client.DeleteMessage(h.Mailbox, h.ID)
|
return h.DeleteWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteWithContext deletes this message from the mailbox
|
||||||
|
func (h *MessageHeader) DeleteWithContext(ctx context.Context) error {
|
||||||
|
return h.client.DeleteMessageWithContext(ctx, h.Mailbox, h.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message represents an Inbucket message including content
|
// Message represents an Inbucket message including content
|
||||||
@@ -151,10 +210,20 @@ type Message struct {
|
|||||||
|
|
||||||
// GetSource returns the source for this message
|
// GetSource returns the source for this message
|
||||||
func (m *Message) GetSource() (*bytes.Buffer, error) {
|
func (m *Message) GetSource() (*bytes.Buffer, error) {
|
||||||
return m.client.GetMessageSource(m.Mailbox, m.ID)
|
return m.GetSourceWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSourceWithContext returns the source for this message
|
||||||
|
func (m *Message) GetSourceWithContext(ctx context.Context) (*bytes.Buffer, error) {
|
||||||
|
return m.client.GetMessageSourceWithContext(ctx, m.Mailbox, m.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes this message from the mailbox
|
// Delete deletes this message from the mailbox
|
||||||
func (m *Message) Delete() error {
|
func (m *Message) Delete() error {
|
||||||
return m.client.DeleteMessage(m.Mailbox, m.ID)
|
return m.DeleteWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteWithContext deletes this message from the mailbox
|
||||||
|
func (m *Message) DeleteWithContext(ctx context.Context) error {
|
||||||
|
return m.client.DeleteMessageWithContext(ctx, m.Mailbox, m.ID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -21,22 +22,24 @@ type restClient struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// do performs an HTTP request with this client and returns the response.
|
// do performs an HTTP request with this client and returns the response.
|
||||||
func (c *restClient) do(method, uri string, body []byte) (*http.Response, error) {
|
func (c *restClient) do(ctx context.Context, method, uri string, body []byte) (*http.Response, error) {
|
||||||
url := c.baseURL.JoinPath(uri)
|
url := c.baseURL.JoinPath(uri)
|
||||||
var r io.Reader
|
var r io.Reader
|
||||||
if body != nil {
|
if body != nil {
|
||||||
r = bytes.NewReader(body)
|
r = bytes.NewReader(body)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest(method, url.String(), r)
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url.String(), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%s for %q: %v", method, url, err)
|
return nil, fmt.Errorf("%s for %q: %v", method, url, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.client.Do(req)
|
return c.client.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// doJSON performs an HTTP request with this client and marshalls the JSON response into v.
|
// doJSON performs an HTTP request with this client and marshalls the JSON response into v.
|
||||||
func (c *restClient) doJSON(method string, uri string, v interface{}) error {
|
func (c *restClient) doJSON(ctx context.Context, method string, uri string, v interface{}) error {
|
||||||
resp, err := c.do(method, uri, nil)
|
resp, err := c.do(ctx, method, uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -78,9 +79,11 @@ func TestDoTable(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
testname := fmt.Sprintf("%s,%s", test.method, test.wantURL)
|
testname := fmt.Sprintf("%s,%s", test.method, test.wantURL)
|
||||||
t.Run(testname, func(t *testing.T) {
|
t.Run(testname, func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
mth := &mockHTTPClient{}
|
mth := &mockHTTPClient{}
|
||||||
c := &restClient{mth, test.base}
|
c := &restClient{mth, test.base}
|
||||||
resp, err := c.do(test.method, test.uri, test.wantBody)
|
|
||||||
|
resp, err := c.do(ctx, test.method, test.uri, test.wantBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -107,7 +110,7 @@ func TestDoJSON(t *testing.T) {
|
|||||||
c := &restClient{mth, baseURL}
|
c := &restClient{mth, baseURL}
|
||||||
|
|
||||||
var v map[string]interface{}
|
var v map[string]interface{}
|
||||||
err := c.doJSON("GET", "/doget", &v)
|
err := c.doJSON(context.Background(), "GET", "/doget", &v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -141,7 +144,7 @@ func TestDoJSONNilV(t *testing.T) {
|
|||||||
mth := &mockHTTPClient{}
|
mth := &mockHTTPClient{}
|
||||||
c := &restClient{mth, baseURL}
|
c := &restClient{mth, baseURL}
|
||||||
|
|
||||||
err := c.doJSON("GET", "/doget", nil)
|
err := c.doJSON(context.Background(), "GET", "/doget", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user