diff --git a/appstore/validator.go b/appstore/validator.go index 31c9930..02e4922 100644 --- a/appstore/validator.go +++ b/appstore/validator.go @@ -31,6 +31,7 @@ type IAPClient interface { type Client struct { URL string TimeOut time.Duration + SandboxURL string } // HandleError returns error message by status code @@ -81,6 +82,7 @@ func New() Client { client := Client{ URL: SandboxURL, TimeOut: time.Second * 5, + SandboxURL: SandboxURL, } if os.Getenv("IAP_ENVIRONMENT") == "production" { client.URL = ProductionURL @@ -121,6 +123,26 @@ func (c *Client) Verify(req IAPRequest, result interface{}) error { defer resp.Body.Close() err = json.NewDecoder(resp.Body).Decode(result) + if err != nil { + return err + } - return err + // Always verify your receipt first with the production URL; proceed to verify with the sandbox URL if you receive + // a 21007 status code + // + // https://developer.apple.com/library/content/technotes/tn2413/_index.html#//apple_ref/doc/uid/DTS40016228-CH1-RECEIPTURL + r, ok := result.(*IAPResponse) + if ok && r.Status == 21007 { + b = new(bytes.Buffer) + json.NewEncoder(b).Encode(req) + resp, err := client.Post(c.SandboxURL, "application/json; charset=utf-8", b) + if err != nil { + return err + } + defer resp.Body.Close() + + return json.NewDecoder(resp.Body).Decode(result) + } + + return nil } diff --git a/appstore/validator_test.go b/appstore/validator_test.go index 52436c1..aa78990 100644 --- a/appstore/validator_test.go +++ b/appstore/validator_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + "net/http/httptest" + "net/http" ) func TestHandleError(t *testing.T) { @@ -91,8 +93,9 @@ func TestHandleError(t *testing.T) { func TestNew(t *testing.T) { expected := Client{ - URL: "https://sandbox.itunes.apple.com/verifyReceipt", + URL: SandboxURL, TimeOut: time.Second * 5, + SandboxURL:SandboxURL, } actual := New() @@ -103,8 +106,9 @@ func TestNew(t *testing.T) { func TestNewWithEnvironment(t *testing.T) { expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", + URL: ProductionURL, TimeOut: time.Second * 5, + SandboxURL:SandboxURL, } os.Setenv("IAP_ENVIRONMENT", "production") @@ -123,7 +127,7 @@ func TestNewWithConfig(t *testing.T) { } expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", + URL: ProductionURL, TimeOut: time.Second * 2, } @@ -139,7 +143,7 @@ func TestNewWithConfigTimeout(t *testing.T) { } expected := Client{ - URL: "https://buy.itunes.apple.com/verifyReceipt", + URL: ProductionURL, TimeOut: time.Second * 5, } @@ -149,9 +153,9 @@ func TestNewWithConfigTimeout(t *testing.T) { } } -func TestVerify(t *testing.T) { +func TestVerifyTimeout(t *testing.T) { client := New() - client.TimeOut = time.Millisecond * 100 + client.TimeOut = time.Millisecond req := IAPRequest{ ReceiptData: "dummy data", @@ -161,13 +165,172 @@ func TestVerify(t *testing.T) { if err == nil { t.Errorf("error should be occurred because of timeout") } +} - client = New() +func TestVerifyBadURL(t *testing.T) { + client := New() + client.URL = "127.0.0.1" + + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + err := client.Verify(req, result) + if err == nil { + t.Errorf("error should be occurred because the server is not real") + } +} + +func TestVerifyBadPayload(t *testing.T) { + s := httptest.NewServer(badPayload()) + defer s.Close() + + client := New() + client.URL = s.URL expected := &IAPResponse{ Status: 21002, } - client.Verify(req, result) + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + err := client.Verify(req, result) + if err != nil { + t.Errorf("got error %s", err) + } if !reflect.DeepEqual(result, expected) { t.Errorf("got %v\nwant %v", result, expected) } } + +func TestVerifyBadResponse(t *testing.T) { + s := httptest.NewServer(invalidResponse()) + defer s.Close() + + client := New() + client.URL = s.URL + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + err := client.Verify(req, result) + if err == nil { + t.Errorf("expected an error because Verify could not unmarshal server response") + } +} + +func TestVerifySandboxReceipt(t *testing.T) { + s := httptest.NewServer(redirectToSandbox()) + defer s.Close() + + sandboxServ := httptest.NewServer(sandboxSuccess()) + defer sandboxServ.Close() + + client := New() + client.URL = s.URL + client.TimeOut = time.Second * 100 + client.SandboxURL = sandboxServ.URL + + expected := &IAPResponse{ + Status: 0, + } + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + err := client.Verify(req, result) + if err != nil { + t.Errorf("got error %s", err) + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v\nwant %v", result, expected) + } +} + +func TestVerifySandboxReceiptFailure(t *testing.T) { + s := httptest.NewServer(redirectToSandbox()) + defer s.Close() + + sandboxServ := httptest.NewServer(sandboxTimeout()) + defer sandboxServ.Close() + + client := New() + client.URL = s.URL + client.TimeOut = time.Second * 100 + client.SandboxURL = sandboxServ.URL + + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + + err := client.Verify(req, result) + if err == nil { + t.Errorf("expected error to be not nil since the sandbox is not responding") + } +} + +func badPayload() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "POST" == r.Method { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status": 21002}`)) + return + } else { + w.Write([]byte(`unsupported request`)) + } + + w.WriteHeader(http.StatusBadRequest) + }) +} + +func redirectToSandbox() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "POST" == r.Method { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status": 21007}`)) + return + } else { + w.Write([]byte(`unsupported request`)) + } + + w.WriteHeader(http.StatusOK) + }) +} + +func sandboxSuccess() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "POST" == r.Method { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status": 0}`)) + return + } else { + w.Write([]byte(`unsupported request`)) + } + + w.WriteHeader(http.StatusOK) + }) +} + +func sandboxTimeout() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Do nothing and just dont return anything either + }) +} + +func invalidResponse() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if "POST" == r.Method { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`qwerty!@#$%^`)) + return + } else { + w.Write([]byte(`unsupported request`)) + } + + w.WriteHeader(http.StatusOK) + }) +}