From 8c93c4b714d7c0487eb3962666605311a4a87cdd Mon Sep 17 00:00:00 2001 From: Junpei Tsuji Date: Fri, 18 May 2018 11:20:58 +0900 Subject: [PATCH] Support context for amazon app store --- amazon/validator.go | 13 ++++++++++--- amazon/validator_test.go | 9 ++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/amazon/validator.go b/amazon/validator.go index c2d2d6e..1d0b3fe 100644 --- a/amazon/validator.go +++ b/amazon/validator.go @@ -1,6 +1,7 @@ package amazon import ( + "context" "encoding/json" "errors" "fmt" @@ -49,7 +50,7 @@ type IAPResponseError struct { // IAPClient is an interface to call validation API in Amazon App Store type IAPClient interface { - Verify(string, string) (IAPResponse, error) + Verify(context.Context, string, string) (IAPResponse, error) } // Client implements IAPClient @@ -91,13 +92,19 @@ func NewWithConfig(config Config) Client { } // Verify sends receipts and gets validation result -func (c Client) Verify(userID string, receiptID string) (IAPResponse, error) { +func (c Client) Verify(ctx context.Context, userID string, receiptID string) (IAPResponse, error) { result := IAPResponse{} url := fmt.Sprintf("%v/version/1.0/verifyReceiptId/developer/%v/user/%v/receiptId/%v", c.URL, c.Secret, userID, receiptID) client := http.Client{ Timeout: c.TimeOut, } - resp, err := client.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return result, err + } + req = req.WithContext(ctx) + + resp, err := client.Do(req) if err != nil { return result, fmt.Errorf("%v", err) } diff --git a/amazon/validator_test.go b/amazon/validator_test.go index 1d76197..681bc0d 100644 --- a/amazon/validator_test.go +++ b/amazon/validator_test.go @@ -1,6 +1,7 @@ package amazon import ( + "context" "errors" "fmt" "net/http" @@ -25,6 +26,7 @@ func TestHandle497Error(t *testing.T) { // status 400 expected = errors.New("Purchase token/app user mismatch") _, actual = client.Verify( + context.Background(), "99FD_DL23EMhrOGDnur9-ulvqomrSg6qyLPSD3CFE=", "q1YqVrJSSs7P1UvMTazKz9PLTCwoTswtyEktM9JLrShIzCvOzM-LL04tiTdW0lFKASo2NDEwMjCwMDM2MTC0AIqVAsUsLd1c4l18jIxdfTOK_N1d8kqLLHVLc8oK83OLgtPNCit9AoJdjJ3dXG2BGkqUrAxrAQ", ) @@ -47,6 +49,7 @@ func TestHandle400Error(t *testing.T) { // status 400 expected = errors.New("Failed to parse receipt Id") _, actual = client.Verify( + context.Background(), "99FD_DL23EMhrOGDnur9-ulvqomrSg6qyLPSD3CFE=", "q1YqVrJSSs7P1UvMTazKz9PLTCwoTswtyEktM9JLrShIzCvOzM-LL04tiTdW0lFKASo2NDEwMjCwMDM2MTC0AIqVAsUsLd1c4l18jIxdfTOK_N1d8kqLLHVLc8oK83OLgtPNCit9AoJdjJ3dXG2BGkqUrAxrAQ", ) @@ -56,7 +59,6 @@ func TestHandle400Error(t *testing.T) { } func TestNew(t *testing.T) { - t.Parallel() expected := Client{ URL: SandboxURL, TimeOut: time.Second * 5, @@ -70,7 +72,6 @@ func TestNew(t *testing.T) { } func TestNewWithEnvironment(t *testing.T) { - t.Parallel() expected := Client{ URL: ProductionURL, TimeOut: time.Second * 5, @@ -143,6 +144,7 @@ func TestVerify(t *testing.T) { } actual, _ := client.Verify( + context.Background(), "99FD_DL23EMhrOGDnur9-ulvqomrSg6qyLPSD3CFE=", "q1YqVrJSSs7P1UvMTazKz9PLTCwoTswtyEktM9JLrShIzCvOzM-LL04tiTdW0lFKASo2NDEwMjCwMDM2MTC0AIqVAsUsLd1c4l18jIxdfTOK_N1d8kqLLHVLc8oK83OLgtPNCit9AoJdjJ3dXG2BGkqUrAxrAQ", ) @@ -158,7 +160,8 @@ func TestVerifyTimeout(t *testing.T) { defer server.Close() expected := errors.New("") - _, actual := client.Verify("timeout", "timeout") + ctx := context.Background() + _, actual := client.Verify(ctx, "timeout", "timeout") if !reflect.DeepEqual(reflect.TypeOf(actual), reflect.TypeOf(expected)) { t.Errorf("got %v\nwant %v", actual, expected) }