diff --git a/pkg/rest/client/apiv1_client.go b/pkg/rest/client/apiv1_client.go index ce8990a..5a50c56 100644 --- a/pkg/rest/client/apiv1_client.go +++ b/pkg/rest/client/apiv1_client.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "net/url" - "time" "github.com/inbucket/inbucket/v3/pkg/rest/model" ) @@ -18,15 +17,22 @@ type Client struct { // New creates a new v1 REST API client given the base URL of an Inbucket server, ex: // "http://localhost:9000" -func New(baseURL string) (*Client, error) { +func New(baseURL string, opts ...Option) (*Client, error) { parsedURL, err := url.Parse(baseURL) if err != nil { return nil, err } + + mergedOpts := getDefaultOptions() + for _, opt := range opts { + opt.apply(mergedOpts) + } + c := &Client{ restClient{ client: &http.Client{ - Timeout: 30 * time.Second, + Timeout: mergedOpts.timeout, + Transport: mergedOpts.transport, }, baseURL: parsedURL, }, diff --git a/pkg/rest/client/apiv1_client_opts.go b/pkg/rest/client/apiv1_client_opts.go new file mode 100644 index 0000000..49fd67c --- /dev/null +++ b/pkg/rest/client/apiv1_client_opts.go @@ -0,0 +1,38 @@ +package client + +import ( + "net/http" + "time" +) + +// options is a struct that holds the options for the rest client +type options struct { + transport http.RoundTripper + timeout time.Duration +} + +type Option interface { + apply(*options) +} + +func getDefaultOptions() *options { + return &options{ + timeout: 30 * time.Second, + } +} + +type transportOption struct { + transport http.RoundTripper +} + +func (t transportOption) apply(opts *options) { + opts.transport = t.transport +} + +// WithOptTransport sets the transport for the rest client. +// Transport specifies the mechanism by which individual +// HTTP requests are made. +// If nil, http.DefaultTransport is used. +func WithOptTransport(transport http.RoundTripper) Option { + return transportOption{transport} +} diff --git a/pkg/rest/client/apiv1_client_test.go b/pkg/rest/client/apiv1_client_test.go index 22eb2f5..f7f84e2 100644 --- a/pkg/rest/client/apiv1_client_test.go +++ b/pkg/rest/client/apiv1_client_test.go @@ -1,6 +1,8 @@ package client_test import ( + "bytes" + "io" "net/http" "net/http/httptest" "testing" @@ -210,6 +212,33 @@ func TestClientV1GetMessageSource(t *testing.T) { } } +func TestClientV1WithCustomTransport(t *testing.T) { + // Call setup, passing a custom roundtripper and make sure it was used during the request. + mockRoundTripper := &mockRoundTripper{ResponseBody: "Custom Transport"} + c, router, teardown := setup(client.WithOptTransport(mockRoundTripper)) + + defer teardown() + + router.Path("/api/v1/mailbox/testbox/20170107T224128-0000/source").Methods("GET"). + Handler(&jsonHandler{json: `message source`}) + + // Method under test. + source, err := c.GetMessageSource("testbox", "20170107T224128-0000") + if err != nil { + t.Fatal(err) + } + + want := mockRoundTripper.ResponseBody + got := source.String() + if got != want { + t.Errorf("Source got %q, want %q", got, want) + } + + if mockRoundTripper.CallCount != 1 { + t.Errorf("RoundTripper called %v times, want 1", mockRoundTripper.CallCount) + } +} + func TestClientV1DeleteMessage(t *testing.T) { // Setup. c, router, teardown := setup() @@ -337,11 +366,24 @@ func TestClientV1MessageHeader(t *testing.T) { } } +type mockRoundTripper struct { + ResponseBody string + CallCount int +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.CallCount++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(m.ResponseBody)), + }, nil +} + // setup returns a client, router and server for API testing. -func setup() (c *client.Client, router *mux.Router, teardown func()) { +func setup(opts ...client.Option) (c *client.Client, router *mux.Router, teardown func()) { router = mux.NewRouter() server := httptest.NewServer(router) - c, err := client.New(server.URL) + c, err := client.New(server.URL, opts...) if err != nil { panic(err) }