mirror of
https://github.com/kataras/iris.git
synced 2026-01-01 17:27:03 +00:00
New JWT features and changes (examples updated). Improvements on the Context User and Private Error features
TODO: Write the new e-book JWT section and the HISTORY entry of the chnages and add a simple example on site docs
This commit is contained in:
@@ -32,7 +32,6 @@ Most of the experimental handlers are ported to work with _iris_'s handler form,
|
||||
| [casbin](https://github.com/iris-contrib/middleware/tree/master/casbin)| An authorization library that supports access control models like ACL, RBAC, ABAC | [iris-contrib/middleware/casbin/_examples](https://github.com/iris-contrib/middleware/tree/master/casbin/_examples) |
|
||||
| [sentry-go (ex. raven)](https://github.com/getsentry/sentry-go/tree/master/iris)| Sentry client in Go | [sentry-go/example/iris](https://github.com/getsentry/sentry-go/blob/master/example/iris/main.go) | <!-- raven was deprecated by its company, the successor is sentry-go, they contain an Iris middleware. -->
|
||||
| [csrf](https://github.com/iris-contrib/middleware/tree/master/csrf)| Cross-Site Request Forgery Protection | [iris-contrib/middleware/csrf/_example](https://github.com/iris-contrib/middleware/blob/master/csrf/_example/main.go) |
|
||||
| [go-i18n](https://github.com/iris-contrib/middleware/tree/master/go-i18n)| i18n Iris Loader for nicksnyder/go-i18n | [iris-contrib/middleware/go-i18n/_example](https://github.com/iris-contrib/middleware/blob/master/go-i18n/_example/main.go) |
|
||||
| [throttler](https://github.com/iris-contrib/middleware/tree/master/throttler)| Rate limiting access to HTTP endpoints | [iris-contrib/middleware/throttler/_example](https://github.com/iris-contrib/middleware/blob/master/throttler/_example/main.go) |
|
||||
|
||||
Third-Party Handlers
|
||||
|
||||
@@ -2,6 +2,7 @@ package jwt
|
||||
|
||||
import (
|
||||
"github.com/square/go-jose/v3"
|
||||
"github.com/square/go-jose/v3/json"
|
||||
"github.com/square/go-jose/v3/jwt"
|
||||
)
|
||||
|
||||
@@ -14,11 +15,19 @@ type (
|
||||
// epoch, including leap seconds. Non-integer values can be represented
|
||||
// in the serialized format, but we round to the nearest second.
|
||||
NumericDate = jwt.NumericDate
|
||||
// Expected defines values used for protected claims validation.
|
||||
// If field has zero value then validation is skipped.
|
||||
Expected = jwt.Expected
|
||||
)
|
||||
|
||||
var (
|
||||
// NewNumericDate constructs NumericDate from time.Time value.
|
||||
NewNumericDate = jwt.NewNumericDate
|
||||
// Marshal returns the JSON encoding of v.
|
||||
Marshal = json.Marshal
|
||||
// Unmarshal parses the JSON-encoded data and stores the result
|
||||
// in the value pointed to by v.
|
||||
Unmarshal = json.Unmarshal
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
131
middleware/jwt/blocklist.go
Normal file
131
middleware/jwt/blocklist.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
stdContext "context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Blocklist is an in-memory storage of tokens that should be
|
||||
// immediately invalidated by the server-side.
|
||||
// The most common way to invalidate a token, e.g. on user logout,
|
||||
// is to make the client-side remove the token itself.
|
||||
// However, if someone else has access to that token,
|
||||
// it could be still valid for new requests until its expiration.
|
||||
type Blocklist struct {
|
||||
entries map[string]time.Time // key = token | value = expiration time (to remove expired).
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBlocklist returns a new up and running in-memory Token Blocklist.
|
||||
// The returned value can be set to the JWT instance's Blocklist field.
|
||||
func NewBlocklist(gcEvery time.Duration) *Blocklist {
|
||||
return NewBlocklistContext(stdContext.Background(), gcEvery)
|
||||
}
|
||||
|
||||
// NewBlocklistContext same as `NewBlocklist`
|
||||
// but it also accepts a standard Go Context for GC cancelation.
|
||||
func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blocklist {
|
||||
b := &Blocklist{
|
||||
entries: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if gcEvery > 0 {
|
||||
go b.runGC(ctx, gcEvery)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Set upserts a given token, with its expiration time,
|
||||
// to the block list, so it's immediately invalidated by the server-side.
|
||||
func (b *Blocklist) Set(token string, expiresAt time.Time) {
|
||||
b.mu.Lock()
|
||||
b.entries[token] = expiresAt
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// Del removes a "token" from the block list.
|
||||
func (b *Blocklist) Del(token string) {
|
||||
b.mu.Lock()
|
||||
delete(b.entries, token)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// Count returns the total amount of blocked tokens.
|
||||
func (b *Blocklist) Count() int {
|
||||
b.mu.RLock()
|
||||
n := len(b.entries)
|
||||
b.mu.RUnlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Has reports whether the given "token" is blocked by the server.
|
||||
// This method is called before the token verification,
|
||||
// so even if was expired it is removed from the block list.
|
||||
func (b *Blocklist) Has(token string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
_, ok := b.entries[token]
|
||||
b.mu.RUnlock()
|
||||
|
||||
/* No, the Blocklist will be used after the token is parsed,
|
||||
there we can call the Del method if err was ErrExpired.
|
||||
if ok {
|
||||
// As an extra step, to keep the list size as small as possible,
|
||||
// we delete it from list if it's going to be expired
|
||||
// ~in the next `blockedExpireLeeway` seconds.~
|
||||
// - Let's keep it easier for testing by not setting a leeway.
|
||||
// if time.Now().Add(blockedExpireLeeway).After(expiresAt) {
|
||||
if time.Now().After(expiresAt) {
|
||||
b.Del(token)
|
||||
}
|
||||
}*/
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// GC iterates over all entries and removes expired tokens.
|
||||
// This method is helpful to keep the list size small.
|
||||
// Depending on the application, the GC method can be scheduled
|
||||
// to called every half or a whole hour.
|
||||
// A good value for a GC cron task is the JWT's max age (default).
|
||||
func (b *Blocklist) GC() int {
|
||||
now := time.Now()
|
||||
var markedForDeletion []string
|
||||
|
||||
b.mu.RLock()
|
||||
for token, expiresAt := range b.entries {
|
||||
if now.After(expiresAt) {
|
||||
markedForDeletion = append(markedForDeletion, token)
|
||||
}
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
n := len(markedForDeletion)
|
||||
if n > 0 {
|
||||
for _, token := range markedForDeletion {
|
||||
b.Del(token)
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *Blocklist) runGC(ctx stdContext.Context, every time.Duration) {
|
||||
t := time.NewTicker(every)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
b.GC()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -85,6 +83,9 @@ func FromJSON(jsonKey string) TokenExtractor {
|
||||
//
|
||||
// The `RSA(privateFile, publicFile, password)` package-level helper function
|
||||
// can be used to decode the SignKey and VerifyKey.
|
||||
//
|
||||
// For an easy use look the `HMAC` package-level function
|
||||
// and the its `NewUser` and `VerifyUser` methods.
|
||||
type JWT struct {
|
||||
// MaxAge is the expiration duration of the generated tokens.
|
||||
MaxAge time.Duration
|
||||
@@ -109,6 +110,17 @@ type JWT struct {
|
||||
Encrypter jose.Encrypter
|
||||
// DecriptionKey is used to decrypt the token (private key)
|
||||
DecriptionKey interface{}
|
||||
|
||||
// Blocklist holds the invalidated-by-server tokens (that are not yet expired).
|
||||
// It is not initialized by default.
|
||||
// Initialization Usage:
|
||||
// j.UseBlocklist()
|
||||
// OR
|
||||
// j.Blocklist = jwt.NewBlocklist(gcEveryDuration)
|
||||
// Usage:
|
||||
// - ctx.Logout()
|
||||
// - j.Invalidate(ctx)
|
||||
Blocklist *Blocklist
|
||||
}
|
||||
|
||||
type privateKey interface{ Public() crypto.PublicKey }
|
||||
@@ -284,64 +296,68 @@ func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorit
|
||||
return nil
|
||||
}
|
||||
|
||||
// Expiry returns a new standard Claims with
|
||||
// the `Expiry` and `IssuedAt` fields of the "claims" filled
|
||||
// based on the given "maxAge" duration.
|
||||
//
|
||||
// See the `JWT.Expiry` method too.
|
||||
func Expiry(maxAge time.Duration, claims Claims) Claims {
|
||||
now := time.Now()
|
||||
claims.Expiry = NewNumericDate(now.Add(maxAge))
|
||||
claims.IssuedAt = NewNumericDate(now)
|
||||
return claims
|
||||
// UseBlocklist initializes the Blocklist.
|
||||
// Should be called on jwt middleware creation-time,
|
||||
// after this, the developer can use the Context.Logout method
|
||||
// to invalidate a verified token by the server-side.
|
||||
func (j *JWT) UseBlocklist() {
|
||||
gcEvery := 30 * time.Minute
|
||||
if j.MaxAge > 0 {
|
||||
gcEvery = j.MaxAge
|
||||
}
|
||||
j.Blocklist = NewBlocklist(gcEvery)
|
||||
}
|
||||
|
||||
// Expiry method same as `Expiry` package-level function,
|
||||
// it returns a Claims with the expiration fields of the "claims"
|
||||
// filled based on the JWT's `MaxAge` field.
|
||||
// Only use it when this standard "claims"
|
||||
// is embedded on a custom claims structure.
|
||||
// Usage:
|
||||
// type UserClaims struct {
|
||||
// jwt.Claims
|
||||
// Username string
|
||||
// }
|
||||
// [...]
|
||||
// standardClaims := j.Expiry(jwt.Claims{...})
|
||||
// customClaims := UserClaims{
|
||||
// Claims: standardClaims,
|
||||
// Username: "kataras",
|
||||
// }
|
||||
// j.WriteToken(ctx, customClaims)
|
||||
func (j *JWT) Expiry(claims Claims) Claims {
|
||||
return Expiry(j.MaxAge, claims)
|
||||
// ExpiryMap adds the expiration based on the "maxAge" to the "claims" map.
|
||||
// It's called automatically on `Token` method.
|
||||
func ExpiryMap(maxAge time.Duration, claims context.Map) {
|
||||
now := time.Now()
|
||||
if claims["exp"] == nil {
|
||||
claims["exp"] = NewNumericDate(now.Add(maxAge))
|
||||
}
|
||||
|
||||
if claims["iat"] == nil {
|
||||
claims["iat"] = NewNumericDate(now)
|
||||
}
|
||||
}
|
||||
|
||||
// Token generates and returns a new token string.
|
||||
// See `VerifyToken` too.
|
||||
func (j *JWT) Token(claims interface{}) (string, error) {
|
||||
// switch c := claims.(type) {
|
||||
// case Claims:
|
||||
// claims = Expiry(j.MaxAge, c)
|
||||
// case map[string]interface{}: let's not support map.
|
||||
// now := time.Now()
|
||||
// c["iat"] = now.Unix()
|
||||
// c["exp"] = now.Add(j.MaxAge).Unix()
|
||||
// }
|
||||
if c, ok := claims.(Claims); ok {
|
||||
claims = Expiry(j.MaxAge, c)
|
||||
return j.token(j.MaxAge, claims)
|
||||
}
|
||||
|
||||
func (j *JWT) token(maxAge time.Duration, claims interface{}) (string, error) {
|
||||
if claims == nil {
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
||||
c, nErr := normalize(claims)
|
||||
if nErr != nil {
|
||||
return "", nErr
|
||||
}
|
||||
|
||||
// Set expiration, if missing.
|
||||
ExpiryMap(maxAge, c)
|
||||
|
||||
var (
|
||||
token string
|
||||
err error
|
||||
)
|
||||
|
||||
// jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same.
|
||||
//
|
||||
// Note that the .Claims method there, converts a Struct to a map under the hoods.
|
||||
// That means that we will not have any performance cost
|
||||
// if we do it by ourselves and pass always a Map there.
|
||||
// That gives us the option to allow user to pass ANY go struct
|
||||
// and we can add the "exp", "nbf", "iat" map values by ourselves
|
||||
// based on the j.MaxAge.
|
||||
// (^ done, see normalize, all methods are
|
||||
// changed to accept totally custom types, no need to embed the standard Claims anymore).
|
||||
if j.DecriptionKey != nil {
|
||||
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize()
|
||||
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(c).CompactSerialize()
|
||||
} else {
|
||||
token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize()
|
||||
token, err = jwt.Signed(j.Signer).Claims(c).CompactSerialize()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -351,39 +367,6 @@ func (j *JWT) Token(claims interface{}) (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
/* Let's no support maps, typed claim is the way to go.
|
||||
// validateMapClaims validates claims of map type.
|
||||
func validateMapClaims(m map[string]interface{}, e jwt.Expected, leeway time.Duration) error {
|
||||
if !e.Time.IsZero() {
|
||||
if v, ok := m["nbf"]; ok {
|
||||
if notBefore, ok := v.(NumericDate); ok {
|
||||
if e.Time.Add(leeway).Before(notBefore.Time()) {
|
||||
return ErrNotValidYet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := m["exp"]; ok {
|
||||
if exp, ok := v.(int64); ok {
|
||||
if e.Time.Add(-leeway).Before(time.Unix(exp, 0)) {
|
||||
return ErrExpired
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := m["iat"]; ok {
|
||||
if issuedAt, ok := v.(int64); ok {
|
||||
if e.Time.Add(leeway).Before(time.Unix(issuedAt, 0)) {
|
||||
return ErrIssuedInTheFuture
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
|
||||
// WriteToken is a helper which just generates(calls the `Token` method) and writes
|
||||
// a new token to the client in plain text format.
|
||||
//
|
||||
@@ -399,91 +382,122 @@ func (j *JWT) WriteToken(ctx *context.Context, claims interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrMissing when token cannot be extracted from the request.
|
||||
ErrMissing = errors.New("token is missing")
|
||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||
ErrExpired = errors.New("token is expired (exp)")
|
||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||
ErrNotValidYet = errors.New("token not valid yet (nbf)")
|
||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||
ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
|
||||
)
|
||||
|
||||
type (
|
||||
claimsValidator interface {
|
||||
ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error
|
||||
}
|
||||
claimsAlternativeValidator interface { // to keep iris-contrib/jwt MapClaims compatible.
|
||||
Validate() error
|
||||
}
|
||||
claimsContextValidator interface {
|
||||
Validate(ctx *context.Context) error
|
||||
}
|
||||
)
|
||||
|
||||
// IsValidated reports whether a token is already validated through
|
||||
// `VerifyToken`. It returns true when the claims are compatible
|
||||
// validators: a `Claims` value or a value that implements the `Validate() error` method.
|
||||
func IsValidated(ctx *context.Context) bool { // see the `ReadClaims`.
|
||||
return ctx.Values().Get(needsValidationContextKey) == nil
|
||||
}
|
||||
|
||||
func validateClaims(ctx *context.Context, claims interface{}) (err error) {
|
||||
switch c := claims.(type) {
|
||||
case claimsValidator:
|
||||
err = c.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0)
|
||||
case claimsAlternativeValidator:
|
||||
err = c.Validate()
|
||||
case claimsContextValidator:
|
||||
err = c.Validate(ctx)
|
||||
case *json.RawMessage:
|
||||
// if the data type is raw message (json []byte)
|
||||
// then it should contain exp (and iat and nbf) keys.
|
||||
// Unmarshal raw message to validate against.
|
||||
v := new(Claims)
|
||||
err = json.Unmarshal(*c, v)
|
||||
if err == nil {
|
||||
return validateClaims(ctx, v)
|
||||
}
|
||||
default:
|
||||
ctx.Values().Set(needsValidationContextKey, struct{}{})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
case jwt.ErrExpired:
|
||||
return ErrExpired
|
||||
case jwt.ErrNotValidYet:
|
||||
return ErrNotValidYet
|
||||
case jwt.ErrIssuedInTheFuture:
|
||||
return ErrIssuedInTheFuture
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// VerifyToken verifies (and decrypts) the request token,
|
||||
// it also validates and binds the parsed token's claims to the "claimsPtr" (destination).
|
||||
// It does return a nil error on success.
|
||||
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}) error {
|
||||
var token string
|
||||
//
|
||||
// The last, variadic, input argument is optionally, if provided then the
|
||||
// parsed claims must match the expectations;
|
||||
// e.g. Audience, Issuer, ID, Subject.
|
||||
// See `ExpectXXX` package-functions for details.
|
||||
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}, expectations ...Expectation) (*TokenInfo, error) {
|
||||
token := j.RequestToken(ctx)
|
||||
return j.VerifyTokenString(ctx, token, claimsPtr, expectations...)
|
||||
}
|
||||
|
||||
// RequestToken extracts the token from the request.
|
||||
func (j *JWT) RequestToken(ctx *context.Context) (token string) {
|
||||
for _, extract := range j.Extractors {
|
||||
if token = extract(ctx); token != "" {
|
||||
break // ok we found it.
|
||||
}
|
||||
}
|
||||
|
||||
return j.VerifyTokenString(ctx, token, claimsPtr)
|
||||
return
|
||||
}
|
||||
|
||||
// VerifyTokenString verifies and unmarshals an extracted token to "claimsPtr" destination.
|
||||
// The Context is required when the claims validator needs it, otherwise can be nil.
|
||||
func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr interface{}) error {
|
||||
// TokenSetter is an interface which if implemented
|
||||
// the extracted, verified, token is stored to the object.
|
||||
type TokenSetter interface {
|
||||
SetToken(token string)
|
||||
}
|
||||
|
||||
// TokenInfo holds the standard token information may required
|
||||
// for further actions.
|
||||
// This structure is mostly useful when the developer's go structure
|
||||
// does not hold the standard jwt fields (e.g. "exp")
|
||||
// but want access to the parsed token which contains those fields.
|
||||
// Inside the middleware, it is used to invalidate tokens through server-side, see `Invalidate`.
|
||||
type TokenInfo struct {
|
||||
RequestToken string // The request token.
|
||||
Claims Claims // The standard JWT parsed fields from the request Token.
|
||||
Value interface{} // The pointer to the end-developer's custom claims structure (see `Get`).
|
||||
}
|
||||
|
||||
const tokenInfoContextKey = "iris.jwt.token"
|
||||
|
||||
// Get returns the verified developer token claims.
|
||||
//
|
||||
//
|
||||
// Usage:
|
||||
// j := jwt.New(...)
|
||||
// app.Use(j.Verify(func() interface{} { return new(CustomClaims) }))
|
||||
// app.Post("/restricted", func(ctx iris.Context){
|
||||
// claims := jwt.Get(ctx).(*CustomClaims)
|
||||
// [use claims...]
|
||||
// })
|
||||
//
|
||||
// Note that there is one exception, if the value was a pointer
|
||||
// to a map[string]interface{}, it returns the map itself so it can be
|
||||
// accessible directly without the requirement of unwrapping it, e.g.
|
||||
// j.Verify(func() interface{} {
|
||||
// return &iris.Map{}
|
||||
// }
|
||||
// [...]
|
||||
// claims := jwt.Get(ctx).(iris.Map)
|
||||
func Get(ctx *context.Context) interface{} {
|
||||
if tok := GetTokenInfo(ctx); tok != nil {
|
||||
switch v := tok.Value.(type) {
|
||||
case *context.Map:
|
||||
return *v
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenInfo returns the verified token's information.
|
||||
func GetTokenInfo(ctx *context.Context) *TokenInfo {
|
||||
if v := ctx.Values().Get(tokenInfoContextKey); v != nil {
|
||||
if t, ok := v.(*TokenInfo); ok {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate invalidates a verified JWT token.
|
||||
// It adds the request token, retrieved by Verify methods, to the block list.
|
||||
// Next request will be blocked, even if the token was not yet expired.
|
||||
// This method can be used when the client-side does not clear the token
|
||||
// on a user logout operation.
|
||||
//
|
||||
// Note: the Blocklist should be initialized before serve-time: j.UseBlocklist().
|
||||
func (j *JWT) Invalidate(ctx *context.Context) {
|
||||
if j.Blocklist == nil {
|
||||
ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo := GetTokenInfo(ctx)
|
||||
if tokenInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
j.Blocklist.Set(tokenInfo.RequestToken, tokenInfo.Claims.Expiry.Time())
|
||||
}
|
||||
|
||||
// VerifyTokenString verifies and unmarshals an extracted request token to "dest" destination.
|
||||
// The last variadic input indicates any further validations against the verified token claims.
|
||||
// If the given "dest" is a valid context.User then ctx.User() will return it.
|
||||
// If the token is missing an `ErrMissing` is returned.
|
||||
// If the incoming token was expired an `ErrExpired` is returned.
|
||||
// If the incoming token was blocked by the server an `ErrBlocked` is returned.
|
||||
func (j *JWT) VerifyTokenString(ctx *context.Context, token string, dest interface{}, expectations ...Expectation) (*TokenInfo, error) {
|
||||
if token == "" {
|
||||
return ErrMissing
|
||||
return nil, ErrMissing
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -494,7 +508,7 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in
|
||||
if j.DecriptionKey != nil {
|
||||
t, cerr := jwt.ParseSignedAndEncrypted(token)
|
||||
if cerr != nil {
|
||||
return cerr
|
||||
return nil, cerr
|
||||
}
|
||||
|
||||
parsedToken, err = t.Decrypt(j.DecriptionKey)
|
||||
@@ -502,112 +516,163 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in
|
||||
parsedToken, err = jwt.ParseSigned(token)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil {
|
||||
return err
|
||||
var claims Claims
|
||||
if err = parsedToken.Claims(j.VerificationKey, dest, &claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return validateClaims(ctx, claimsPtr)
|
||||
}
|
||||
|
||||
const (
|
||||
// ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method.
|
||||
ClaimsContextKey = "iris.jwt.claims"
|
||||
needsValidationContextKey = "iris.jwt.claims.unvalidated"
|
||||
)
|
||||
|
||||
// Verify is a middleware. It verifies and optionally decrypts an incoming request token.
|
||||
// It does write a 401 unauthorized status code if verification or decryption failed.
|
||||
// It calls the `ctx.Next` on verified requests.
|
||||
//
|
||||
// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once.
|
||||
//
|
||||
// A call of `ReadClaims` is required to validate and acquire the jwt claims
|
||||
// on the next request.
|
||||
func (j *JWT) Verify(ctx *context.Context) {
|
||||
var raw json.RawMessage
|
||||
if err := j.VerifyToken(ctx, &raw); err != nil {
|
||||
ctx.StopWithStatus(401)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Values().Set(ClaimsContextKey, raw)
|
||||
ctx.Next()
|
||||
}
|
||||
|
||||
// ReadClaims binds the "claimsPtr" (destination)
|
||||
// to the verified (and decrypted) claims.
|
||||
// The `Verify` method should be called first (registered as middleware).
|
||||
func ReadClaims(ctx *context.Context, claimsPtr interface{}) error {
|
||||
v := ctx.Values().Get(ClaimsContextKey)
|
||||
if v == nil {
|
||||
return ErrMissing
|
||||
}
|
||||
|
||||
raw, ok := v.(json.RawMessage)
|
||||
if !ok {
|
||||
return ErrMissing
|
||||
}
|
||||
|
||||
err := json.Unmarshal(raw, claimsPtr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !IsValidated(ctx) {
|
||||
// If already validated on `Verify/VerifyToken`
|
||||
// then no need to perform the check again.
|
||||
ctx.Values().Remove(needsValidationContextKey)
|
||||
return validateClaims(ctx, claimsPtr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns and validates (if not already) the claims
|
||||
// stored on request context's values storage.
|
||||
//
|
||||
// Should be used instead of the `ReadClaims` method when
|
||||
// a custom verification middleware was registered (see the `Verify` method for an example).
|
||||
//
|
||||
// Usage:
|
||||
// j := jwt.New(...)
|
||||
// [...]
|
||||
// app.Use(func(ctx iris.Context) {
|
||||
// var claims CustomClaims_or_jwt.Claims
|
||||
// if err := j.VerifyToken(ctx, &claims); err != nil {
|
||||
// ctx.StopWithStatus(iris.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// ctx.Values().Set(jwt.ClaimsContextKey, claims)
|
||||
// ctx.Next()
|
||||
// })
|
||||
// [...]
|
||||
// app.Post("/restricted", func(ctx iris.Context){
|
||||
// v, err := jwt.Get(ctx)
|
||||
// [handle error...]
|
||||
// claims,ok := v.(CustomClaims_or_jwt.Claims)
|
||||
// if !ok {
|
||||
// [do you support more than one type of claims? Handle here]
|
||||
// }
|
||||
// [use claims...]
|
||||
// })
|
||||
func Get(ctx *context.Context) (interface{}, error) {
|
||||
claims := ctx.Values().Get(ClaimsContextKey)
|
||||
if claims == nil {
|
||||
return nil, ErrMissing
|
||||
}
|
||||
|
||||
if !IsValidated(ctx) {
|
||||
ctx.Values().Remove(needsValidationContextKey)
|
||||
err := validateClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Build the Expected value.
|
||||
expected := Expected{}
|
||||
for _, e := range expectations {
|
||||
if e != nil {
|
||||
// expection can be used as a field validation too (see MeetRequirements).
|
||||
if err = e(&expected, dest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
// For other standard JWT claims fields such as "exp"
|
||||
// The developer can just add a field of Expiry *NumericDate `json:"exp"`
|
||||
// and will be filled by the parsed token automatically.
|
||||
// No need for more interfaces.
|
||||
|
||||
err = validateClaims(ctx, dest, claims, expected)
|
||||
if err != nil {
|
||||
if err == ErrExpired {
|
||||
// If token was expired remove it from the block list.
|
||||
if j.Blocklist != nil {
|
||||
j.Blocklist.Del(token)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if j.Blocklist != nil {
|
||||
// If token exists in the block list, then stop here.
|
||||
if j.Blocklist.Has(token) {
|
||||
return nil, ErrBlocked
|
||||
}
|
||||
}
|
||||
|
||||
if ut, ok := dest.(TokenSetter); ok {
|
||||
// The u.Token is empty even if we set it and export it on JSON structure.
|
||||
// Set it manually.
|
||||
ut.SetToken(token)
|
||||
}
|
||||
|
||||
// Set the information.
|
||||
tokenInfo := &TokenInfo{
|
||||
RequestToken: token,
|
||||
Claims: claims,
|
||||
Value: dest,
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// TokenPair holds the access token and refresh token response.
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// TokenPair generates a token pair of access and refresh tokens.
|
||||
// The first two arguments required for the refresh token
|
||||
// and the last one is the claims for the access token one.
|
||||
func (j *JWT) TokenPair(refreshMaxAge time.Duration, refreshClaims interface{}, accessClaims interface{}) (TokenPair, error) {
|
||||
accessToken, err := j.Token(accessClaims)
|
||||
if err != nil {
|
||||
return TokenPair{}, err
|
||||
}
|
||||
|
||||
refreshToken, err := j.token(refreshMaxAge, refreshClaims)
|
||||
if err != nil {
|
||||
return TokenPair{}, nil
|
||||
}
|
||||
|
||||
pair := TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
return pair, nil
|
||||
}
|
||||
|
||||
// Verify returns a middleware which
|
||||
// decrypts an incoming request token to the result of the given "newPtr".
|
||||
// It does write a 401 unauthorized status code if verification or decryption failed.
|
||||
// It calls the `ctx.Next` on verified requests.
|
||||
//
|
||||
// Iit unmarshals the token to the specific type returned from the given "newPtr" function.
|
||||
// It sets the Context User and User's Token too. So the next handler(s)
|
||||
// of the same chain can access the User through a `Context.User()` call.
|
||||
//
|
||||
// Note unlike `VerifyToken`, this method automatically protects
|
||||
// the claims with JSON required tags (see `MeetRequirements` Expection).
|
||||
//
|
||||
// On verified tokens:
|
||||
// - The information can be retrieved through `Get` and `GetTokenInfo` functions.
|
||||
// - User is set if the newPtr returns a valid Context User
|
||||
// - The Context Logout method is set if Blocklist was initialized
|
||||
// Any error is captured to the Context,
|
||||
// which can be retrieved by a `ctx.GetErr()` call.
|
||||
func (j *JWT) Verify(newPtr func() interface{}, expections ...Expectation) context.Handler {
|
||||
expections = append(expections, MeetRequirements(newPtr()))
|
||||
|
||||
return func(ctx *context.Context) {
|
||||
ptr := newPtr()
|
||||
|
||||
tokenInfo, err := j.VerifyToken(ctx, ptr, expections...)
|
||||
if err != nil {
|
||||
ctx.Application().Logger().Debugf("iris.jwt.Verify: %v", err)
|
||||
ctx.StopWithError(401, context.PrivateError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if u, ok := ptr.(context.User); ok {
|
||||
ctx.SetUser(u)
|
||||
}
|
||||
|
||||
if j.Blocklist != nil {
|
||||
ctx.SetLogoutFunc(j.Invalidate)
|
||||
}
|
||||
|
||||
ctx.Values().Set(tokenInfoContextKey, tokenInfo)
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser returns a new User based on the given "opts".
|
||||
// The caller can modify the User until its `GetToken` is called.
|
||||
func (j *JWT) NewUser(opts ...UserOption) *User {
|
||||
u := &User{
|
||||
j: j,
|
||||
SimpleUser: &context.SimpleUser{
|
||||
Authorization: "IRIS_JWT_USER", // Used to separate a refresh token with a user/access one too.
|
||||
Features: []context.UserFeature{
|
||||
context.TokenFeature,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(u)
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// VerifyUser works like the `Verify` method but instead
|
||||
// it unmarshals the token to the specific User type.
|
||||
// It sets the Context User too. So the next handler(s)
|
||||
// of the same chain can access the User through a `Context.User()` call.
|
||||
func (j *JWT) VerifyUser() context.Handler {
|
||||
return j.Verify(func() interface{} {
|
||||
return new(User)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,11 +12,15 @@ import (
|
||||
)
|
||||
|
||||
type userClaims struct {
|
||||
jwt.Claims
|
||||
Username string
|
||||
// Optionally:
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience jwt.Audience `json:"aud"`
|
||||
//
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
const testMaxAge = 3 * time.Second
|
||||
const testMaxAge = 7 * time.Second
|
||||
|
||||
// Random RSA verification and encryption.
|
||||
func TestRSA(t *testing.T) {
|
||||
@@ -25,13 +29,13 @@ func TestRSA(t *testing.T) {
|
||||
os.Remove(jwt.DefaultSignFilename)
|
||||
os.Remove(jwt.DefaultEncFilename)
|
||||
})
|
||||
testWriteVerifyToken(t, j)
|
||||
testWriteVerifyBlockToken(t, j)
|
||||
}
|
||||
|
||||
// HMAC verification and encryption.
|
||||
func TestHMAC(t *testing.T) {
|
||||
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
|
||||
testWriteVerifyToken(t, j)
|
||||
testWriteVerifyBlockToken(t, j)
|
||||
}
|
||||
|
||||
func TestNew_HMAC(t *testing.T) {
|
||||
@@ -44,7 +48,7 @@ func TestNew_HMAC(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testWriteVerifyToken(t, j)
|
||||
testWriteVerifyBlockToken(t, j)
|
||||
}
|
||||
|
||||
// HMAC verification only (unecrypted).
|
||||
@@ -53,54 +57,60 @@ func TestVerify(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testWriteVerifyToken(t, j)
|
||||
testWriteVerifyBlockToken(t, j)
|
||||
}
|
||||
|
||||
func testWriteVerifyToken(t *testing.T, j *jwt.JWT) {
|
||||
func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) {
|
||||
t.Helper()
|
||||
|
||||
j.UseBlocklist()
|
||||
j.Extractors = append(j.Extractors, jwt.FromJSON("access_token"))
|
||||
standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}}
|
||||
expectedClaims := userClaims{
|
||||
Claims: j.Expiry(standardClaims),
|
||||
|
||||
customClaims := &userClaims{
|
||||
Issuer: "an-issuer",
|
||||
Audience: jwt.Audience{"an-audience"},
|
||||
Subject: "user",
|
||||
Username: "kataras",
|
||||
}
|
||||
|
||||
app := iris.New()
|
||||
app.OnErrorCode(iris.StatusUnauthorized, func(ctx iris.Context) {
|
||||
if err := ctx.GetErr(); err != nil {
|
||||
// Test accessing the private error and set this as the response body.
|
||||
ctx.WriteString(err.Error())
|
||||
} else { // Else the default behavior
|
||||
ctx.WriteString(iris.StatusText(iris.StatusUnauthorized))
|
||||
}
|
||||
})
|
||||
|
||||
app.Get("/auth", func(ctx iris.Context) {
|
||||
j.WriteToken(ctx, expectedClaims)
|
||||
j.WriteToken(ctx, customClaims)
|
||||
})
|
||||
|
||||
app.Post("/restricted", func(ctx iris.Context) {
|
||||
app.Post("/protected", func(ctx iris.Context) {
|
||||
var claims userClaims
|
||||
if err := j.VerifyToken(ctx, &claims); err != nil {
|
||||
ctx.StopWithStatus(iris.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
app.Post("/restricted_middleware_readclaims", j.Verify, func(ctx iris.Context) {
|
||||
var claims userClaims
|
||||
if err := jwt.ReadClaims(ctx, &claims); err != nil {
|
||||
ctx.StopWithStatus(iris.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
app.Post("/restricted_middleware_get", j.Verify, func(ctx iris.Context) {
|
||||
claims, err := jwt.Get(ctx)
|
||||
_, err := j.VerifyToken(ctx, &claims)
|
||||
if err != nil {
|
||||
ctx.StopWithStatus(iris.StatusUnauthorized)
|
||||
// t.Logf("%s: %v", ctx.Path(), err)
|
||||
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
m := app.Party("/middleware")
|
||||
m.Use(j.Verify(func() interface{} {
|
||||
return new(userClaims)
|
||||
}))
|
||||
m.Post("/protected", func(ctx iris.Context) {
|
||||
claims := jwt.Get(ctx)
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
m.Post("/invalidate", func(ctx iris.Context) {
|
||||
ctx.Logout() // OR j.Invalidate(ctx)
|
||||
})
|
||||
|
||||
e := httptest.New(t, app)
|
||||
|
||||
// Get token.
|
||||
@@ -109,31 +119,186 @@ func testWriteVerifyToken(t *testing.T, j *jwt.JWT) {
|
||||
t.Fatalf("empty token")
|
||||
}
|
||||
|
||||
restrictedPaths := [...]string{"/restricted", "/restricted_middleware_readclaims", "/restricted_middleware_get"}
|
||||
restrictedPaths := [...]string{"/protected", "/middleware/protected"}
|
||||
|
||||
now := time.Now()
|
||||
for _, path := range restrictedPaths {
|
||||
// Authorization Header.
|
||||
e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
|
||||
Status(httptest.StatusOK).JSON().Equal(customClaims)
|
||||
|
||||
// URL Query.
|
||||
e.POST(path).WithQuery("token", rawToken).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
|
||||
Status(httptest.StatusOK).JSON().Equal(customClaims)
|
||||
|
||||
// JSON Body.
|
||||
e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
|
||||
Status(httptest.StatusOK).JSON().Equal(customClaims)
|
||||
|
||||
// Missing "Bearer".
|
||||
e.POST(path).WithHeader("Authorization", rawToken).Expect().
|
||||
Status(httptest.StatusUnauthorized)
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("token is missing")
|
||||
}
|
||||
|
||||
// Invalidate the token.
|
||||
e.POST("/middleware/invalidate").WithQuery("token", rawToken).Expect().
|
||||
Status(httptest.StatusOK)
|
||||
// Token is blocked by server.
|
||||
e.POST("/middleware/protected").WithQuery("token", rawToken).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("token is blocked")
|
||||
|
||||
expireRemDur := testMaxAge - time.Since(now)
|
||||
|
||||
// Expiration.
|
||||
time.Sleep(expireRemDur /* -end */)
|
||||
for _, path := range restrictedPaths {
|
||||
e.POST(path).WithQuery("token", rawToken).Expect().Status(httptest.StatusUnauthorized)
|
||||
e.POST(path).WithQuery("token", rawToken).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("token is expired (exp)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyMap(t *testing.T) {
|
||||
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
|
||||
expectedClaims := iris.Map{
|
||||
"iss": "tester",
|
||||
"username": "makis",
|
||||
"roles": []string{"admin"},
|
||||
}
|
||||
|
||||
app := iris.New()
|
||||
app.Get("/user/auth", func(ctx iris.Context) {
|
||||
err := j.WriteToken(ctx, expectedClaims)
|
||||
if err != nil {
|
||||
ctx.StopWithError(iris.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
if expectedClaims["exp"] == nil || expectedClaims["iat"] == nil {
|
||||
ctx.StopWithText(iris.StatusBadRequest,
|
||||
"exp or/and iat is nil - this means that the expiry was not set")
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
userAPI := app.Party("/user")
|
||||
userAPI.Post("/", func(ctx iris.Context) {
|
||||
var claims iris.Map
|
||||
if _, err := j.VerifyToken(ctx, &claims); err != nil {
|
||||
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
// Test map + Verify middleware.
|
||||
userAPI.Post("/middleware", j.Verify(func() interface{} {
|
||||
return &iris.Map{} // or &map[string]interface{}{}
|
||||
}), func(ctx iris.Context) {
|
||||
claims := jwt.Get(ctx)
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
e := httptest.New(t, app, httptest.LogLevel("error"))
|
||||
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
|
||||
if token == "" {
|
||||
t.Fatalf("empty token")
|
||||
}
|
||||
|
||||
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
|
||||
|
||||
e.POST("/user/middleware").WithHeader("Authorization", "Bearer "+token).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
|
||||
|
||||
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
|
||||
}
|
||||
|
||||
type customClaims struct {
|
||||
Username string `json:"username"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func (c *customClaims) SetToken(tok string) {
|
||||
c.Token = tok
|
||||
}
|
||||
|
||||
func TestVerifyStruct(t *testing.T) {
|
||||
maxAge := testMaxAge / 2
|
||||
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
|
||||
|
||||
app := iris.New()
|
||||
app.Get("/user/auth", func(ctx iris.Context) {
|
||||
err := j.WriteToken(ctx, customClaims{
|
||||
Username: "makis",
|
||||
})
|
||||
if err != nil {
|
||||
ctx.StopWithError(iris.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
userAPI := app.Party("/user")
|
||||
userAPI.Post("/", func(ctx iris.Context) {
|
||||
var claims customClaims
|
||||
if _, err := j.VerifyToken(ctx, &claims); err != nil {
|
||||
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(claims)
|
||||
})
|
||||
|
||||
e := httptest.New(t, app)
|
||||
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
|
||||
if token == "" {
|
||||
t.Fatalf("empty token")
|
||||
}
|
||||
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
|
||||
Status(httptest.StatusOK).JSON().Object().ContainsMap(iris.Map{
|
||||
"username": "makis",
|
||||
"token": token, // Test SetToken.
|
||||
})
|
||||
|
||||
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
|
||||
time.Sleep(maxAge)
|
||||
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().Status(httptest.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestVerifyUserAndExpected(t *testing.T) { // Tests the jwt.User struct + context validator + expected.
|
||||
maxAge := testMaxAge / 2
|
||||
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
|
||||
expectedUser := j.NewUser(jwt.Username("makis"), jwt.Roles("admin"), jwt.Fields(iris.Map{
|
||||
"custom": true,
|
||||
})) // only for the sake of the test, we iniitalize it here.
|
||||
expectedUser.Issuer = "tester"
|
||||
|
||||
app := iris.New()
|
||||
app.Get("/user/auth", func(ctx iris.Context) {
|
||||
tok, err := expectedUser.GetToken()
|
||||
if err != nil {
|
||||
ctx.StopWithError(iris.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
ctx.WriteString(tok)
|
||||
})
|
||||
|
||||
userAPI := app.Party("/user")
|
||||
userAPI.Use(jwt.WithExpected(jwt.Expected{Issuer: "tester"}, j.VerifyUser()))
|
||||
userAPI.Post("/", func(ctx iris.Context) {
|
||||
user := ctx.User()
|
||||
ctx.JSON(user)
|
||||
})
|
||||
|
||||
e := httptest.New(t, app)
|
||||
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
|
||||
if token == "" {
|
||||
t.Fatalf("empty token")
|
||||
}
|
||||
|
||||
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
|
||||
Status(httptest.StatusOK).JSON().Equal(expectedUser)
|
||||
|
||||
// Test generic client message if we don't manage the private error by ourselves.
|
||||
e.POST("/user").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
}
|
||||
|
||||
187
middleware/jwt/user.go
Normal file
187
middleware/jwt/user.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
)
|
||||
|
||||
// User a common User structure for JWT.
|
||||
// However, we're not limited to that one;
|
||||
// any Go structure can be generated as a JWT token.
|
||||
//
|
||||
// Look `NewUser` and `VerifyUser` JWT middleware's methods.
|
||||
// Use its `GetToken` method to generate the token when
|
||||
// the User structure is set.
|
||||
type User struct {
|
||||
Claims
|
||||
// Note: we could use a map too as the Token is generated when GetToken is called.
|
||||
*context.SimpleUser
|
||||
|
||||
j *JWT
|
||||
}
|
||||
|
||||
var (
|
||||
_ context.FeaturedUser = (*User)(nil)
|
||||
_ TokenSetter = (*User)(nil)
|
||||
_ ContextValidator = (*User)(nil)
|
||||
)
|
||||
|
||||
// UserOption sets optional fields for a new User
|
||||
// See `NewUser` instance function.
|
||||
type UserOption func(*User)
|
||||
|
||||
// Username sets the Username and the JWT Claim's Subject
|
||||
// to the given "username".
|
||||
func Username(username string) UserOption {
|
||||
return func(u *User) {
|
||||
u.Username = username
|
||||
u.Claims.Subject = username
|
||||
u.Features = append(u.Features, context.UsernameFeature)
|
||||
}
|
||||
}
|
||||
|
||||
// Email sets the Email field for the User field.
|
||||
func Email(email string) UserOption {
|
||||
return func(u *User) {
|
||||
u.Email = email
|
||||
u.Features = append(u.Features, context.EmailFeature)
|
||||
}
|
||||
}
|
||||
|
||||
// Roles upserts to the User's Roles field.
|
||||
func Roles(roles ...string) UserOption {
|
||||
return func(u *User) {
|
||||
u.Roles = roles
|
||||
u.Features = append(u.Features, context.RolesFeature)
|
||||
}
|
||||
}
|
||||
|
||||
// MaxAge sets claims expiration and the AuthorizedAt User field.
|
||||
func MaxAge(maxAge time.Duration) UserOption {
|
||||
return func(u *User) {
|
||||
now := time.Now()
|
||||
u.Claims.Expiry = NewNumericDate(now.Add(maxAge))
|
||||
u.Claims.IssuedAt = NewNumericDate(now)
|
||||
u.AuthorizedAt = now
|
||||
|
||||
u.Features = append(u.Features, context.AuthorizedAtFeature)
|
||||
}
|
||||
}
|
||||
|
||||
// Fields copies the "fields" to the user's Fields field.
|
||||
// This can be used to set custom fields to the User instance.
|
||||
func Fields(fields context.Map) UserOption {
|
||||
return func(u *User) {
|
||||
if len(fields) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if u.Fields == nil {
|
||||
u.Fields = make(context.Map, len(fields))
|
||||
}
|
||||
|
||||
for k, v := range fields {
|
||||
u.Fields[k] = v
|
||||
}
|
||||
|
||||
u.Features = append(u.Features, context.FieldsFeature)
|
||||
}
|
||||
}
|
||||
|
||||
// SetToken is called automaticaly on VerifyUser/VerifyObject.
|
||||
// It sets the extracted from request, and verified from server raw token.
|
||||
func (u *User) SetToken(token string) {
|
||||
u.Token = token
|
||||
}
|
||||
|
||||
// GetToken overrides the SimpleUser's Token
|
||||
// and returns the jwt generated token, among with
|
||||
// a generator error, if any.
|
||||
func (u *User) GetToken() (string, error) {
|
||||
if u.Token != "" {
|
||||
return u.Token, nil
|
||||
}
|
||||
|
||||
if u.j != nil { // it's always not nil.
|
||||
if u.j.MaxAge > 0 {
|
||||
// if the MaxAge option was not manually set, resolve it from the JWT instance.
|
||||
MaxAge(u.j.MaxAge)(u)
|
||||
}
|
||||
|
||||
// we could generate a token here
|
||||
// but let's do it on GetToken
|
||||
// as the user fields may change
|
||||
// by the caller manually until the token
|
||||
// sent to the client.
|
||||
tok, err := u.j.Token(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
u.Token = tok
|
||||
}
|
||||
|
||||
if u.Token == "" {
|
||||
return "", ErrMissing
|
||||
}
|
||||
|
||||
return u.Token, nil
|
||||
}
|
||||
|
||||
// Validate validates the current user's claims against
|
||||
// the request. It's called automatically by the JWT instance.
|
||||
func (u *User) Validate(ctx *context.Context, claims Claims, e Expected) error {
|
||||
err := u.Claims.ValidateWithLeeway(e, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if u.SimpleUser.Authorization != "IRIS_JWT_USER" {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// We could add specific User Expectations (new struct and accept an interface{}),
|
||||
// but for the sake of code simplicity we don't, unless is requested, as the caller
|
||||
// can validate specific fields by its own at the next step.
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json unmarshaler interface.
|
||||
func (u *User) UnmarshalJSON(data []byte) error {
|
||||
err := Unmarshal(data, &u.Claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
simpleUser := new(context.SimpleUser)
|
||||
err = Unmarshal(data, simpleUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.SimpleUser = simpleUser
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json marshaler interface.
|
||||
func (u *User) MarshalJSON() ([]byte, error) {
|
||||
claimsB, err := Marshal(u.Claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userB, err := Marshal(u.SimpleUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(userB) == 0 {
|
||||
return claimsB, nil
|
||||
}
|
||||
|
||||
claimsB = claimsB[0 : len(claimsB)-1] // remove last '}'
|
||||
userB = userB[1:] // remove first '{'
|
||||
|
||||
raw := append(claimsB, ',')
|
||||
raw = append(raw, userB...)
|
||||
return raw, nil
|
||||
}
|
||||
212
middleware/jwt/validation.go
Normal file
212
middleware/jwt/validation.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
|
||||
"github.com/square/go-jose/v3/json"
|
||||
// Use this package instead of the standard encoding/json
|
||||
// to marshal the NumericDate as expected by the implementation (see 'normalize`).
|
||||
"github.com/square/go-jose/v3/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
claimsExpectedContextKey = "iris.jwt.claims.expected"
|
||||
needsValidationContextKey = "iris.jwt.claims.unvalidated"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMissing when token cannot be extracted from the request.
|
||||
ErrMissing = errors.New("token is missing")
|
||||
// ErrMissingKey when token does not contain a required JSON field.
|
||||
ErrMissingKey = errors.New("token is missing a required field")
|
||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||
ErrExpired = errors.New("token is expired (exp)")
|
||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||
ErrNotValidYet = errors.New("token not valid yet (nbf)")
|
||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||
ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
|
||||
// ErrBlocked indicates that the token was not yet expired
|
||||
// but was blocked by the server's Blocklist.
|
||||
ErrBlocked = errors.New("token is blocked")
|
||||
)
|
||||
|
||||
// Expectation option to provide
|
||||
// an extra layer of token validation, a claims type protection.
|
||||
// See `VerifyToken` method.
|
||||
type Expectation func(e *Expected, claims interface{}) error
|
||||
|
||||
// Expect protects the claims with the expected values.
|
||||
func Expect(expected Expected) Expectation {
|
||||
return func(e *Expected, _ interface{}) error {
|
||||
*e = expected
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectID protects the claims with an ID validation.
|
||||
func ExpectID(id string) Expectation {
|
||||
return func(e *Expected, _ interface{}) error {
|
||||
e.ID = id
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectIssuer protects the claims with an issuer validation.
|
||||
func ExpectIssuer(issuer string) Expectation {
|
||||
return func(e *Expected, _ interface{}) error {
|
||||
e.Issuer = issuer
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectSubject protects the claims with a subject validation.
|
||||
func ExpectSubject(sub string) Expectation {
|
||||
return func(e *Expected, _ interface{}) error {
|
||||
e.Subject = sub
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectAudience protects the claims with an audience validation.
|
||||
func ExpectAudience(audience ...string) Expectation {
|
||||
return func(e *Expected, _ interface{}) error {
|
||||
e.Audience = audience
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MeetRequirements protects the custom fields of JWT claims
|
||||
// based on the json:required tag; `json:"name,required"`.
|
||||
// It accepts the value type.
|
||||
//
|
||||
// Usage:
|
||||
// Verify/VerifyToken(... MeetRequirements(MyUser{}))
|
||||
func MeetRequirements(claimsType interface{}) Expectation {
|
||||
// pre-calculate if we need to use reflection at serve time to check for required fields,
|
||||
// this can work as an alternative of expections for custom non-standard JWT fields.
|
||||
requireFieldsIndexes := getRequiredFieldIndexes(claimsType)
|
||||
|
||||
return func(e *Expected, claims interface{}) error {
|
||||
if len(requireFieldsIndexes) > 0 {
|
||||
val := reflect.Indirect(reflect.ValueOf(claims))
|
||||
for _, idx := range requireFieldsIndexes {
|
||||
field := val.Field(idx)
|
||||
if field.IsZero() {
|
||||
return ErrMissingKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithExpected is a middleware wrapper. It wraps a VerifyXXX middleware
|
||||
// with expected claims fields protection.
|
||||
// Usage:
|
||||
// jwt.WithExpected(jwt.Expected{Issuer:"app"}, j.VerifyUser)
|
||||
func WithExpected(e Expected, verifyHandler context.Handler) context.Handler {
|
||||
return func(ctx *context.Context) {
|
||||
ctx.Values().Set(claimsExpectedContextKey, e)
|
||||
verifyHandler(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ContextValidator validates the object based on the given
|
||||
// claims and the expected once. The end-developer
|
||||
// can use this method for advanced validations based on the request Context.
|
||||
type ContextValidator interface {
|
||||
Validate(ctx *context.Context, claims Claims, e Expected) error
|
||||
}
|
||||
|
||||
func validateClaims(ctx *context.Context, dest interface{}, claims Claims, expected Expected) (err error) {
|
||||
// Get any dynamic expectation set by prior middleware.
|
||||
// See `WithExpected` middleware.
|
||||
if v := ctx.Values().Get(claimsExpectedContextKey); v != nil {
|
||||
if e, ok := v.(Expected); ok {
|
||||
expected = e
|
||||
}
|
||||
}
|
||||
// Force-set the time, it's important for expiration.
|
||||
expected.Time = time.Now()
|
||||
switch c := dest.(type) {
|
||||
case Claims:
|
||||
err = c.ValidateWithLeeway(expected, 0)
|
||||
case ContextValidator:
|
||||
err = c.Validate(ctx, claims, expected)
|
||||
case *context.Map:
|
||||
// if the dest is a map then set automatically the expiration settings here,
|
||||
// so the caller can work further with it.
|
||||
err = claims.ValidateWithLeeway(expected, 0)
|
||||
if err == nil {
|
||||
(*c)["exp"] = claims.Expiry
|
||||
(*c)["iat"] = claims.IssuedAt
|
||||
if claims.NotBefore != nil {
|
||||
(*c)["nbf"] = claims.NotBefore
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = claims.ValidateWithLeeway(expected, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
case jwt.ErrExpired:
|
||||
return ErrExpired
|
||||
case jwt.ErrNotValidYet:
|
||||
return ErrNotValidYet
|
||||
case jwt.ErrIssuedInTheFuture:
|
||||
return ErrIssuedInTheFuture
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func normalize(i interface{}) (context.Map, error) {
|
||||
if m, ok := i.(context.Map); ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m := make(context.Map)
|
||||
|
||||
raw, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewReader(raw))
|
||||
d.UseNumber()
|
||||
|
||||
if err := d.Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func getRequiredFieldIndexes(i interface{}) (v []int) {
|
||||
val := reflect.Indirect(reflect.ValueOf(i))
|
||||
typ := val.Type()
|
||||
if typ.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
// Note: for the sake of simplicity we don't lookup for nested objects (FieldByIndex),
|
||||
// we could do that as we do in dependency injection feature but unless requirested we don't.
|
||||
tag := field.Tag.Get("json")
|
||||
if strings.Contains(tag, ",required") {
|
||||
v = append(v, i)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user