Merge pull request #45 from Timothylock/master

Have app store hit prod and then sandbox if 21007
This commit is contained in:
Junpei Tsuji
2018-02-08 17:14:11 +09:00
committed by GitHub
3 changed files with 216 additions and 42 deletions

View File

@@ -102,7 +102,7 @@ type (
// We defined each field by the current IAP response, but some fields are not mentioned // We defined each field by the current IAP response, but some fields are not mentioned
// in the following Apple's document; // in the following Apple's document;
// https://developer.apple.com/library/ios/releasenotes/General/ValidateAppStoreReceipt/Chapters/ReceiptFields.html // https://developer.apple.com/library/ios/releasenotes/General/ValidateAppStoreReceipt/Chapters/ReceiptFields.html
// If you get other types or fileds from the IAP response, you should use the struct you defined. // If you get other types or fields from the IAP response, you should use the struct you defined.
IAPResponse struct { IAPResponse struct {
Status int `json:"status"` Status int `json:"status"`
Environment string `json:"environment"` Environment string `json:"environment"`
@@ -112,4 +112,10 @@ type (
PendingRenewalInfo []PendingRenewalInfo `json:"pending_renewal_info"` PendingRenewalInfo []PendingRenewalInfo `json:"pending_renewal_info"`
IsRetryable bool `json:"is-retryable"` IsRetryable bool `json:"is-retryable"`
} }
// The HttpStatusResponse struct contains the status code returned by the store
// Used as a workaround to detect when to hit the production appstore or sandbox appstore regardless of receipt type
StatusResponse struct {
Status int `json:"status"`
}
) )

View File

@@ -4,8 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"io/ioutil"
"net/http" "net/http"
"os"
"time" "time"
) )
@@ -18,7 +18,6 @@ const (
// Config is a configuration to initialize client // Config is a configuration to initialize client
type Config struct { type Config struct {
IsProduction bool
TimeOut time.Duration TimeOut time.Duration
} }
@@ -29,7 +28,8 @@ type IAPClient interface {
// Client implements IAPClient // Client implements IAPClient
type Client struct { type Client struct {
URL string ProductionURL string
SandboxURL string
TimeOut time.Duration TimeOut time.Duration
} }
@@ -79,12 +79,10 @@ func HandleError(status int) error {
// New creates a client object // New creates a client object
func New() Client { func New() Client {
client := Client{ client := Client{
URL: SandboxURL, ProductionURL: ProductionURL,
SandboxURL: SandboxURL,
TimeOut: time.Second * 5, TimeOut: time.Second * 5,
} }
if os.Getenv("IAP_ENVIRONMENT") == "production" {
client.URL = ProductionURL
}
return client return client
} }
@@ -95,12 +93,10 @@ func NewWithConfig(config Config) Client {
} }
client := Client{ client := Client{
URL: SandboxURL, ProductionURL: ProductionURL,
SandboxURL: SandboxURL,
TimeOut: config.TimeOut, TimeOut: config.TimeOut,
} }
if config.IsProduction {
client.URL = ProductionURL
}
return client return client
} }
@@ -114,13 +110,43 @@ func (c *Client) Verify(req IAPRequest, result interface{}) error {
b := new(bytes.Buffer) b := new(bytes.Buffer)
json.NewEncoder(b).Encode(req) json.NewEncoder(b).Encode(req)
resp, err := client.Post(c.URL, "application/json; charset=utf-8", b) resp, err := client.Post(c.ProductionURL, "application/json; charset=utf-8", b)
if err != nil {
return err
}
defer resp.Body.Close()
return c.parseResponse(resp, result, client, req)
}
func (c *Client) parseResponse(resp *http.Response, result interface{}, client http.Client, req IAPRequest) error {
// Read the body now so that we can unmarshal it twice
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(buf, &result)
if err != nil {
return err
}
// https://developer.apple.com/library/content/technotes/tn2413/_index.html#//apple_ref/doc/uid/DTS40016228-CH1-RECEIPTURL
var r StatusResponse
err = json.Unmarshal(buf, &r)
if err != nil {
return err
}
if r.Status == 21007 {
b := new(bytes.Buffer)
json.NewEncoder(b).Encode(req)
resp, err := client.Post(c.SandboxURL, "application/json; charset=utf-8", b)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(result) return json.NewDecoder(resp.Body).Decode(result)
}
return err return nil
} }

View File

