diff --git a/amazon/validator.go b/amazon/validator.go index 1d0b3fe..0861c1b 100644 --- a/amazon/validator.go +++ b/amazon/validator.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "os" - "time" ) const ( @@ -25,13 +24,6 @@ func getSandboxURL() string { return url } -// Config is a configuration to initialize client -type Config struct { - IsProduction bool - Secret string - TimeOut time.Duration -} - // The IAPResponse type has the response properties type IAPResponse struct { ReceiptID string `json:"receiptId"` @@ -57,34 +49,31 @@ type IAPClient interface { type Client struct { URL string Secret string - TimeOut time.Duration + httpCli *http.Client } // New creates a client object -func New(secret string) IAPClient { - client := Client{ +func New(secret string) *Client { + client := &Client{ URL: getSandboxURL(), Secret: secret, - TimeOut: time.Second * 5, + httpCli: http.DefaultClient, } if os.Getenv("IAP_ENVIRONMENT") == "production" { client.URL = ProductionURL } + return client } -// NewWithConfig creates a client with configuration -func NewWithConfig(config Config) Client { - if config.TimeOut == 0 { - config.TimeOut = time.Second * 5 - } - - client := Client{ +// NewWithClient creates a client with a custom client. +func NewWithClient(secret string, cli *http.Client) *Client { + client := &Client{ URL: getSandboxURL(), - Secret: config.Secret, - TimeOut: config.TimeOut, + Secret: secret, + httpCli: cli, } - if config.IsProduction { + if os.Getenv("IAP_ENVIRONMENT") == "production" { client.URL = ProductionURL } @@ -92,19 +81,16 @@ func NewWithConfig(config Config) Client { } // Verify sends receipts and gets validation result -func (c Client) Verify(ctx context.Context, userID string, receiptID string) (IAPResponse, error) { +func (c *Client) Verify(ctx context.Context, userID string, receiptID string) (IAPResponse, error) { result := IAPResponse{} url := fmt.Sprintf("%v/version/1.0/verifyReceiptId/developer/%v/user/%v/receiptId/%v", c.URL, c.Secret, userID, receiptID) - client := http.Client{ - Timeout: c.TimeOut, - } req, err := http.NewRequest("GET", url, nil) if err != nil { return result, err } req = req.WithContext(ctx) - resp, err := client.Do(req) + resp, err := c.httpCli.Do(req) if err != nil { return result, fmt.Errorf("%v", err) } diff --git a/amazon/validator_test.go b/amazon/validator_test.go index 681bc0d..77e272d 100644 --- a/amazon/validator_test.go +++ b/amazon/validator_test.go @@ -59,10 +59,10 @@ func TestHandle400Error(t *testing.T) { } func TestNew(t *testing.T) { - expected := Client{ + expected := &Client{ URL: SandboxURL, - TimeOut: time.Second * 5, Secret: "developerSecret", + httpCli: http.DefaultClient, } actual := New("developerSecret") @@ -72,10 +72,10 @@ func TestNew(t *testing.T) { } func TestNewWithEnvironment(t *testing.T) { - expected := Client{ + expected := &Client{ URL: ProductionURL, - TimeOut: time.Second * 5, Secret: "developerSecret", + httpCli: http.DefaultClient, } os.Setenv("IAP_ENVIRONMENT", "production") @@ -87,40 +87,20 @@ func TestNewWithEnvironment(t *testing.T) { } } -func TestNewWithConfig(t *testing.T) { - t.Parallel() - config := Config{ - IsProduction: true, - Secret: "developerSecret", - TimeOut: time.Second * 2, +func TestNewWithClient(t *testing.T) { + expected := &Client{ + URL: ProductionURL, + Secret: "developerSecret", + httpCli: &http.Client{ + Timeout: time.Second * 2, + }, } + os.Setenv("IAP_ENVIRONMENT", "production") - expected := Client{ - URL: ProductionURL, - TimeOut: time.Second * 2, - Secret: "developerSecret", + cli := &http.Client{ + Timeout: time.Second * 2, } - - actual := NewWithConfig(config) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } -} - -func TestNewWithConfigTimeout(t *testing.T) { - t.Parallel() - config := Config{ - IsProduction: true, - Secret: "developerSecret", - } - - expected := Client{ - URL: ProductionURL, - TimeOut: time.Second * 5, - Secret: "developerSecret", - } - - actual := NewWithConfig(config) + actual := NewWithClient("developerSecret", cli) if !reflect.DeepEqual(actual, expected) { t.Errorf("got %v\nwant %v", actual, expected) } @@ -175,6 +155,6 @@ func testTools(code int, body string) (*httptest.Server, *Client) { fmt.Fprintln(w, body) })) - client := &Client{URL: server.URL, TimeOut: time.Second * 2, Secret: "developerSecret"} + client := &Client{URL: server.URL, Secret: "developerSecret", httpCli: &http.Client{Timeout: 2 * time.Second}} return server, client }