diff --git a/appstore/model.go b/appstore/model.go index 37d7a25..a4c3476 100644 --- a/appstore/model.go +++ b/appstore/model.go @@ -102,7 +102,7 @@ type ( // We defined each field by the current IAP response, but some fields are not mentioned // in the following Apple's document; // https://developer.apple.com/library/ios/releasenotes/General/ValidateAppStoreReceipt/Chapters/ReceiptFields.html - // If you get other types or fileds from the IAP response, you should use the struct you defined. + // If you get other types or fields from the IAP response, you should use the struct you defined. IAPResponse struct { Status int `json:"status"` Environment string `json:"environment"` @@ -112,4 +112,10 @@ type ( PendingRenewalInfo []PendingRenewalInfo `json:"pending_renewal_info"` IsRetryable bool `json:"is-retryable"` } + + // The HttpStatusResponse struct contains the status code returned by the store + // Used as a workaround to detect when to hit the production appstore or sandbox appstore regardless of receipt type + StatusResponse struct { + Status int `json:"status"` + } ) diff --git a/appstore/validator.go b/appstore/validator.go index 31c9930..731109c 100644 --- a/appstore/validator.go +++ b/appstore/validator.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/json" "errors" + "io/ioutil" "net/http" - "os" "time" ) @@ -18,8 +18,7 @@ const ( // Config is a configuration to initialize client type Config struct { - IsProduction bool - TimeOut time.Duration + TimeOut time.Duration } // IAPClient is an interface to call validation API in App Store @@ -29,8 +28,9 @@ type IAPClient interface { // Client implements IAPClient type Client struct { - URL string - TimeOut time.Duration + ProductionURL string + SandboxURL string + TimeOut time.Duration } // HandleError returns error message by status code @@ -79,11 +79,9 @@ func HandleError(status int) error { // New creates a client object func New() Client { client := Client{ - URL: SandboxURL, - TimeOut: time.Second * 5, - } - if os.Getenv("IAP_ENVIRONMENT") == "production" { - client.URL = ProductionURL + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + TimeOut: time.Second * 5, } return client } @@ -95,11 +93,9 @@ func NewWithConfig(config Config) Client { } client := Client{ - URL: SandboxURL, - TimeOut: config.TimeOut, - } - if config.IsProduction { - client.URL = ProductionURL + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + TimeOut: config.TimeOut, } return client @@ -114,13 +110,43 @@ func (c *Client) Verify(req IAPRequest, result interface{}) error { b := new(bytes.Buffer) json.NewEncoder(b).Encode(req) - resp, err := client.Post(c.URL, "application/json; charset=utf-8", b) + resp, err := client.Post(c.ProductionURL, "application/json; charset=utf-8", b) if err != nil { return err } defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(result) - - return err + return c.parseResponse(resp, result, client, req) +} + +func (c *Client) parseResponse(resp *http.Response, result interface{}, client http.Client, req IAPRequest) error { + // Read the body now so that we can unmarshal it twice + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + err = json.Unmarshal(buf, &result) + if err != nil { + return err + } + + // https://developer.apple.com/library/content/technotes/tn2413/_index.html#//apple_ref/doc/uid/DTS40016228-CH1-RECEIPTURL + var r StatusResponse + err = json.Unmarshal(buf, &r) + if err != nil { + return err + } + if r.Status == 21007 { + b := new(bytes.Buffer) + json.NewEncoder(b).Encode(req) + resp, err := client.Post(c.SandboxURL, "application/json; charset=utf-8", b) + if err != nil { + return err + } + defer resp.Body.Close() + + return json.NewDecoder(resp.Body).Decode(result) + } + + return nil } diff --git a/appstore/validator_test.go b/appstore/validator_test.go index 52436c1..056e169 100644 --- a/appstore/validator_test.go +++ b/appstore/validator_test.go @@ -2,8 +2,12 @@ package appstore import ( "errors" + "io/ioutil" + "net/http" + "net/http/httptest" "os" "reflect" + "strings" "testing" "time" ) @@ -91,8 +95,9 @@ func TestHandleError(t *testing.T) { func TestNew(t *testing.T) { expected := Client{ - URL: "https://sandbox.itunes.apple.com/verifyReceipt", - TimeOut: time.Second * 5, + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + TimeOut: time.Second * 5, } actual := New() @@ -103,8 +108,9 @@ func TestNew(t *testing.T) { func TestNewWithEnvironment(t *testing.T) { expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", - TimeOut: time.Second * 5, + ProductionURL: ProductionURL, + TimeOut: time.Second * 5, + SandboxURL: SandboxURL, } os.Setenv("IAP_ENVIRONMENT", "production") @@ -118,13 +124,13 @@ func TestNewWithEnvironment(t *testing.T) { func TestNewWithConfig(t *testing.T) { config := Config{ - IsProduction: true, - TimeOut: time.Second * 2, + TimeOut: time.Second * 2, } expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", - TimeOut: time.Second * 2, + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + TimeOut: time.Second * 2, } actual := NewWithConfig(config) @@ -134,13 +140,12 @@ func TestNewWithConfig(t *testing.T) { } func TestNewWithConfigTimeout(t *testing.T) { - config := Config{ - IsProduction: true, - } + config := Config{} expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", - TimeOut: time.Second * 5, + ProductionURL: ProductionURL, + SandboxURL: SandboxURL, + TimeOut: time.Second * 5, } actual := NewWithConfig(config) @@ -149,9 +154,9 @@ func TestNewWithConfigTimeout(t *testing.T) { } } -func TestVerify(t *testing.T) { +func TestVerifyTimeout(t *testing.T) { client := New() - client.TimeOut = time.Millisecond * 100 + client.TimeOut = time.Millisecond req := IAPRequest{ ReceiptData: "dummy data", @@ -161,13 +166,150 @@ func TestVerify(t *testing.T) { if err == nil { t.Errorf("error should be occurred because of timeout") } +} - client = New() - expected := &IAPResponse{ - Status: 21002, +func TestVerifyBadURL(t *testing.T) { + client := New() + client.ProductionURL = "127.0.0.1" + + req := IAPRequest{ + ReceiptData: "dummy data", } - client.Verify(req, result) - if !reflect.DeepEqual(result, expected) { - t.Errorf("got %v\nwant %v", result, expected) + result := &IAPResponse{} + err := client.Verify(req, result) + if err == nil { + t.Errorf("error should be occurred because the server is not real") } } + +func TestResponses(t *testing.T) { + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + type testCase struct { + testServer *httptest.Server + sandboxServ *httptest.Server + expected *IAPResponse + } + + testCases := []testCase{ + // VerifySandboxReceipt + { + testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21007}`)), + sandboxServ: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 0}`)), + expected: &IAPResponse{ + Status: 0, + }, + }, + // VerifyBadPayload + { + testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21002}`)), + expected: &IAPResponse{ + Status: 21002, + }, + }, + // SuccessPayload + { + testServer: httptest.NewServer(serverWithResponse(http.StatusBadRequest, `{"status": 0}`)), + expected: &IAPResponse{ + Status: 0, + }, + }, + } + + client := New() + client.TimeOut = time.Second * 100 + client.SandboxURL = "localhost" + + for i, tc := range testCases { + defer tc.testServer.Close() + client.ProductionURL = tc.testServer.URL + if tc.sandboxServ != nil { + client.SandboxURL = tc.sandboxServ.URL + } + + err := client.Verify(req, result) + if err != nil { + t.Errorf("Test case %d - %s", i, err.Error()) + } + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("Test case %d - got %v\nwant %v", i, result, tc.expected) + } + } +} + +func TestErrors(t *testing.T) { + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + type testCase struct { + testServer *httptest.Server + } + + testCases := []testCase{ + // VerifySandboxReceiptFailure + { + testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21007}`)), + }, + // VerifyBadResponse + { + testServer: httptest.NewServer(serverWithResponse(http.StatusInternalServerError, `qwerty!@#$%^`)), + }, + } + + client := New() + client.TimeOut = time.Second * 100 + client.SandboxURL = "localhost" + + for i, tc := range testCases { + defer tc.testServer.Close() + client.ProductionURL = tc.testServer.URL + + err := client.Verify(req, result) + if err == nil { + t.Errorf("Test case %d - expected error to be not nil since the sandbox is not responding", i) + } + } +} + +func TestCannotReadBody(t *testing.T) { + client := New() + testResponse := http.Response{Body: ioutil.NopCloser(errReader(0))} + + if client.parseResponse(&testResponse, IAPResponse{}, http.Client{}, IAPRequest{}) == nil { + t.Errorf("expected redirectToSandbox to fail to read the body") + } +} + +func TestCannotUnmarshalBody(t *testing.T) { + client := New() + testResponse := http.Response{Body: ioutil.NopCloser(strings.NewReader(`{"status": true}`))} + + if client.parseResponse(&testResponse, StatusResponse{}, http.Client{}, IAPRequest{}) == nil { + t.Errorf("expected redirectToSandbox to fail to unmarshal the data") + } +} + +type errReader int + +func (errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("test error") +} + +func serverWithResponse(statusCode int, response string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "POST" == r.Method { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(response)) + return + } else { + w.Write([]byte(`unsupported request`)) + } + + w.WriteHeader(statusCode) + }) +}