1
0
mirror of https://github.com/kataras/iris.git synced 2026-01-01 17:27:03 +00:00

New JWT features and changes (examples updated). Improvements on the Context User and Private Error features

TODO: Write the new e-book JWT section and the HISTORY entry of the chnages and  add a simple example on site docs
This commit is contained in:
Gerasimos (Makis) Maropoulos
2020-10-17 06:40:17 +03:00
parent b816156e77
commit 1864f99145
19 changed files with 1749 additions and 493 deletions

View File

@@ -32,7 +32,6 @@ Most of the experimental handlers are ported to work with _iris_'s handler form,
| [casbin](https://github.com/iris-contrib/middleware/tree/master/casbin)| An authorization library that supports access control models like ACL, RBAC, ABAC | [iris-contrib/middleware/casbin/_examples](https://github.com/iris-contrib/middleware/tree/master/casbin/_examples) |
| [sentry-go (ex. raven)](https://github.com/getsentry/sentry-go/tree/master/iris)| Sentry client in Go | [sentry-go/example/iris](https://github.com/getsentry/sentry-go/blob/master/example/iris/main.go) | <!-- raven was deprecated by its company, the successor is sentry-go, they contain an Iris middleware. -->
| [csrf](https://github.com/iris-contrib/middleware/tree/master/csrf)| Cross-Site Request Forgery Protection | [iris-contrib/middleware/csrf/_example](https://github.com/iris-contrib/middleware/blob/master/csrf/_example/main.go) |
| [go-i18n](https://github.com/iris-contrib/middleware/tree/master/go-i18n)| i18n Iris Loader for nicksnyder/go-i18n | [iris-contrib/middleware/go-i18n/_example](https://github.com/iris-contrib/middleware/blob/master/go-i18n/_example/main.go) |
| [throttler](https://github.com/iris-contrib/middleware/tree/master/throttler)| Rate limiting access to HTTP endpoints | [iris-contrib/middleware/throttler/_example](https://github.com/iris-contrib/middleware/blob/master/throttler/_example/main.go) |
Third-Party Handlers

View File

@@ -2,6 +2,7 @@ package jwt
import (
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/json"
"github.com/square/go-jose/v3/jwt"
)
@@ -14,11 +15,19 @@ type (
// epoch, including leap seconds. Non-integer values can be represented
// in the serialized format, but we round to the nearest second.
NumericDate = jwt.NumericDate
// Expected defines values used for protected claims validation.
// If field has zero value then validation is skipped.
Expected = jwt.Expected
)
var (
// NewNumericDate constructs NumericDate from time.Time value.
NewNumericDate = jwt.NewNumericDate
// Marshal returns the JSON encoding of v.
Marshal = json.Marshal
// Unmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v.
Unmarshal = json.Unmarshal
)
type (

131
middleware/jwt/blocklist.go Normal file
View File

@@ -0,0 +1,131 @@
package jwt
import (
stdContext "context"
"sync"
"time"
)
// Blocklist is an in-memory storage of tokens that should be
// immediately invalidated by the server-side.
// The most common way to invalidate a token, e.g. on user logout,
// is to make the client-side remove the token itself.
// However, if someone else has access to that token,
// it could be still valid for new requests until its expiration.
type Blocklist struct {
entries map[string]time.Time // key = token | value = expiration time (to remove expired).
mu sync.RWMutex
}
// NewBlocklist returns a new up and running in-memory Token Blocklist.
// The returned value can be set to the JWT instance's Blocklist field.
func NewBlocklist(gcEvery time.Duration) *Blocklist {
return NewBlocklistContext(stdContext.Background(), gcEvery)
}
// NewBlocklistContext same as `NewBlocklist`
// but it also accepts a standard Go Context for GC cancelation.
func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blocklist {
b := &Blocklist{
entries: make(map[string]time.Time),
}
if gcEvery > 0 {
go b.runGC(ctx, gcEvery)
}
return b
}
// Set upserts a given token, with its expiration time,
// to the block list, so it's immediately invalidated by the server-side.
func (b *Blocklist) Set(token string, expiresAt time.Time) {
b.mu.Lock()
b.entries[token] = expiresAt
b.mu.Unlock()
}
// Del removes a "token" from the block list.
func (b *Blocklist) Del(token string) {
b.mu.Lock()
delete(b.entries, token)
b.mu.Unlock()
}
// Count returns the total amount of blocked tokens.
func (b *Blocklist) Count() int {
b.mu.RLock()
n := len(b.entries)
b.mu.RUnlock()
return n
}
// Has reports whether the given "token" is blocked by the server.
// This method is called before the token verification,
// so even if was expired it is removed from the block list.
func (b *Blocklist) Has(token string) bool {
if token == "" {
return false
}
b.mu.RLock()
_, ok := b.entries[token]
b.mu.RUnlock()
/* No, the Blocklist will be used after the token is parsed,
there we can call the Del method if err was ErrExpired.
if ok {
// As an extra step, to keep the list size as small as possible,
// we delete it from list if it's going to be expired
// ~in the next `blockedExpireLeeway` seconds.~
// - Let's keep it easier for testing by not setting a leeway.
// if time.Now().Add(blockedExpireLeeway).After(expiresAt) {
if time.Now().After(expiresAt) {
b.Del(token)
}
}*/
return ok
}
// GC iterates over all entries and removes expired tokens.
// This method is helpful to keep the list size small.
// Depending on the application, the GC method can be scheduled
// to called every half or a whole hour.
// A good value for a GC cron task is the JWT's max age (default).
func (b *Blocklist) GC() int {
now := time.Now()
var markedForDeletion []string
b.mu.RLock()
for token, expiresAt := range b.entries {
if now.After(expiresAt) {
markedForDeletion = append(markedForDeletion, token)
}
}
b.mu.RUnlock()
n := len(markedForDeletion)
if n > 0 {
for _, token := range markedForDeletion {
b.Del(token)
}
}
return n
}
func (b *Blocklist) runGC(ctx stdContext.Context, every time.Duration) {
t := time.NewTicker(every)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
b.GC()
}
}
}

View File

@@ -2,8 +2,6 @@ package jwt
import (
"crypto"
"encoding/json"
"errors"
"os"
"strings"
"time"
@@ -85,6 +83,9 @@ func FromJSON(jsonKey string) TokenExtractor {
//
// The `RSA(privateFile, publicFile, password)` package-level helper function
// can be used to decode the SignKey and VerifyKey.
//
// For an easy use look the `HMAC` package-level function
// and the its `NewUser` and `VerifyUser` methods.
type JWT struct {
// MaxAge is the expiration duration of the generated tokens.
MaxAge time.Duration
@@ -109,6 +110,17 @@ type JWT struct {
Encrypter jose.Encrypter
// DecriptionKey is used to decrypt the token (private key)
DecriptionKey interface{}
// Blocklist holds the invalidated-by-server tokens (that are not yet expired).
// It is not initialized by default.
// Initialization Usage:
// j.UseBlocklist()
// OR
// j.Blocklist = jwt.NewBlocklist(gcEveryDuration)
// Usage:
// - ctx.Logout()
// - j.Invalidate(ctx)
Blocklist *Blocklist
}
type privateKey interface{ Public() crypto.PublicKey }
@@ -284,64 +296,68 @@ func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorit
return nil
}
// Expiry returns a new standard Claims with
// the `Expiry` and `IssuedAt` fields of the "claims" filled
// based on the given "maxAge" duration.
//
// See the `JWT.Expiry` method too.
func Expiry(maxAge time.Duration, claims Claims) Claims {
now := time.Now()
claims.Expiry = NewNumericDate(now.Add(maxAge))
claims.IssuedAt = NewNumericDate(now)
return claims
// UseBlocklist initializes the Blocklist.
// Should be called on jwt middleware creation-time,
// after this, the developer can use the Context.Logout method
// to invalidate a verified token by the server-side.
func (j *JWT) UseBlocklist() {
gcEvery := 30 * time.Minute
if j.MaxAge > 0 {
gcEvery = j.MaxAge
}
j.Blocklist = NewBlocklist(gcEvery)
}
// Expiry method same as `Expiry` package-level function,
// it returns a Claims with the expiration fields of the "claims"
// filled based on the JWT's `MaxAge` field.
// Only use it when this standard "claims"
// is embedded on a custom claims structure.
// Usage:
// type UserClaims struct {
// jwt.Claims
// Username string
// }
// [...]
// standardClaims := j.Expiry(jwt.Claims{...})
// customClaims := UserClaims{
// Claims: standardClaims,
// Username: "kataras",
// }
// j.WriteToken(ctx, customClaims)
func (j *JWT) Expiry(claims Claims) Claims {
return Expiry(j.MaxAge, claims)
// ExpiryMap adds the expiration based on the "maxAge" to the "claims" map.
// It's called automatically on `Token` method.
func ExpiryMap(maxAge time.Duration, claims context.Map) {
now := time.Now()
if claims["exp"] == nil {
claims["exp"] = NewNumericDate(now.Add(maxAge))
}
if claims["iat"] == nil {
claims["iat"] = NewNumericDate(now)
}
}
// Token generates and returns a new token string.
// See `VerifyToken` too.
func (j *JWT) Token(claims interface{}) (string, error) {
// switch c := claims.(type) {
// case Claims:
// claims = Expiry(j.MaxAge, c)
// case map[string]interface{}: let's not support map.
// now := time.Now()
// c["iat"] = now.Unix()
// c["exp"] = now.Add(j.MaxAge).Unix()
// }
if c, ok := claims.(Claims); ok {
claims = Expiry(j.MaxAge, c)
return j.token(j.MaxAge, claims)
}
func (j *JWT) token(maxAge time.Duration, claims interface{}) (string, error) {
if claims == nil {
return "", ErrInvalidKey
}
c, nErr := normalize(claims)
if nErr != nil {
return "", nErr
}
// Set expiration, if missing.
ExpiryMap(maxAge, c)
var (
token string
err error
)
// jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same.
//
// Note that the .Claims method there, converts a Struct to a map under the hoods.
// That means that we will not have any performance cost
// if we do it by ourselves and pass always a Map there.
// That gives us the option to allow user to pass ANY go struct
// and we can add the "exp", "nbf", "iat" map values by ourselves
// based on the j.MaxAge.
// (^ done, see normalize, all methods are
// changed to accept totally custom types, no need to embed the standard Claims anymore).
if j.DecriptionKey != nil {
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize()
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(c).CompactSerialize()
} else {
token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize()
token, err = jwt.Signed(j.Signer).Claims(c).CompactSerialize()
}
if err != nil {
@@ -351,39 +367,6 @@ func (j *JWT) Token(claims interface{}) (string, error) {
return token, nil
}
/* Let's no support maps, typed claim is the way to go.
// validateMapClaims validates claims of map type.
func validateMapClaims(m map[string]interface{}, e jwt.Expected, leeway time.Duration) error {
if !e.Time.IsZero() {
if v, ok := m["nbf"]; ok {
if notBefore, ok := v.(NumericDate); ok {
if e.Time.Add(leeway).Before(notBefore.Time()) {
return ErrNotValidYet
}
}
}
if v, ok := m["exp"]; ok {
if exp, ok := v.(int64); ok {
if e.Time.Add(-leeway).Before(time.Unix(exp, 0)) {
return ErrExpired
}
}
}
if v, ok := m["iat"]; ok {
if issuedAt, ok := v.(int64); ok {
if e.Time.Add(leeway).Before(time.Unix(issuedAt, 0)) {
return ErrIssuedInTheFuture
}
}
}
}
return nil
}
*/
// WriteToken is a helper which just generates(calls the `Token` method) and writes
// a new token to the client in plain text format.
//
@@ -399,91 +382,122 @@ func (j *JWT) WriteToken(ctx *context.Context, claims interface{}) error {
return err
}
var (
// ErrMissing when token cannot be extracted from the request.
ErrMissing = errors.New("token is missing")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("token is expired (exp)")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("token not valid yet (nbf)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
)
type (
claimsValidator interface {
ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error
}
claimsAlternativeValidator interface { // to keep iris-contrib/jwt MapClaims compatible.
Validate() error
}
claimsContextValidator interface {
Validate(ctx *context.Context) error
}
)
// IsValidated reports whether a token is already validated through
// `VerifyToken`. It returns true when the claims are compatible
// validators: a `Claims` value or a value that implements the `Validate() error` method.
func IsValidated(ctx *context.Context) bool { // see the `ReadClaims`.
return ctx.Values().Get(needsValidationContextKey) == nil
}
func validateClaims(ctx *context.Context, claims interface{}) (err error) {
switch c := claims.(type) {
case claimsValidator:
err = c.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0)
case claimsAlternativeValidator:
err = c.Validate()
case claimsContextValidator:
err = c.Validate(ctx)
case *json.RawMessage:
// if the data type is raw message (json []byte)
// then it should contain exp (and iat and nbf) keys.
// Unmarshal raw message to validate against.
v := new(Claims)
err = json.Unmarshal(*c, v)
if err == nil {
return validateClaims(ctx, v)
}
default:
ctx.Values().Set(needsValidationContextKey, struct{}{})
}
if err != nil {
switch err {
case jwt.ErrExpired:
return ErrExpired
case jwt.ErrNotValidYet:
return ErrNotValidYet
case jwt.ErrIssuedInTheFuture:
return ErrIssuedInTheFuture
}
}
return err
}
// VerifyToken verifies (and decrypts) the request token,
// it also validates and binds the parsed token's claims to the "claimsPtr" (destination).
// It does return a nil error on success.
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}) error {
var token string
//
// The last, variadic, input argument is optionally, if provided then the
// parsed claims must match the expectations;
// e.g. Audience, Issuer, ID, Subject.
// See `ExpectXXX` package-functions for details.
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}, expectations ...Expectation) (*TokenInfo, error) {
token := j.RequestToken(ctx)
return j.VerifyTokenString(ctx, token, claimsPtr, expectations...)
}
// RequestToken extracts the token from the request.
func (j *JWT) RequestToken(ctx *context.Context) (token string) {
for _, extract := range j.Extractors {
if token = extract(ctx); token != "" {
break // ok we found it.
}
}
return j.VerifyTokenString(ctx, token, claimsPtr)
return
}
// VerifyTokenString verifies and unmarshals an extracted token to "claimsPtr" destination.
// The Context is required when the claims validator needs it, otherwise can be nil.
func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr interface{}) error {
// TokenSetter is an interface which if implemented
// the extracted, verified, token is stored to the object.
type TokenSetter interface {
SetToken(token string)
}
// TokenInfo holds the standard token information may required
// for further actions.
// This structure is mostly useful when the developer's go structure
// does not hold the standard jwt fields (e.g. "exp")
// but want access to the parsed token which contains those fields.
// Inside the middleware, it is used to invalidate tokens through server-side, see `Invalidate`.
type TokenInfo struct {
RequestToken string // The request token.
Claims Claims // The standard JWT parsed fields from the request Token.
Value interface{} // The pointer to the end-developer's custom claims structure (see `Get`).
}
const tokenInfoContextKey = "iris.jwt.token"
// Get returns the verified developer token claims.
//
//
// Usage:
// j := jwt.New(...)
// app.Use(j.Verify(func() interface{} { return new(CustomClaims) }))
// app.Post("/restricted", func(ctx iris.Context){
// claims := jwt.Get(ctx).(*CustomClaims)
// [use claims...]
// })
//
// Note that there is one exception, if the value was a pointer
// to a map[string]interface{}, it returns the map itself so it can be
// accessible directly without the requirement of unwrapping it, e.g.
// j.Verify(func() interface{} {
// return &iris.Map{}
// }
// [...]
// claims := jwt.Get(ctx).(iris.Map)
func Get(ctx *context.Context) interface{} {
if tok := GetTokenInfo(ctx); tok != nil {
switch v := tok.Value.(type) {
case *context.Map:
return *v
default:
return v
}
}
return nil
}
// GetTokenInfo returns the verified token's information.
func GetTokenInfo(ctx *context.Context) *TokenInfo {
if v := ctx.Values().Get(tokenInfoContextKey); v != nil {
if t, ok := v.(*TokenInfo); ok {
return t
}
}
return nil
}
// Invalidate invalidates a verified JWT token.
// It adds the request token, retrieved by Verify methods, to the block list.
// Next request will be blocked, even if the token was not yet expired.
// This method can be used when the client-side does not clear the token
// on a user logout operation.
//
// Note: the Blocklist should be initialized before serve-time: j.UseBlocklist().
func (j *JWT) Invalidate(ctx *context.Context) {
if j.Blocklist == nil {
ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil")
return
}
tokenInfo := GetTokenInfo(ctx)
if tokenInfo == nil {
return
}
j.Blocklist.Set(tokenInfo.RequestToken, tokenInfo.Claims.Expiry.Time())
}
// VerifyTokenString verifies and unmarshals an extracted request token to "dest" destination.
// The last variadic input indicates any further validations against the verified token claims.
// If the given "dest" is a valid context.User then ctx.User() will return it.
// If the token is missing an `ErrMissing` is returned.
// If the incoming token was expired an `ErrExpired` is returned.
// If the incoming token was blocked by the server an `ErrBlocked` is returned.
func (j *JWT) VerifyTokenString(ctx *context.Context, token string, dest interface{}, expectations ...Expectation) (*TokenInfo, error) {
if token == "" {
return ErrMissing
return nil, ErrMissing
}
var (
@@ -494,7 +508,7 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in
if j.DecriptionKey != nil {
t, cerr := jwt.ParseSignedAndEncrypted(token)
if cerr != nil {
return cerr
return nil, cerr
}
parsedToken, err = t.Decrypt(j.DecriptionKey)
@@ -502,112 +516,163 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in
parsedToken, err = jwt.ParseSigned(token)
}
if err != nil {
return err
return nil, err
}
if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil {
return err
var claims Claims
if err = parsedToken.Claims(j.VerificationKey, dest, &claims); err != nil {
return nil, err
}
return validateClaims(ctx, claimsPtr)
}
const (
// ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method.
ClaimsContextKey = "iris.jwt.claims"
needsValidationContextKey = "iris.jwt.claims.unvalidated"
)
// Verify is a middleware. It verifies and optionally decrypts an incoming request token.
// It does write a 401 unauthorized status code if verification or decryption failed.
// It calls the `ctx.Next` on verified requests.
//
// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once.
//
// A call of `ReadClaims` is required to validate and acquire the jwt claims
// on the next request.
func (j *JWT) Verify(ctx *context.Context) {
var raw json.RawMessage
if err := j.VerifyToken(ctx, &raw); err != nil {
ctx.StopWithStatus(401)
return
}
ctx.Values().Set(ClaimsContextKey, raw)
ctx.Next()
}
// ReadClaims binds the "claimsPtr" (destination)
// to the verified (and decrypted) claims.
// The `Verify` method should be called first (registered as middleware).
func ReadClaims(ctx *context.Context, claimsPtr interface{}) error {
v := ctx.Values().Get(ClaimsContextKey)
if v == nil {
return ErrMissing
}
raw, ok := v.(json.RawMessage)
if !ok {
return ErrMissing
}
err := json.Unmarshal(raw, claimsPtr)
if err != nil {
return err
}
if !IsValidated(ctx) {
// If already validated on `Verify/VerifyToken`
// then no need to perform the check again.
ctx.Values().Remove(needsValidationContextKey)
return validateClaims(ctx, claimsPtr)
}
return nil
}
// Get returns and validates (if not already) the claims
// stored on request context's values storage.
//
// Should be used instead of the `ReadClaims` method when
// a custom verification middleware was registered (see the `Verify` method for an example).
//
// Usage:
// j := jwt.New(...)
// [...]
// app.Use(func(ctx iris.Context) {
// var claims CustomClaims_or_jwt.Claims
// if err := j.VerifyToken(ctx, &claims); err != nil {
// ctx.StopWithStatus(iris.StatusUnauthorized)
// return
// }
//
// ctx.Values().Set(jwt.ClaimsContextKey, claims)
// ctx.Next()
// })
// [...]
// app.Post("/restricted", func(ctx iris.Context){
// v, err := jwt.Get(ctx)
// [handle error...]
// claims,ok := v.(CustomClaims_or_jwt.Claims)
// if !ok {
// [do you support more than one type of claims? Handle here]
// }
// [use claims...]
// })
func Get(ctx *context.Context) (interface{}, error) {
claims := ctx.Values().Get(ClaimsContextKey)
if claims == nil {
return nil, ErrMissing
}
if !IsValidated(ctx) {
ctx.Values().Remove(needsValidationContextKey)
err := validateClaims(ctx, claims)
if err != nil {
return nil, err
// Build the Expected value.
expected := Expected{}
for _, e := range expectations {
if e != nil {
// expection can be used as a field validation too (see MeetRequirements).
if err = e(&expected, dest); err != nil {
return nil, err
}
}
}
return claims, nil
// For other standard JWT claims fields such as "exp"
// The developer can just add a field of Expiry *NumericDate `json:"exp"`
// and will be filled by the parsed token automatically.
// No need for more interfaces.
err = validateClaims(ctx, dest, claims, expected)
if err != nil {
if err == ErrExpired {
// If token was expired remove it from the block list.
if j.Blocklist != nil {
j.Blocklist.Del(token)
}
}
return nil, err
}
if j.Blocklist != nil {
// If token exists in the block list, then stop here.
if j.Blocklist.Has(token) {
return nil, ErrBlocked
}
}
if ut, ok := dest.(TokenSetter); ok {
// The u.Token is empty even if we set it and export it on JSON structure.
// Set it manually.
ut.SetToken(token)
}
// Set the information.
tokenInfo := &TokenInfo{
RequestToken: token,
Claims: claims,
Value: dest,
}
return tokenInfo, nil
}
// TokenPair holds the access token and refresh token response.
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// TokenPair generates a token pair of access and refresh tokens.
// The first two arguments required for the refresh token
// and the last one is the claims for the access token one.
func (j *JWT) TokenPair(refreshMaxAge time.Duration, refreshClaims interface{}, accessClaims interface{}) (TokenPair, error) {
accessToken, err := j.Token(accessClaims)
if err != nil {
return TokenPair{}, err
}
refreshToken, err := j.token(refreshMaxAge, refreshClaims)
if err != nil {
return TokenPair{}, nil
}
pair := TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
}
return pair, nil
}
// Verify returns a middleware which
// decrypts an incoming request token to the result of the given "newPtr".
// It does write a 401 unauthorized status code if verification or decryption failed.
// It calls the `ctx.Next` on verified requests.
//
// Iit unmarshals the token to the specific type returned from the given "newPtr" function.
// It sets the Context User and User's Token too. So the next handler(s)
// of the same chain can access the User through a `Context.User()` call.
//
// Note unlike `VerifyToken`, this method automatically protects
// the claims with JSON required tags (see `MeetRequirements` Expection).
//
// On verified tokens:
// - The information can be retrieved through `Get` and `GetTokenInfo` functions.
// - User is set if the newPtr returns a valid Context User
// - The Context Logout method is set if Blocklist was initialized
// Any error is captured to the Context,
// which can be retrieved by a `ctx.GetErr()` call.
func (j *JWT) Verify(newPtr func() interface{}, expections ...Expectation) context.Handler {
expections = append(expections, MeetRequirements(newPtr()))
return func(ctx *context.Context) {
ptr := newPtr()
tokenInfo, err := j.VerifyToken(ctx, ptr, expections...)
if err != nil {
ctx.Application().Logger().Debugf("iris.jwt.Verify: %v", err)
ctx.StopWithError(401, context.PrivateError(err))
return
}
if u, ok := ptr.(context.User); ok {
ctx.SetUser(u)
}
if j.Blocklist != nil {
ctx.SetLogoutFunc(j.Invalidate)
}
ctx.Values().Set(tokenInfoContextKey, tokenInfo)
ctx.Next()
}
}
// NewUser returns a new User based on the given "opts".
// The caller can modify the User until its `GetToken` is called.
func (j *JWT) NewUser(opts ...UserOption) *User {
u := &User{
j: j,
SimpleUser: &context.SimpleUser{
Authorization: "IRIS_JWT_USER", // Used to separate a refresh token with a user/access one too.
Features: []context.UserFeature{
context.TokenFeature,
},
},
}
for _, opt := range opts {
opt(u)
}
return u
}
// VerifyUser works like the `Verify` method but instead
// it unmarshals the token to the specific User type.
// It sets the Context User too. So the next handler(s)
// of the same chain can access the User through a `Context.User()` call.
func (j *JWT) VerifyUser() context.Handler {
return j.Verify(func() interface{} {
return new(User)
})
}

View File

@@ -12,11 +12,15 @@ import (
)
type userClaims struct {
jwt.Claims
Username string
// Optionally:
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience jwt.Audience `json:"aud"`
//
Username string `json:"username"`
}
const testMaxAge = 3 * time.Second
const testMaxAge = 7 * time.Second
// Random RSA verification and encryption.
func TestRSA(t *testing.T) {
@@ -25,13 +29,13 @@ func TestRSA(t *testing.T) {
os.Remove(jwt.DefaultSignFilename)
os.Remove(jwt.DefaultEncFilename)
})
testWriteVerifyToken(t, j)
testWriteVerifyBlockToken(t, j)
}
// HMAC verification and encryption.
func TestHMAC(t *testing.T) {
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
testWriteVerifyToken(t, j)
testWriteVerifyBlockToken(t, j)
}
func TestNew_HMAC(t *testing.T) {
@@ -44,7 +48,7 @@ func TestNew_HMAC(t *testing.T) {
t.Fatal(err)
}
testWriteVerifyToken(t, j)
testWriteVerifyBlockToken(t, j)
}
// HMAC verification only (unecrypted).
@@ -53,54 +57,60 @@ func TestVerify(t *testing.T) {
if err != nil {
t.Fatal(err)
}
testWriteVerifyToken(t, j)
testWriteVerifyBlockToken(t, j)
}
func testWriteVerifyToken(t *testing.T, j *jwt.JWT) {
func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) {
t.Helper()
j.UseBlocklist()
j.Extractors = append(j.Extractors, jwt.FromJSON("access_token"))
standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}}
expectedClaims := userClaims{
Claims: j.Expiry(standardClaims),
customClaims := &userClaims{
Issuer: "an-issuer",
Audience: jwt.Audience{"an-audience"},
Subject: "user",
Username: "kataras",
}
app := iris.New()
app.OnErrorCode(iris.StatusUnauthorized, func(ctx iris.Context) {
if err := ctx.GetErr(); err != nil {
// Test accessing the private error and set this as the response body.
ctx.WriteString(err.Error())
} else { // Else the default behavior
ctx.WriteString(iris.StatusText(iris.StatusUnauthorized))
}
})
app.Get("/auth", func(ctx iris.Context) {
j.WriteToken(ctx, expectedClaims)
j.WriteToken(ctx, customClaims)
})
app.Post("/restricted", func(ctx iris.Context) {
app.Post("/protected", func(ctx iris.Context) {
var claims userClaims
if err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithStatus(iris.StatusUnauthorized)
return
}
ctx.JSON(claims)
})
app.Post("/restricted_middleware_readclaims", j.Verify, func(ctx iris.Context) {
var claims userClaims
if err := jwt.ReadClaims(ctx, &claims); err != nil {
ctx.StopWithStatus(iris.StatusUnauthorized)
return
}
ctx.JSON(claims)
})
app.Post("/restricted_middleware_get", j.Verify, func(ctx iris.Context) {
claims, err := jwt.Get(ctx)
_, err := j.VerifyToken(ctx, &claims)
if err != nil {
ctx.StopWithStatus(iris.StatusUnauthorized)
// t.Logf("%s: %v", ctx.Path(), err)
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
m := app.Party("/middleware")
m.Use(j.Verify(func() interface{} {
return new(userClaims)
}))
m.Post("/protected", func(ctx iris.Context) {
claims := jwt.Get(ctx)
ctx.JSON(claims)
})
m.Post("/invalidate", func(ctx iris.Context) {
ctx.Logout() // OR j.Invalidate(ctx)
})
e := httptest.New(t, app)
// Get token.
@@ -109,31 +119,186 @@ func testWriteVerifyToken(t *testing.T, j *jwt.JWT) {
t.Fatalf("empty token")
}
restrictedPaths := [...]string{"/restricted", "/restricted_middleware_readclaims", "/restricted_middleware_get"}
restrictedPaths := [...]string{"/protected", "/middleware/protected"}
now := time.Now()
for _, path := range restrictedPaths {
// Authorization Header.
e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
Status(httptest.StatusOK).JSON().Equal(customClaims)
// URL Query.
e.POST(path).WithQuery("token", rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
Status(httptest.StatusOK).JSON().Equal(customClaims)
// JSON Body.
e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
Status(httptest.StatusOK).JSON().Equal(customClaims)
// Missing "Bearer".
e.POST(path).WithHeader("Authorization", rawToken).Expect().
Status(httptest.StatusUnauthorized)
Status(httptest.StatusUnauthorized).Body().Equal("token is missing")
}
// Invalidate the token.
e.POST("/middleware/invalidate").WithQuery("token", rawToken).Expect().
Status(httptest.StatusOK)
// Token is blocked by server.
e.POST("/middleware/protected").WithQuery("token", rawToken).Expect().
Status(httptest.StatusUnauthorized).Body().Equal("token is blocked")
expireRemDur := testMaxAge - time.Since(now)
// Expiration.
time.Sleep(expireRemDur /* -end */)
for _, path := range restrictedPaths {
e.POST(path).WithQuery("token", rawToken).Expect().Status(httptest.StatusUnauthorized)
e.POST(path).WithQuery("token", rawToken).Expect().
Status(httptest.StatusUnauthorized).Body().Equal("token is expired (exp)")
}
}
func TestVerifyMap(t *testing.T) {
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
expectedClaims := iris.Map{
"iss": "tester",
"username": "makis",
"roles": []string{"admin"},
}
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
err := j.WriteToken(ctx, expectedClaims)
if err != nil {
ctx.StopWithError(iris.StatusUnauthorized, err)
return
}
if expectedClaims["exp"] == nil || expectedClaims["iat"] == nil {
ctx.StopWithText(iris.StatusBadRequest,
"exp or/and iat is nil - this means that the expiry was not set")
return
}
})
userAPI := app.Party("/user")
userAPI.Post("/", func(ctx iris.Context) {
var claims iris.Map
if _, err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
// Test map + Verify middleware.
userAPI.Post("/middleware", j.Verify(func() interface{} {
return &iris.Map{} // or &map[string]interface{}{}
}), func(ctx iris.Context) {
claims := jwt.Get(ctx)
ctx.JSON(claims)
})
e := httptest.New(t, app, httptest.LogLevel("error"))
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
e.POST("/user/middleware").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
}
type customClaims struct {
Username string `json:"username"`
Token string `json:"token"`
}
func (c *customClaims) SetToken(tok string) {
c.Token = tok
}
func TestVerifyStruct(t *testing.T) {
maxAge := testMaxAge / 2
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
err := j.WriteToken(ctx, customClaims{
Username: "makis",
})
if err != nil {
ctx.StopWithError(iris.StatusUnauthorized, err)
return
}
})
userAPI := app.Party("/user")
userAPI.Post("/", func(ctx iris.Context) {
var claims customClaims
if _, err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
e := httptest.New(t, app)
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Object().ContainsMap(iris.Map{
"username": "makis",
"token": token, // Test SetToken.
})
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
time.Sleep(maxAge)
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().Status(httptest.StatusUnauthorized)
}
func TestVerifyUserAndExpected(t *testing.T) { // Tests the jwt.User struct + context validator + expected.
maxAge := testMaxAge / 2
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
expectedUser := j.NewUser(jwt.Username("makis"), jwt.Roles("admin"), jwt.Fields(iris.Map{
"custom": true,
})) // only for the sake of the test, we iniitalize it here.
expectedUser.Issuer = "tester"
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
tok, err := expectedUser.GetToken()
if err != nil {
ctx.StopWithError(iris.StatusInternalServerError, err)
return
}
ctx.WriteString(tok)
})
userAPI := app.Party("/user")
userAPI.Use(jwt.WithExpected(jwt.Expected{Issuer: "tester"}, j.VerifyUser()))
userAPI.Post("/", func(ctx iris.Context) {
user := ctx.User()
ctx.JSON(user)
})
e := httptest.New(t, app)
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedUser)
// Test generic client message if we don't manage the private error by ourselves.
e.POST("/user").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
}

187
middleware/jwt/user.go Normal file
View File

@@ -0,0 +1,187 @@
package jwt
import (
"time"
"github.com/kataras/iris/v12/context"
)
// User a common User structure for JWT.
// However, we're not limited to that one;
// any Go structure can be generated as a JWT token.
//
// Look `NewUser` and `VerifyUser` JWT middleware's methods.
// Use its `GetToken` method to generate the token when
// the User structure is set.
type User struct {
Claims
// Note: we could use a map too as the Token is generated when GetToken is called.
*context.SimpleUser
j *JWT
}
var (
_ context.FeaturedUser = (*User)(nil)
_ TokenSetter = (*User)(nil)
_ ContextValidator = (*User)(nil)
)
// UserOption sets optional fields for a new User
// See `NewUser` instance function.
type UserOption func(*User)
// Username sets the Username and the JWT Claim's Subject
// to the given "username".
func Username(username string) UserOption {
return func(u *User) {
u.Username = username
u.Claims.Subject = username
u.Features = append(u.Features, context.UsernameFeature)
}
}
// Email sets the Email field for the User field.
func Email(email string) UserOption {
return func(u *User) {
u.Email = email
u.Features = append(u.Features, context.EmailFeature)
}
}
// Roles upserts to the User's Roles field.
func Roles(roles ...string) UserOption {
return func(u *User) {
u.Roles = roles
u.Features = append(u.Features, context.RolesFeature)
}
}
// MaxAge sets claims expiration and the AuthorizedAt User field.
func MaxAge(maxAge time.Duration) UserOption {
return func(u *User) {
now := time.Now()
u.Claims.Expiry = NewNumericDate(now.Add(maxAge))
u.Claims.IssuedAt = NewNumericDate(now)
u.AuthorizedAt = now
u.Features = append(u.Features, context.AuthorizedAtFeature)
}
}
// Fields copies the "fields" to the user's Fields field.
// This can be used to set custom fields to the User instance.
func Fields(fields context.Map) UserOption {
return func(u *User) {
if len(fields) == 0 {
return
}
if u.Fields == nil {
u.Fields = make(context.Map, len(fields))
}
for k, v := range fields {
u.Fields[k] = v
}
u.Features = append(u.Features, context.FieldsFeature)
}
}
// SetToken is called automaticaly on VerifyUser/VerifyObject.
// It sets the extracted from request, and verified from server raw token.
func (u *User) SetToken(token string) {
u.Token = token
}
// GetToken overrides the SimpleUser's Token
// and returns the jwt generated token, among with
// a generator error, if any.
func (u *User) GetToken() (string, error) {
if u.Token != "" {
return u.Token, nil
}
if u.j != nil { // it's always not nil.
if u.j.MaxAge > 0 {
// if the MaxAge option was not manually set, resolve it from the JWT instance.
MaxAge(u.j.MaxAge)(u)
}
// we could generate a token here
// but let's do it on GetToken
// as the user fields may change
// by the caller manually until the token
// sent to the client.
tok, err := u.j.Token(u)
if err != nil {
return "", err
}
u.Token = tok
}
if u.Token == "" {
return "", ErrMissing
}
return u.Token, nil
}
// Validate validates the current user's claims against
// the request. It's called automatically by the JWT instance.
func (u *User) Validate(ctx *context.Context, claims Claims, e Expected) error {
err := u.Claims.ValidateWithLeeway(e, 0)
if err != nil {
return err
}
if u.SimpleUser.Authorization != "IRIS_JWT_USER" {
return ErrInvalidKey
}
// We could add specific User Expectations (new struct and accept an interface{}),
// but for the sake of code simplicity we don't, unless is requested, as the caller
// can validate specific fields by its own at the next step.
return nil
}
// UnmarshalJSON implements the json unmarshaler interface.
func (u *User) UnmarshalJSON(data []byte) error {
err := Unmarshal(data, &u.Claims)
if err != nil {
return err
}
simpleUser := new(context.SimpleUser)
err = Unmarshal(data, simpleUser)
if err != nil {
return err
}
u.SimpleUser = simpleUser
return nil
}
// MarshalJSON implements the json marshaler interface.
func (u *User) MarshalJSON() ([]byte, error) {
claimsB, err := Marshal(u.Claims)
if err != nil {
return nil, err
}
userB, err := Marshal(u.SimpleUser)
if err != nil {
return nil, err
}
if len(userB) == 0 {
return claimsB, nil
}
claimsB = claimsB[0 : len(claimsB)-1] // remove last '}'
userB = userB[1:] // remove first '{'
raw := append(claimsB, ',')
raw = append(raw, userB...)
return raw, nil
}

View File

@@ -0,0 +1,212 @@
package jwt
import (
"bytes"
"errors"
"reflect"
"strings"
"time"
"github.com/kataras/iris/v12/context"
"github.com/square/go-jose/v3/json"
// Use this package instead of the standard encoding/json
// to marshal the NumericDate as expected by the implementation (see 'normalize`).
"github.com/square/go-jose/v3/jwt"
)
const (
claimsExpectedContextKey = "iris.jwt.claims.expected"
needsValidationContextKey = "iris.jwt.claims.unvalidated"
)
var (
// ErrMissing when token cannot be extracted from the request.
ErrMissing = errors.New("token is missing")
// ErrMissingKey when token does not contain a required JSON field.
ErrMissingKey = errors.New("token is missing a required field")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("token is expired (exp)")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("token not valid yet (nbf)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
// ErrBlocked indicates that the token was not yet expired
// but was blocked by the server's Blocklist.
ErrBlocked = errors.New("token is blocked")
)
// Expectation option to provide
// an extra layer of token validation, a claims type protection.
// See `VerifyToken` method.
type Expectation func(e *Expected, claims interface{}) error
// Expect protects the claims with the expected values.
func Expect(expected Expected) Expectation {
return func(e *Expected, _ interface{}) error {
*e = expected
return nil
}
}
// ExpectID protects the claims with an ID validation.
func ExpectID(id string) Expectation {
return func(e *Expected, _ interface{}) error {
e.ID = id
return nil
}
}
// ExpectIssuer protects the claims with an issuer validation.
func ExpectIssuer(issuer string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Issuer = issuer
return nil
}
}
// ExpectSubject protects the claims with a subject validation.
func ExpectSubject(sub string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Subject = sub
return nil
}
}
// ExpectAudience protects the claims with an audience validation.
func ExpectAudience(audience ...string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Audience = audience
return nil
}
}
// MeetRequirements protects the custom fields of JWT claims
// based on the json:required tag; `json:"name,required"`.
// It accepts the value type.
//
// Usage:
// Verify/VerifyToken(... MeetRequirements(MyUser{}))
func MeetRequirements(claimsType interface{}) Expectation {
// pre-calculate if we need to use reflection at serve time to check for required fields,
// this can work as an alternative of expections for custom non-standard JWT fields.
requireFieldsIndexes := getRequiredFieldIndexes(claimsType)
return func(e *Expected, claims interface{}) error {
if len(requireFieldsIndexes) > 0 {
val := reflect.Indirect(reflect.ValueOf(claims))
for _, idx := range requireFieldsIndexes {
field := val.Field(idx)
if field.IsZero() {
return ErrMissingKey
}
}
}
return nil
}
}
// WithExpected is a middleware wrapper. It wraps a VerifyXXX middleware
// with expected claims fields protection.
// Usage:
// jwt.WithExpected(jwt.Expected{Issuer:"app"}, j.VerifyUser)
func WithExpected(e Expected, verifyHandler context.Handler) context.Handler {
return func(ctx *context.Context) {
ctx.Values().Set(claimsExpectedContextKey, e)
verifyHandler(ctx)
}
}
// ContextValidator validates the object based on the given
// claims and the expected once. The end-developer
// can use this method for advanced validations based on the request Context.
type ContextValidator interface {
Validate(ctx *context.Context, claims Claims, e Expected) error
}
func validateClaims(ctx *context.Context, dest interface{}, claims Claims, expected Expected) (err error) {
// Get any dynamic expectation set by prior middleware.
// See `WithExpected` middleware.
if v := ctx.Values().Get(claimsExpectedContextKey); v != nil {
if e, ok := v.(Expected); ok {
expected = e
}
}
// Force-set the time, it's important for expiration.
expected.Time = time.Now()
switch c := dest.(type) {
case Claims:
err = c.ValidateWithLeeway(expected, 0)
case ContextValidator:
err = c.Validate(ctx, claims, expected)
case *context.Map:
// if the dest is a map then set automatically the expiration settings here,
// so the caller can work further with it.
err = claims.ValidateWithLeeway(expected, 0)
if err == nil {
(*c)["exp"] = claims.Expiry
(*c)["iat"] = claims.IssuedAt
if claims.NotBefore != nil {
(*c)["nbf"] = claims.NotBefore
}
}
default:
err = claims.ValidateWithLeeway(expected, 0)
}
if err != nil {
switch err {
case jwt.ErrExpired:
return ErrExpired
case jwt.ErrNotValidYet:
return ErrNotValidYet
case jwt.ErrIssuedInTheFuture:
return ErrIssuedInTheFuture
}
}
return err
}
func normalize(i interface{}) (context.Map, error) {
if m, ok := i.(context.Map); ok {
return m, nil
}
m := make(context.Map)
raw, err := json.Marshal(i)
if err != nil {
return nil, err
}
d := json.NewDecoder(bytes.NewReader(raw))
d.UseNumber()
if err := d.Decode(&m); err != nil {
return nil, err
}
return m, nil
}
func getRequiredFieldIndexes(i interface{}) (v []int) {
val := reflect.Indirect(reflect.ValueOf(i))
typ := val.Type()
if typ.Kind() != reflect.Struct {
return nil
}
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// Note: for the sake of simplicity we don't lookup for nested objects (FieldByIndex),
// we could do that as we do in dependency injection feature but unless requirested we don't.
tag := field.Tag.Get("json")
if strings.Contains(tag, ",required") {
v = append(v, i)
}
}
return
}