1
0
mirror of https://github.com/kataras/iris.git synced 2025-12-17 09:57:01 +00:00

add new errors.Intercept package-level function

This commit is contained in:
Gerasimos (Makis) Maropoulos
2024-01-20 00:33:59 +02:00
parent fabbc271b9
commit 4e3c242044
3 changed files with 149 additions and 38 deletions

View File

@@ -141,13 +141,13 @@ type ResponseOnlyErrorFunc[T any] interface {
func(stdContext.Context, T) error
}
// ContextValidatorFunc is a function which takes a context and a generic type T and returns an error.
// ContextRequestFunc is a function which takes a context and a generic type T and returns an error.
// It is used to validate the context before calling a service function.
//
// See Validation package-level function.
type ContextValidatorFunc[T any] func(*context.Context, T) error
type ContextRequestFunc[T any] func(*context.Context, T) error
const contextValidatorFuncKey = "iris.errors.ContextValidatorFunc"
const contextRequestHandlerFuncKey = "iris.errors.ContextRequestHandler"
// Validation adds a context validator function to the context.
// It returns a middleware which can be used to validate the context before calling a service function.
@@ -164,31 +164,31 @@ const contextValidatorFuncKey = "iris.errors.ContextValidatorFunc"
// validation.Slice("hobbies", r.Hobbies).Length(1, 10),
// )
// }
func Validation[T any](validators ...ContextValidatorFunc[T]) context.Handler {
validator := joinContextValidators[T](validators)
func Validation[T any](validators ...ContextRequestFunc[T]) context.Handler {
validator := joinContextRequestFuncs[T](validators)
return func(ctx *context.Context) {
ctx.Values().Set(contextValidatorFuncKey, validator)
ctx.Values().Set(contextRequestHandlerFuncKey, validator)
ctx.Next()
}
}
func joinContextValidators[T any](validators []ContextValidatorFunc[T]) ContextValidatorFunc[T] {
if len(validators) == 0 || validators[0] == nil {
panic("at least one validator is required")
func joinContextRequestFuncs[T any](requestHandlerFuncs []ContextRequestFunc[T]) ContextRequestFunc[T] {
if len(requestHandlerFuncs) == 0 || requestHandlerFuncs[0] == nil {
panic("at least one context request handler function is required")
}
if len(validators) == 1 {
return validators[0]
if len(requestHandlerFuncs) == 1 {
return requestHandlerFuncs[0]
}
return func(ctx *context.Context, req T) error {
for _, validator := range validators {
if validator == nil {
for _, handler := range requestHandlerFuncs {
if handler == nil {
continue
}
if err := validator(ctx, req); err != nil {
if err := handler(ctx, req); err != nil {
return err
}
}
@@ -197,38 +197,102 @@ func joinContextValidators[T any](validators []ContextValidatorFunc[T]) ContextV
}
}
// ContextValidator is an interface which can be implemented by a request payload struct
// RequestHandler is an interface which can be implemented by a request payload struct
// in order to validate the context before calling a service function.
type ContextValidator interface {
ValidateContext(*context.Context) error
type RequestHandler interface {
HandleRequest(*context.Context) error
}
func validateContext[T any](ctx *context.Context, req T) bool {
func validateRequest[T any](ctx *context.Context, req T) bool {
var err error
// Always run the request's validator first,
// so dynamic validators can be customized per path and method.
if contextValidator, ok := any(&req).(ContextValidator); ok {
err = contextValidator.ValidateContext(ctx)
if contextRequestHandler, ok := any(&req).(RequestHandler); ok {
err = contextRequestHandler.HandleRequest(ctx)
}
if err == nil {
if v := ctx.Values().Get(contextValidatorFuncKey); v != nil {
if contextValidatorFunc, ok := v.(ContextValidatorFunc[T]); ok {
err = contextValidatorFunc(ctx, req)
} else if contextValidatorFunc, ok := v.(ContextValidatorFunc[*T]); ok { // or a pointer of T.
err = contextValidatorFunc(ctx, &req)
if v := ctx.Values().Get(contextRequestHandlerFuncKey); v != nil {
if contextRequestHandlerFunc, ok := v.(ContextRequestFunc[T]); ok && contextRequestHandlerFunc != nil {
err = contextRequestHandlerFunc(ctx, req)
} else if contextRequestHandlerFunc, ok := v.(ContextRequestFunc[*T]); ok && contextRequestHandlerFunc != nil { // or a pointer of T.
err = contextRequestHandlerFunc(ctx, &req)
}
}
}
if err != nil {
if HandleError(ctx, err) {
return false
return err == nil || !HandleError(ctx, err)
}
// ResponseHandler is an interface which can be implemented by a request payload struct
// in order to handle a response before sending it to the client.
type ResponseHandler[R any, RPointer *R] interface {
HandleResponse(ctx *context.Context, response RPointer) error
}
// ContextResponseFunc is a function which takes a context, a generic type T and a generic type R and returns an error.
type ContextResponseFunc[T, R any, RPointer *R] func(*context.Context, T, RPointer) error
const contextResponseHandlerFuncKey = "iris.errors.ContextResponseHandler"
func validateResponse[T, R any, RPointer *R](ctx *context.Context, req T, resp RPointer) bool {
var err error
if contextResponseHandler, ok := any(&req).(ResponseHandler[R, RPointer]); ok {
err = contextResponseHandler.HandleResponse(ctx, resp)
}
if err == nil {
if v := ctx.Values().Get(contextResponseHandlerFuncKey); v != nil {
if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[T, R, RPointer]); ok && contextResponseHandlerFunc != nil {
err = contextResponseHandlerFunc(ctx, req, resp)
} else if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[*T, R, RPointer]); ok && contextResponseHandlerFunc != nil {
err = contextResponseHandlerFunc(ctx, &req, resp)
}
}
}
return true
return err == nil || !HandleError(ctx, err)
}
// Intercept adds a context response handler function to the context.
// It returns a middleware which can be used to intercept the response before sending it to the client.
//
// Example Code:
//
// app.Post("/", errors.Intercept(func(ctx iris.Context, req *CreateRequest, resp *CreateResponse) error{ ... }), errors.CreateHandler(service.Create))
func Intercept[T, R any, RPointer *R](responseHandlers ...ContextResponseFunc[T, R, RPointer]) context.Handler {
responseHandler := joinContextResponseFuncs[T, R, RPointer](responseHandlers)
return func(ctx *context.Context) {
ctx.Values().Set(contextResponseHandlerFuncKey, responseHandler)
ctx.Next()
}
}
func joinContextResponseFuncs[T, R any, RPointer *R](responseHandlerFuncs []ContextResponseFunc[T, R, RPointer]) ContextResponseFunc[T, R, RPointer] {
if len(responseHandlerFuncs) == 0 || responseHandlerFuncs[0] == nil {
panic("at least one context response handler function is required")
}
if len(responseHandlerFuncs) == 1 {
return responseHandlerFuncs[0]
}
return func(ctx *context.Context, req T, resp RPointer) error {
for _, handler := range responseHandlerFuncs {
if handler == nil {
continue
}
if err := handler(ctx, req, resp); err != nil {
return err
}
}
return nil
}
}
func bindResponse[T, R any, F ResponseFunc[T, R]](ctx *context.Context, fn F, fnInput ...T) (R, bool) {
@@ -247,12 +311,18 @@ func bindResponse[T, R any, F ResponseFunc[T, R]](ctx *context.Context, fn F, fn
panic("invalid number of arguments")
}
if !validateContext(ctx, req) {
if !validateRequest(ctx, req) {
var resp R
return resp, false
}
resp, err := fn(ctx, req)
if err == nil {
if !validateResponse(ctx, req, &resp) {
return resp, false
}
}
return resp, !HandleError(ctx, err)
}
@@ -372,7 +442,7 @@ func List[T, R any, C constraints.Integer | constraints.Float, F ListResponseFun
return false
}
if !validateContext(ctx, filter) {
if !validateRequest(ctx, filter) {
return false
}
@@ -383,7 +453,11 @@ func List[T, R any, C constraints.Integer | constraints.Float, F ListResponseFun
}
resp := pagination.NewList(items, int64(totalCount), filter, listOpts)
return Handle(ctx, resp, nil)
if !validateResponse(ctx, filter, resp) {
return false
}
return Handle(ctx, resp, err)
}
// ListHandler handles a generic response and error from a service paginated call and sends a JSON response to the client.