diff --git a/appstore/validator.go b/appstore/validator.go index 1a8c2fd..1e4af86 100644 --- a/appstore/validator.go +++ b/appstore/validator.go @@ -2,11 +2,11 @@ package appstore import ( "bytes" + "context" "encoding/json" "errors" "io/ioutil" "net/http" - "time" ) const ( @@ -18,14 +18,9 @@ const ( ContentType string = "application/json; charset=utf-8" ) -// Config is a configuration to initialize client -type Config struct { - TimeOut time.Duration -} - // IAPClient is an interface to call validation API in App Store type IAPClient interface { - Verify(IAPRequest, interface{}) error + Verify(ctx context.Context, reqBody IAPRequest, resp interface{}) error } // Client implements IAPClient @@ -98,7 +93,7 @@ func NewWithClient(client *http.Client) *Client { } // Verify sends receipts and gets validation result -func (c *Client) Verify(reqBody IAPRequest, result interface{}) error { +func (c *Client) Verify(ctx context.Context, reqBody IAPRequest, result interface{}) error { b := new(bytes.Buffer) json.NewEncoder(b).Encode(reqBody) @@ -107,15 +102,16 @@ func (c *Client) Verify(reqBody IAPRequest, result interface{}) error { return err } req.Header.Set("Content-Type", ContentType) + req = req.WithContext(ctx) resp, err := c.httpCli.Do(req) if err != nil { return err } defer resp.Body.Close() - return c.parseResponse(resp, result, reqBody) + return c.parseResponse(resp, result, ctx, reqBody) } -func (c *Client) parseResponse(resp *http.Response, result interface{}, reqBody IAPRequest) error { +func (c *Client) parseResponse(resp *http.Response, result interface{}, ctx context.Context, reqBody IAPRequest) error { // Read the body now so that we can unmarshal it twice buf, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -142,6 +138,7 @@ func (c *Client) parseResponse(resp *http.Response, result interface{}, reqBody return err } req.Header.Set("Content-Type", ContentType) + req = req.WithContext(ctx) resp, err := c.httpCli.Do(req) if err != nil { return err diff --git a/appstore/validator_test.go b/appstore/validator_test.go index e7c8972..408bbc2 100644 --- a/appstore/validator_test.go +++ b/appstore/validator_test.go @@ -1,6 +1,7 @@ package appstore import ( + "context" "errors" "io/ioutil" "net/http" @@ -128,10 +129,31 @@ func TestVerifyTimeout(t *testing.T) { ReceiptData: "dummy data", } result := &IAPResponse{} - err := client.Verify(req, result) + ctx := context.Background() + err := client.Verify(ctx, req, result) if err == nil { t.Errorf("error should be occurred because of timeout") } + t.Log(err) +} + +func TestVerifyWithCancel(t *testing.T) { + client := New() + + req := IAPRequest{ + ReceiptData: "dummy data", + } + result := &IAPResponse{} + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancelFunc() + }() + err := client.Verify(ctx, req, result) + if err == nil { + t.Errorf("error should be occurred because of context cancel") + } + t.Log(err) } func TestVerifyBadURL(t *testing.T) { @@ -142,7 +164,8 @@ func TestVerifyBadURL(t *testing.T) { ReceiptData: "dummy data", } result := &IAPResponse{} - err := client.Verify(req, result) + ctx := context.Background() + err := client.Verify(ctx, req, result) if err == nil { t.Errorf("error should be occurred because the server is not real") } @@ -195,7 +218,8 @@ func TestResponses(t *testing.T) { client.SandboxURL = tc.sandboxServ.URL } - err := client.Verify(req, result) + ctx := context.Background() + err := client.Verify(ctx, req, result) if err != nil { t.Errorf("Test case %d - %s", i, err.Error()) } @@ -233,7 +257,8 @@ func TestErrors(t *testing.T) { defer tc.testServer.Close() client.ProductionURL = tc.testServer.URL - err := client.Verify(req, result) + ctx := context.Background() + err := client.Verify(ctx, req, result) if err == nil { t.Errorf("Test case %d - expected error to be not nil since the sandbox is not responding", i) } @@ -244,7 +269,8 @@ func TestCannotReadBody(t *testing.T) { client := New() testResponse := http.Response{Body: ioutil.NopCloser(errReader(0))} - if client.parseResponse(&testResponse, IAPResponse{}, IAPRequest{}) == nil { + ctx := context.Background() + if client.parseResponse(&testResponse, IAPResponse{}, ctx, IAPRequest{}) == nil { t.Errorf("expected redirectToSandbox to fail to read the body") } } @@ -253,7 +279,8 @@ func TestCannotUnmarshalBody(t *testing.T) { client := New() testResponse := http.Response{Body: ioutil.NopCloser(strings.NewReader(`{"status": true}`))} - if client.parseResponse(&testResponse, StatusResponse{}, IAPRequest{}) == nil { + ctx := context.Background() + if client.parseResponse(&testResponse, StatusResponse{}, ctx, IAPRequest{}) == nil { t.Errorf("expected redirectToSandbox to fail to unmarshal the data") } }