From 2a724a9d7c3559f57c1dfa904bb516dd528e0926 Mon Sep 17 00:00:00 2001 From: Junpei Tsuji Date: Thu, 17 May 2018 12:36:12 +0900 Subject: [PATCH] Support custom client --- appstore/validator.go | 53 +++++----- appstore/validator_test.go | 208 +++++++++++++++---------------------- 2 files changed, 114 insertions(+), 147 deletions(-) diff --git a/appstore/validator.go b/appstore/validator.go index 731109c..1a8c2fd 100644 --- a/appstore/validator.go +++ b/appstore/validator.go @@ -14,6 +14,8 @@ const ( SandboxURL string = "https://sandbox.itunes.apple.com/verifyReceipt" // ProductionURL is the endpoint for production environment. ProductionURL string = "https://buy.itunes.apple.com/verifyReceipt" + // ContentType is the request content-type for apple store. + ContentType string = "application/json; charset=utf-8" ) // Config is a configuration to initialize client @@ -30,7 +32,7 @@ type IAPClient interface { type Client struct { ProductionURL string SandboxURL string - TimeOut time.Duration + httpCli *http.Client } // HandleError returns error message by status code @@ -77,48 +79,43 @@ func HandleError(status int) error { } // New creates a client object -func New() Client { - client := Client{ +func New() *Client { + client := &Client{ ProductionURL: ProductionURL, SandboxURL: SandboxURL, - TimeOut: time.Second * 5, + httpCli: http.DefaultClient, } return client } -// NewWithConfig creates a client with configuration -func NewWithConfig(config Config) Client { - if config.TimeOut == 0 { - config.TimeOut = time.Second * 5 - } - - client := Client{ +// NewWithClient creates a client with a custom http client. +func NewWithClient(client *http.Client) *Client { + return &Client{ ProductionURL: ProductionURL, SandboxURL: SandboxURL, - TimeOut: config.TimeOut, + httpCli: client, } - - return client } // Verify sends receipts and gets validation result -func (c *Client) Verify(req IAPRequest, result interface{}) error { - client := http.Client{ - Timeout: c.TimeOut, - } - +func (c *Client) Verify(reqBody IAPRequest, result interface{}) error { b := new(bytes.Buffer) - json.NewEncoder(b).Encode(req) + json.NewEncoder(b).Encode(reqBody) - resp, err := client.Post(c.ProductionURL, "application/json; charset=utf-8", b) + req, err := http.NewRequest("POST", c.ProductionURL, b) + if err != nil { + return err + } + req.Header.Set("Content-Type", ContentType) + resp, err := c.httpCli.Do(req) if err != nil { return err } defer resp.Body.Close() - return c.parseResponse(resp, result, client, req) + return c.parseResponse(resp, result, reqBody) } -func (c *Client) parseResponse(resp *http.Response, result interface{}, client http.Client, req IAPRequest) error { +func (c *Client) parseResponse(resp *http.Response, result interface{}, reqBody IAPRequest) error { // Read the body now so that we can unmarshal it twice buf, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -138,8 +135,14 @@ func (c *Client) parseResponse(resp *http.Response, result interface{}, client h } if r.Status == 21007 { b := new(bytes.Buffer) - json.NewEncoder(b).Encode(req) - resp, err := client.Post(c.SandboxURL, "application/json; charset=utf-8", b) + json.NewEncoder(b).Encode(reqBody) + + req, err := http.NewRequest("POST", c.SandboxURL, b) + if err != nil { + return err + } + req.Header.Set("Content-Type", ContentType) + resp, err := c.httpCli.Do(req) if err != nil { return err } diff --git a/appstore/validator_test.go b/appstore/validator_test.go index 056e169..e7c8972 100644 --- a/appstore/validator_test.go +++ b/appstore/validator_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "os" "reflect" "strings" "testing" @@ -13,91 +12,84 @@ import ( ) func TestHandleError(t *testing.T) { - var expected, actual error - - // status 0 - expected = nil - actual = HandleError(0) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) + tests := []struct { + name string + in int + out error + }{ + { + name: "status 0", + in: 0, + out: nil, + }, + { + name: "status 21000", + in: 21000, + out: errors.New("The App Store could not read the JSON object you provided."), + }, + { + name: "status 21002", + in: 21002, + out: errors.New("The data in the receipt-data property was malformed or missing."), + }, + { + name: "status 21003", + in: 21003, + out: errors.New("The receipt could not be authenticated."), + }, + { + name: "status 21004", + in: 21004, + out: errors.New("The shared secret you provided does not match the shared secret on file for your account."), + }, + { + name: "status 21005", + in: 21005, + out: errors.New("The receipt server is not currently available."), + }, + { + name: "status 21007", + in: 21007, + out: errors.New("This receipt is from the test environment, but it was sent to the production environment for verification. Send it to the test environment instead."), + }, + { + name: "status 21008", + in: 21008, + out: errors.New("This receipt is from the production environment, but it was sent to the test environment for verification. Send it to the production environment instead."), + }, + { + name: "status 21010", + in: 21010, + out: errors.New("This receipt could not be authorized. Treat this the same as if a purchase was never made."), + }, + { + name: "status 21100 ~ 21199", + in: 21100, + out: errors.New("Internal data access error."), + }, + { + name: "status unknown", + in: 100, + out: errors.New("An unknown error occurred"), + }, } - // status 21000 - expected = errors.New("The App Store could not read the JSON object you provided.") - actual = HandleError(21000) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } + for _, v := range tests { + t.Run(v.name, func(t *testing.T) { + out := HandleError(v.in) - // status 21002 - expected = errors.New("The data in the receipt-data property was malformed or missing.") - actual = HandleError(21002) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21003 - expected = errors.New("The receipt could not be authenticated.") - actual = HandleError(21003) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21004 - expected = errors.New("The shared secret you provided does not match the shared secret on file for your account.") - actual = HandleError(21004) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21005 - expected = errors.New("The receipt server is not currently available.") - actual = HandleError(21005) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21007 - expected = errors.New("This receipt is from the test environment, but it was sent to the production environment for verification. Send it to the test environment instead.") - actual = HandleError(21007) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21008 - expected = errors.New("This receipt is from the production environment, but it was sent to the test environment for verification. Send it to the production environment instead.") - actual = HandleError(21008) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21010 - expected = errors.New("This receipt could not be authorized. Treat this the same as if a purchase was never made.") - actual = HandleError(21010) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status 21100 - 21199 - expected = errors.New("Internal data access error.") - actual = HandleError(21155) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } - - // status unknown - expected = errors.New("An unknown error occurred") - actual = HandleError(100) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) + if !reflect.DeepEqual(out, v.out) { + t.Errorf("input: %d\ngot: %v\nwant: %v\n", v.in, out, v.out) + } + }) } } func TestNew(t *testing.T) { - expected := Client{ + expected := &Client{ ProductionURL: ProductionURL, SandboxURL: SandboxURL, - TimeOut: time.Second * 5, + httpCli: http.DefaultClient, } actual := New() @@ -106,57 +98,31 @@ func TestNew(t *testing.T) { } } -func TestNewWithEnvironment(t *testing.T) { - expected := Client{ - ProductionURL: ProductionURL, - TimeOut: time.Second * 5, - SandboxURL: SandboxURL, - } - - os.Setenv("IAP_ENVIRONMENT", "production") - actual := New() - os.Clearenv() - - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } -} - -func TestNewWithConfig(t *testing.T) { - config := Config{ - TimeOut: time.Second * 2, - } - - expected := Client{ +func TestNewWithClient(t *testing.T) { + expected := &Client{ ProductionURL: ProductionURL, SandboxURL: SandboxURL, - TimeOut: time.Second * 2, + httpCli: &http.Client{ + Timeout: 10 * time.Second, + }, } - actual := NewWithConfig(config) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("got %v\nwant %v", actual, expected) - } -} - -func TestNewWithConfigTimeout(t *testing.T) { - config := Config{} - - expected := Client{ - ProductionURL: ProductionURL, - SandboxURL: SandboxURL, - TimeOut: time.Second * 5, - } - - actual := NewWithConfig(config) + actual := NewWithClient(&http.Client{ + Timeout: 10 * time.Second, + }) if !reflect.DeepEqual(actual, expected) { t.Errorf("got %v\nwant %v", actual, expected) } } func TestVerifyTimeout(t *testing.T) { - client := New() - client.TimeOut = time.Millisecond + client := &Client{ + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + httpCli: &http.Client{ + Timeout: time.Millisecond, + }, + } req := IAPRequest{ ReceiptData: "dummy data", @@ -220,7 +186,6 @@ func TestResponses(t *testing.T) { } client := New() - client.TimeOut = time.Second * 100 client.SandboxURL = "localhost" for i, tc := range testCases { @@ -262,7 +227,6 @@ func TestErrors(t *testing.T) { } client := New() - client.TimeOut = time.Second * 100 client.SandboxURL = "localhost" for i, tc := range testCases { @@ -280,7 +244,7 @@ func TestCannotReadBody(t *testing.T) { client := New() testResponse := http.Response{Body: ioutil.NopCloser(errReader(0))} - if client.parseResponse(&testResponse, IAPResponse{}, http.Client{}, IAPRequest{}) == nil { + if client.parseResponse(&testResponse, IAPResponse{}, IAPRequest{}) == nil { t.Errorf("expected redirectToSandbox to fail to read the body") } } @@ -289,7 +253,7 @@ func TestCannotUnmarshalBody(t *testing.T) { client := New() testResponse := http.Response{Body: ioutil.NopCloser(strings.NewReader(`{"status": true}`))} - if client.parseResponse(&testResponse, StatusResponse{}, http.Client{}, IAPRequest{}) == nil { + if client.parseResponse(&testResponse, StatusResponse{}, IAPRequest{}) == nil { t.Errorf("expected redirectToSandbox to fail to unmarshal the data") } }