@@ -2,8 +2,12 @@ package appstore
import ( import (
"errors" "errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"os" "os"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -91,7 +95,8 @@ func TestHandleError(t *testing.T) {
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
expected := Client{ expected := Client{
URL: "https://sandbox.itunes.apple.com/verifyReceipt", ProductionURL: ProductionURL,
SandboxURL: SandboxURL,
TimeOut: time.Second * 5, TimeOut: time.Second * 5,
} }
@@ -103,8 +108,9 @@ func TestNew(t *testing.T) {
func TestNewWithEnvironment(t *testing.T) { func TestNewWithEnvironment(t *testing.T) {
expected := Client{ expected := Client{
URL: "https://buy.itunes.apple.com/verifyReceipt", ProductionURL: ProductionURL,
TimeOut: time.Second * 5, TimeOut: time.Second * 5,
SandboxURL: SandboxURL,
} }
os.Setenv("IAP_ENVIRONMENT", "production") os.Setenv("IAP_ENVIRONMENT", "production")
@@ -118,12 +124,12 @@ func TestNewWithEnvironment(t *testing.T) {
func TestNewWithConfig(t *testing.T) { func TestNewWithConfig(t *testing.T) {
config := Config{ config := Config{
IsProduction: true,
TimeOut: time.Second * 2, TimeOut: time.Second * 2,
} }
expected := Client{ expected := Client{
URL: "https://buy.itunes.apple.com/verifyReceipt", ProductionURL: ProductionURL,
SandboxURL: SandboxURL,
TimeOut: time.Second * 2, TimeOut: time.Second * 2,
} }
@@ -134,12 +140,11 @@ func TestNewWithConfig(t *testing.T) {
} }
func TestNewWithConfigTimeout(t *testing.T) { func TestNewWithConfigTimeout(t *testing.T) {
config := Config{ config := Config{}
IsProduction: true,
}
expected := Client{ expected := Client{
URL: "https://buy.itunes.apple.com/verifyReceipt", ProductionURL: ProductionURL,
SandboxURL: SandboxURL,
TimeOut: time.Second * 5, TimeOut: time.Second * 5,
} }
@@ -149,9 +154,9 @@ func TestNewWithConfigTimeout(t *testing.T) {
} }
} }
func TestVerify(t *testing.T) { func TestVerifyTimeout(t *testing.T) {
client := New() client := New()
client.TimeOut = time.Millisecond * 100 client.TimeOut = time.Millisecond
req := IAPRequest{ req := IAPRequest{
ReceiptData: "dummy data", ReceiptData: "dummy data",
@@ -161,13 +166,150 @@ func TestVerify(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("error should be occurred because of timeout") t.Errorf("error should be occurred because of timeout")
} }
}
client = New() func TestVerifyBadURL(t *testing.T) {
expected := &IAPResponse{ client := New()
Status: 21002, client.ProductionURL = "127.0.0.1"
req := IAPRequest{
ReceiptData: "dummy data",
} }
client.Verify(req, result) result := &IAPResponse{}
if !reflect.DeepEqual(result, expected) { err := client.Verify(req, result)
t.Errorf("got %v\nwant %v", result, expected) if err == nil {
t.Errorf("error should be occurred because the server is not real")
} }
} }
func TestResponses(t *testing.T) {
req := IAPRequest{
ReceiptData: "dummy data",
}
result := &IAPResponse{}
type testCase struct {
testServer *httptest.Server
sandboxServ *httptest.Server
expected *IAPResponse
}
testCases := []testCase{
// VerifySandboxReceipt
{
testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21007}`)),
sandboxServ: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 0}`)),
expected: &IAPResponse{
Status: 0,
},
},
// VerifyBadPayload
{
testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21002}`)),
expected: &IAPResponse{
Status: 21002,
},
},
// SuccessPayload
{
testServer: httptest.NewServer(serverWithResponse(http.StatusBadRequest, `{"status": 0}`)),
expected: &IAPResponse{
Status: 0,
},
},
}
client := New()
client.TimeOut = time.Second * 100
client.SandboxURL = "localhost"
for i, tc := range testCases {
defer tc.testServer.Close()
client.ProductionURL = tc.testServer.URL
if tc.sandboxServ != nil {
client.SandboxURL = tc.sandboxServ.URL
}
err := client.Verify(req, result)
if err != nil {
t.Errorf("Test case %d - %s", i, err.Error())
}
if !reflect.DeepEqual(result, tc.expected) {
t.Errorf("Test case %d - got %v\nwant %v", i, result, tc.expected)
}
}
}
func TestErrors(t *testing.T) {
req := IAPRequest{
ReceiptData: "dummy data",
}
result := &IAPResponse{}
type testCase struct {
testServer *httptest.Server
}
testCases := []testCase{
// VerifySandboxReceiptFailure
{
testServer: httptest.NewServer(serverWithResponse(http.StatusOK, `{"status": 21007}`)),
},
// VerifyBadResponse
{
testServer: httptest.NewServer(serverWithResponse(http.StatusInternalServerError, `qwerty!@#$%^`)),
},
}
client := New()
client.TimeOut = time.Second * 100
client.SandboxURL = "localhost"
for i, tc := range testCases {
defer tc.testServer.Close()
client.ProductionURL = tc.testServer.URL
err := client.Verify(req, result)
if err == nil {
t.Errorf("Test case %d - expected error to be not nil since the sandbox is not responding", i)
}
}
}
func TestCannotReadBody(t *testing.T) {
client := New()
testResponse := http.Response{Body: ioutil.NopCloser(errReader(0))}
if client.parseResponse(&testResponse, IAPResponse{}, http.Client{}, IAPRequest{}) == nil {
t.Errorf("expected redirectToSandbox to fail to read the body")
}
}
func TestCannotUnmarshalBody(t *testing.T) {
client := New()
testResponse := http.Response{Body: ioutil.NopCloser(strings.NewReader(`{"status": true}`))}
if client.parseResponse(&testResponse, StatusResponse{}, http.Client{}, IAPRequest{}) == nil {
t.Errorf("expected redirectToSandbox to fail to unmarshal the data")
}
}
type errReader int
func (errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("test error")
}
func serverWithResponse(statusCode int, response string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if "POST" == r.Method {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
return
} else {
w.Write([]byte(`unsupported request`))
}
w.WriteHeader(statusCode)
})
}