1
0
mirror of https://github.com/kataras/iris.git synced 2026-01-10 21:45:57 +00:00

As noticed in my previous commit, the existing jwt libraries added a lot of performance cost between jwt-featured requests and simple requests. That's why a new custom JWT parser was created. This commit adds our custom jwt parser as the underline token signer and verifier

This commit is contained in:
Gerasimos (Makis) Maropoulos
2020-10-30 22:12:16 +02:00
parent d517f36a29
commit 8eea0296a7
21 changed files with 750 additions and 2431 deletions

View File

@@ -1,91 +1,105 @@
package jwt
import (
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/json"
"github.com/square/go-jose/v3/jwt"
"github.com/kataras/jwt"
)
// Type alises for the underline jwt package.
type (
// Claims represents public claim values (as specified in RFC 7519).
// Alg is the signature algorithm interface alias.
Alg = jwt.Alg
// Claims represents the standard claim values (as specified in RFC 7519).
Claims = jwt.Claims
// Audience represents the recipients that the token is intended for.
Audience = jwt.Audience
// NumericDate represents date and time as the number of seconds since the
// epoch, including leap seconds. Non-integer values can be represented
// in the serialized format, but we round to the nearest second.
NumericDate = jwt.NumericDate
// Expected defines values used for protected claims validation.
// If field has zero value then validation is skipped.
// Expected is a TokenValidator which performs simple checks
// between standard claims values.
//
// Usage:
// expecteed := jwt.Expected{
// Issuer: "my-app",
// }
// verifiedToken, err := verifier.Verify(..., expected)
Expected = jwt.Expected
)
var (
// NewNumericDate constructs NumericDate from time.Time value.
NewNumericDate = jwt.NewNumericDate
// Marshal returns the JSON encoding of v.
Marshal = json.Marshal
// Unmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v.
Unmarshal = json.Unmarshal
)
type (
// KeyAlgorithm represents a key management algorithm.
KeyAlgorithm = jose.KeyAlgorithm
// SignatureAlgorithm represents a signature (or MAC) algorithm.
SignatureAlgorithm = jose.SignatureAlgorithm
// ContentEncryption represents a content encryption algorithm.
ContentEncryption = jose.ContentEncryption
)
// Key management algorithms.
const (
ED25519 = jose.ED25519
RSA15 = jose.RSA1_5
RSAOAEP = jose.RSA_OAEP
RSAOAEP256 = jose.RSA_OAEP_256
A128KW = jose.A128KW
A192KW = jose.A192KW
A256KW = jose.A256KW
DIRECT = jose.DIRECT
ECDHES = jose.ECDH_ES
ECDHESA128KW = jose.ECDH_ES_A128KW
ECDHESA192KW = jose.ECDH_ES_A192KW
ECDHESA256KW = jose.ECDH_ES_A256KW
A128GCMKW = jose.A128GCMKW
A192GCMKW = jose.A192GCMKW
A256GCMKW = jose.A256GCMKW
PBES2HS256A128KW = jose.PBES2_HS256_A128KW
PBES2HS384A192KW = jose.PBES2_HS384_A192KW
PBES2HS512A256KW = jose.PBES2_HS512_A256KW
// TokenValidator is the token validator interface alias.
TokenValidator = jwt.TokenValidator
// VerifiedToken is the type alias for the verfieid token type,
// the result of the VerifyToken function.
VerifiedToken = jwt.VerifiedToken
// SignOption used to set signing options at Sign function.
SignOption = jwt.SignOption
// TokenPair is just a helper structure which holds both access and refresh tokens.
TokenPair = jwt.TokenPair
)
// Signature algorithms.
const (
EdDSA = jose.EdDSA
HS256 = jose.HS256
HS384 = jose.HS384
HS512 = jose.HS512
RS256 = jose.RS256
RS384 = jose.RS384
RS512 = jose.RS512
ES256 = jose.ES256
ES384 = jose.ES384
ES512 = jose.ES512
PS256 = jose.PS256
PS384 = jose.PS384
PS512 = jose.PS512
var (
EdDSA = jwt.EdDSA
HS256 = jwt.HS256
HS384 = jwt.HS384
HS512 = jwt.HS512
RS256 = jwt.RS256
RS384 = jwt.RS384
RS512 = jwt.RS512
ES256 = jwt.ES256
ES384 = jwt.ES384
ES512 = jwt.ES512
PS256 = jwt.PS256
PS384 = jwt.PS384
PS512 = jwt.PS512
)
// Content encryption algorithms.
const (
A128CBCHS256 = jose.A128CBC_HS256
A192CBCHS384 = jose.A192CBC_HS384
A256CBCHS512 = jose.A256CBC_HS512
A128GCM = jose.A128GCM
A192GCM = jose.A192GCM
A256GCM = jose.A256GCM
// Encryption algorithms.
var (
GCM = jwt.GCM
// Helper to generate random key,
// can be used to generate hmac signature key and GCM+AES for testing.
MustGenerateRandom = jwt.MustGenerateRandom
)
var (
// Leeway adds validation for a leeway expiration time.
// If the token was not expired then a comparison between
// this "leeway" and the token's "exp" one is expected to pass instead (now+leeway > exp).
// Example of use case: disallow tokens that are going to be expired in 3 seconds from now,
// this is useful to make sure that the token is valid when the when the user fires a database call for example.
// Usage:
// verifiedToken, err := verifier.Verify(..., jwt.Leeway(5*time.Second))
Leeway = jwt.Leeway
// MaxAge is a SignOption to set the expiration "exp", "iat" JWT standard claims.
// Can be passed as last input argument of the `Sign` function.
//
// If maxAge > second then sets expiration to the token.
// It's a helper field to set the "exp" and "iat" claim values.
// Usage:
// signer.Sign(..., jwt.MaxAge(15*time.Minute))
MaxAge = jwt.MaxAge
)
// Shortcuts for Signing and Verifying.
var (
VerifyToken = jwt.Verify
VerifyEncryptedToken = jwt.VerifyEncrypted
Sign = jwt.Sign
SignEncrypted = jwt.SignEncrypted
)
// Signature algorithm helpers.
var (
MustLoadHMAC = jwt.MustLoadHMAC
LoadHMAC = jwt.LoadHMAC
MustLoadRSA = jwt.MustLoadRSA
LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA
LoadPublicKeyRSA = jwt.LoadPublicKeyRSA
ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA
ParsePublicKeyRSA = jwt.ParsePublicKeyRSA
MustLoadECDSA = jwt.MustLoadECDSA
LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA
LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA
ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA
ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA
MustLoadEdDSA = jwt.MustLoadEdDSA
LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA
LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA
ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA
ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA
)

View File

@@ -1,9 +1,7 @@
package jwt
import (
stdContext "context"
"sync"
"time"
"github.com/kataras/jwt"
)
// Blocklist should hold and manage invalidated-by-server tokens.
@@ -15,136 +13,14 @@ import (
// e.g. a redis one to keep persistence of invalidated tokens on server restarts.
// and bind to the JWT middleware's Blocklist field.
type Blocklist interface {
// Set should upsert a token to the storage.
Set(token string, expiresAt time.Time)
jwt.TokenValidator
// InvalidateToken should invalidate a verified JWT token.
InvalidateToken(token []byte, expiry int64)
// Del should remove a token from the storage.
Del(token string)
Del(token []byte)
// Count should return the total amount of tokens stored.
Count() int
// Has should report whether a specific token exists in the storage.
Has(token string) bool
}
// blocklist is an in-memory storage of tokens that should be
// immediately invalidated by the server-side.
// The most common way to invalidate a token, e.g. on user logout,
// is to make the client-side remove the token itself.
// However, if someone else has access to that token,
// it could be still valid for new requests until its expiration.
type blocklist struct {
entries map[string]time.Time // key = token | value = expiration time (to remove expired).
mu sync.RWMutex
}
// NewBlocklist returns a new up and running in-memory Token Blocklist.
// The returned value can be set to the JWT instance's Blocklist field.
func NewBlocklist(gcEvery time.Duration) Blocklist {
return NewBlocklistContext(stdContext.Background(), gcEvery)
}
// NewBlocklistContext same as `NewBlocklist`
// but it also accepts a standard Go Context for GC cancelation.
func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) Blocklist {
b := &blocklist{
entries: make(map[string]time.Time),
}
if gcEvery > 0 {
go b.runGC(ctx, gcEvery)
}
return b
}
// Set upserts a given token, with its expiration time,
// to the block list, so it's immediately invalidated by the server-side.
func (b *blocklist) Set(token string, expiresAt time.Time) {
b.mu.Lock()
b.entries[token] = expiresAt
b.mu.Unlock()
}
// Del removes a "token" from the block list.
func (b *blocklist) Del(token string) {
b.mu.Lock()
delete(b.entries, token)
b.mu.Unlock()
}
// Count returns the total amount of blocked tokens.
func (b *blocklist) Count() int {
b.mu.RLock()
n := len(b.entries)
b.mu.RUnlock()
return n
}
// Has reports whether the given "token" is blocked by the server.
// This method is called before the token verification,
// so even if was expired it is removed from the block list.
func (b *blocklist) Has(token string) bool {
if token == "" {
return false
}
b.mu.RLock()
_, ok := b.entries[token]
b.mu.RUnlock()
/* No, the Blocklist will be used after the token is parsed,
there we can call the Del method if err was ErrExpired.
if ok {
// As an extra step, to keep the list size as small as possible,
// we delete it from list if it's going to be expired
// ~in the next `blockedExpireLeeway` seconds.~
// - Let's keep it easier for testing by not setting a leeway.
// if time.Now().Add(blockedExpireLeeway).After(expiresAt) {
if time.Now().After(expiresAt) {
b.Del(token)
}
}*/
return ok
}
// GC iterates over all entries and removes expired tokens.
// This method is helpful to keep the list size small.
// Depending on the application, the GC method can be scheduled
// to called every half or a whole hour.
// A good value for a GC cron task is the JWT's max age (default).
func (b *blocklist) GC() int {
now := time.Now()
var markedForDeletion []string
b.mu.RLock()
for token, expiresAt := range b.entries {
if now.After(expiresAt) {
markedForDeletion = append(markedForDeletion, token)
}
}
b.mu.RUnlock()
n := len(markedForDeletion)
if n > 0 {
for _, token := range markedForDeletion {
b.Del(token)
}
}
return n
}
func (b *blocklist) runGC(ctx stdContext.Context, every time.Duration) {
t := time.NewTicker(every)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
b.GC()
}
}
Has(token []byte) bool
}

View File

@@ -0,0 +1,71 @@
package jwt
import (
"strings"
"github.com/kataras/iris/v12/context"
)
// TokenExtractor is a function that takes a context as input and returns
// a token. An empty string should be returned if no token found
// without additional information.
type TokenExtractor func(*context.Context) string
// FromHeader is a token extractor.
// It reads the token from the Authorization request header of form:
// Authorization: "Bearer {token}".
func FromHeader(ctx *context.Context) string {
authHeader := ctx.GetHeader("Authorization")
if authHeader == "" {
return ""
}
// pure check: authorization header format must be Bearer {token}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return ""
}
return authHeaderParts[1]
}
// FromQuery is a token extractor.
// It reads the token from the "token" url query parameter.
func FromQuery(ctx *context.Context) string {
return ctx.URLParam("token")
}
// FromJSON is a token extractor.
// Reads a json request body and extracts the json based on the given field.
// The request content-type should contain the: application/json header value, otherwise
// this method will not try to read and consume the body.
func FromJSON(jsonKey string) TokenExtractor {
return func(ctx *context.Context) string {
if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue {
return ""
}
var m context.Map
ctx.RecordRequestBody(true)
defer ctx.RecordRequestBody(false)
if err := ctx.ReadJSON(&m); err != nil {
return ""
}
if m == nil {
return ""
}
v, ok := m[jsonKey]
if !ok {
return ""
}
tok, ok := v.(string)
if !ok {
return ""
}
return tok
}
}

View File

@@ -1,861 +1,7 @@
package jwt
import (
"crypto"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/kataras/iris/v12/context"
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/jwt"
)
import "github.com/kataras/iris/v12/context"
func init() {
context.SetHandlerName("iris/middleware/jwt.*", "iris.jwt")
}
// TokenExtractor is a function that takes a context as input and returns
// a token. An empty string should be returned if no token found
// without additional information.
type TokenExtractor func(*context.Context) string
// FromHeader is a token extractor.
// It reads the token from the Authorization request header of form:
// Authorization: "Bearer {token}".
func FromHeader(ctx *context.Context) string {
authHeader := ctx.GetHeader("Authorization")
if authHeader == "" {
return ""
}
// pure check: authorization header format must be Bearer {token}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return ""
}
return authHeaderParts[1]
}
// FromQuery is a token extractor.
// It reads the token from the "token" url query parameter.
func FromQuery(ctx *context.Context) string {
return ctx.URLParam("token")
}
// FromJSON is a token extractor.
// Reads a json request body and extracts the json based on the given field.
// The request content-type should contain the: application/json header value, otherwise
// this method will not try to read and consume the body.
func FromJSON(jsonKey string) TokenExtractor {
return func(ctx *context.Context) string {
if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue {
return ""
}
var m context.Map
ctx.RecordRequestBody(true)
defer ctx.RecordRequestBody(false)
if err := ctx.ReadJSON(&m); err != nil {
return ""
}
if m == nil {
return ""
}
v, ok := m[jsonKey]
if !ok {
return ""
}
tok, ok := v.(string)
if !ok {
return ""
}
return tok
}
}
// JWT holds the necessary information the middleware need
// to sign and verify tokens.
//
// The `RSA(privateFile, publicFile, password)` package-level helper function
// can be used to decode the SignKey and VerifyKey.
//
// For an easy use look the `HMAC` package-level function
// and the its `NewUser` and `VerifyUser` methods.
type JWT struct {
// MaxAge is the expiration duration of the generated tokens.
MaxAge time.Duration
// Extractors are used to extract a raw token string value
// from the request.
// Builtin extractors:
// * FromHeader
// * FromQuery
// * FromJSON
// Defaults to a slice of `FromHeader` and `FromQuery`.
Extractors []TokenExtractor
// Signer is used to sign the token.
// It is set on `New` and `Default` package-level functions.
Signer jose.Signer
// VerificationKey is used to verify the token (public key).
VerificationKey interface{}
// Encrypter is used to, optionally, encrypt the token.
// It is set on `WithEncryption` method.
Encrypter jose.Encrypter
// DecriptionKey is used to decrypt the token (private key)
DecriptionKey interface{}
// Blocklist holds the invalidated-by-server tokens (that are not yet expired).
// It is not initialized by default.
// Initialization Usage:
// j.InitDefaultBlocklist()
// OR
// j.Blocklist = jwt.NewBlocklist(gcEveryDuration)
// Usage:
// - ctx.Logout()
// - j.Invalidate(ctx)
Blocklist Blocklist
}
type privateKey interface{ Public() crypto.PublicKey }
// New returns a new JWT instance.
// It accepts a maximum time duration for token expiration
// and the algorithm among with its key for signing and verification.
//
// See `WithEncryption` method to add token encryption too.
// Use `Token` method to generate a new token string
// and `VerifyToken` method to decrypt, verify and bind claims of an incoming request token.
// Token, by default, is extracted by "Authorization: Bearer {token}" request header and
// url query parameter of "token". Token extractors can be modified through the `Extractors` field.
//
// For example, if you want to sign and verify using RSA-256 key:
// 1. Generate key file, e.g:
// $ openssl genrsa -des3 -out private.pem 2048
// 2. Read file contents with io.ReadFile("./private.pem")
// 3. Pass the []byte result to the `ParseRSAPrivateKey(contents, password)` package-level helper
// 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function.
//
// See aliases.go file for available algorithms.
func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, error) {
sig, err := jose.NewSigner(jose.SigningKey{
Algorithm: alg,
Key: key,
}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, err
}
j := &JWT{
Signer: sig,
VerificationKey: key,
MaxAge: maxAge,
Extractors: []TokenExtractor{FromHeader, FromQuery},
}
if s, ok := key.(privateKey); ok {
j.VerificationKey = s.Public()
}
return j, nil
}
// Default key filenames for `RSA`.
const (
DefaultSignFilename = "jwt_sign.key"
DefaultEncFilename = "jwt_enc.key"
)
// RSA returns a new `JWT` instance.
// It tries to parse RSA256 keys from "filenames[0]" (defaults to "jwt_sign.key") and
// "filenames[1]" (defaults to "jwt_enc.key") files or generates and exports new random keys.
//
// It panics on errors.
// Use the `New` package-level function instead for more options.
func RSA(maxAge time.Duration, filenames ...string) *JWT {
var (
signFilename = DefaultSignFilename
encFilename = DefaultEncFilename
)
switch len(filenames) {
case 1:
signFilename = filenames[0]
case 2:
encFilename = filenames[1]
}
// Do not try to create or load enc key if only sign key already exists.
withEncryption := true
if fileExists(signFilename) {
withEncryption = fileExists(encFilename)
}
sigKey, err := LoadRSA(signFilename, 2048)
if err != nil {
panic(err)
}
j, err := New(maxAge, RS256, sigKey)
if err != nil {
panic(err)
}
if withEncryption {
encKey, err := LoadRSA(encFilename, 2048)
if err != nil {
panic(err)
}
err = j.WithEncryption(A128CBCHS256, RSA15, encKey)
if err != nil {
panic(err)
}
}
return j
}
const (
signEnv = "JWT_SECRET"
encEnv = "JWT_SECRET_ENC"
)
func getenv(key string, def string) string {
v := os.Getenv(key)
if v == "" {
return def
}
return v
}
// HMAC returns a new `JWT` instance.
// It tries to read hmac256 secret keys from system environment variables:
// * JWT_SECRET for signing and verification key and
// * JWT_SECRET_ENC for encryption and decryption key
// and defaults them to the given "keys" respectfully.
//
// It panics on errors.
// Use the `New` package-level function instead for more options.
//
// Example at:
// https://github.com/kataras/iris/tree/master/_examples/auth/jwt/overview/main.go
func HMAC(maxAge time.Duration, keys ...string) *JWT {
var defaultSignSecret, defaultEncSecret string
switch len(keys) {
case 1:
defaultSignSecret = keys[0]
case 2:
defaultEncSecret = keys[1]
}
signSecret := getenv(signEnv, defaultSignSecret)
encSecret := getenv(encEnv, defaultEncSecret)
j, err := New(maxAge, HS256, []byte(signSecret))
if err != nil {
panic(err)
}
if encSecret != "" {
err = j.WithEncryption(A128GCM, DIRECT, []byte(encSecret))
if err != nil {
panic(err)
}
}
return j
}
// WithEncryption method enables encryption and decryption of the token.
// It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type.
func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error {
var publicKey interface{} = key
if s, ok := key.(privateKey); ok {
publicKey = s.Public()
}
enc, err := jose.NewEncrypter(contentEncryption, jose.Recipient{
Algorithm: alg,
Key: publicKey,
},
(&jose.EncrypterOptions{}).WithType("JWT").WithContentType("JWT"),
)
if err != nil {
return err
}
j.Encrypter = enc
j.DecriptionKey = key
return nil
}
// InitDefaultBlocklist initializes the Blocklist field with the default in-memory implementation.
// Should be called on jwt middleware creation-time,
// after this, the developer can use the Context.Logout method
// to invalidate a verified token by the server-side.
func (j *JWT) InitDefaultBlocklist() {
gcEvery := 30 * time.Minute
if j.MaxAge > 0 {
gcEvery = j.MaxAge
}
j.Blocklist = NewBlocklist(gcEvery)
}
// ExpiryMap adds the expiration based on the "maxAge" to the "claims" map.
// It's called automatically on `Token` method.
func ExpiryMap(maxAge time.Duration, claims context.Map) {
now := time.Now()
if claims["exp"] == nil {
claims["exp"] = NewNumericDate(now.Add(maxAge))
}
if claims["iat"] == nil {
claims["iat"] = NewNumericDate(now)
}
}
// Token generates and returns a new token string.
// See `VerifyToken` too.
func (j *JWT) Token(claims interface{}) (string, error) {
return j.token(j.MaxAge, claims)
}
func (j *JWT) token(maxAge time.Duration, claims interface{}) (string, error) {
if claims == nil {
return "", ErrInvalidKey
}
c, nErr := normalize(claims)
if nErr != nil {
return "", nErr
}
ExpiryMap(maxAge, c)
var (
token string
err error
)
// jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same.
//
// Note that the .Claims method there, converts a Struct to a map under the hoods.
// That means that we will not have any performance cost
// if we do it by ourselves and pass always a Map there.
// That gives us the option to allow user to pass ANY go struct
// and we can add the "exp", "nbf", "iat" map values by ourselves
// based on the j.MaxAge.
// (^ done, see normalize, all methods are
// changed to accept totally custom types, no need to embed the standard Claims anymore).
if j.DecriptionKey != nil {
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(c).CompactSerialize()
} else {
token, err = jwt.Signed(j.Signer).Claims(c).CompactSerialize()
// payload, pErr := Marshal(c)
// if pErr != nil {
// return "", pErr
// }
// sign, sErr := j.Signer.Sign(payload)
// if sErr != nil {
// return "", sErr
// }
// token, err = sign.CompactSerialize()
}
if err != nil {
return "", err
}
return token, nil
}
// WriteToken is a helper which just generates(calls the `Token` method) and writes
// a new token to the client in plain text format.
//
// Use the `Token` method to get a new generated token raw string value.
func (j *JWT) WriteToken(ctx *context.Context, claims interface{}) error {
token, err := j.Token(claims)
if err != nil {
ctx.StatusCode(500)
return err
}
_, err = ctx.WriteString(token)
return err
}
// VerifyToken verifies (and decrypts) the request token,
// it also validates and binds the parsed token's claims to the "claimsPtr" (destination).
//
// The last, variadic, input argument is optionally, if provided then the
// parsed claims must match the expectations;
// e.g. Audience, Issuer, ID, Subject.
// See `ExpectXXX` package-functions for details.
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}, expectations ...Expectation) (*TokenInfo, error) {
token := j.RequestToken(ctx)
return j.VerifyTokenString(ctx, token, claimsPtr, expectations...)
}
// VerifyRefreshToken like the `VerifyToken` but it verifies a refresh token one instead.
// If the implementation does not fill the application's requirements,
// you can ignore this method and still use the `VerifyToken` for refresh tokens too.
//
// This method adds the ExpectRefreshToken expectation and it
// tries to read the refresh token from raw body or,
// if content type was application/json, then it extracts the token
// from the JSON request body's {"refresh_token": "$token"} key.
func (j *JWT) VerifyRefreshToken(ctx *context.Context, claimsPtr interface{}, expectations ...Expectation) (*TokenInfo, error) {
token := j.RequestToken(ctx)
if token == "" {
ctx.RecordRequestBody(true)
defer ctx.RecordRequestBody(false)
var tokenPair TokenPair // read "refresh_token" from JSON.
if ctx.GetContentTypeRequested() == context.ContentJSONHeaderValue {
ctx.ReadJSON(&tokenPair) // ignore error.
token = tokenPair.RefreshToken
if token == "" {
return nil, ErrMissing
}
} else {
ctx.ReadBody(&token)
}
}
return j.VerifyTokenString(ctx, token, claimsPtr, append(expectations, ExpectRefreshToken)...)
}
// RequestToken extracts the token from the request.
func (j *JWT) RequestToken(ctx *context.Context) (token string) {
for _, extract := range j.Extractors {
if token = extract(ctx); token != "" {
break // ok we found it.
}
}
return
}
// TokenSetter is an interface which if implemented
// the extracted, verified, token is stored to the object.
type TokenSetter interface {
SetToken(token string)
}
// TokenInfo holds the standard token information may required
// for further actions.
// This structure is mostly useful when the developer's go structure
// does not hold the standard jwt fields (e.g. "exp")
// but want access to the parsed token which contains those fields.
// Inside the middleware, it is used to invalidate tokens through server-side, see `Invalidate`.
type TokenInfo struct {
RequestToken string // The request token.
Claims Claims // The standard JWT parsed fields from the request Token.
Value interface{} // The pointer to the end-developer's custom claims structure (see `Get`).
}
const tokenInfoContextKey = "iris.jwt.token"
// Get returns the verified developer token claims.
//
//
// Usage:
// j := jwt.New(...)
// app.Use(j.Verify(func() interface{} { return new(CustomClaims) }))
// app.Post("/restricted", func(ctx iris.Context){
// claims := jwt.Get(ctx).(*CustomClaims)
// [use claims...]
// })
//
// Note that there is one exception, if the value was a pointer
// to a map[string]interface{}, it returns the map itself so it can be
// accessible directly without the requirement of unwrapping it, e.g.
// j.Verify(func() interface{} {
// return &iris.Map{}
// }
// [...]
// claims := jwt.Get(ctx).(iris.Map)
func Get(ctx *context.Context) interface{} {
if tok := GetTokenInfo(ctx); tok != nil {
switch v := tok.Value.(type) {
case *context.Map:
return *v
case *json.RawMessage:
// This is useful when we can accept more than one
// type of JWT token in the same request path,
// but we also want to keep type safety.
// Usage:
// type myClaims struct { Roles []string `json:"roles"`}
// v := jwt.Get(ctx)
// var claims myClaims
// jwt.Unmarshal(v, &claims)
// [...claims.Roles]
return *v
default:
return v
}
}
return nil
}
// GetTokenInfo returns the verified token's information.
func GetTokenInfo(ctx *context.Context) *TokenInfo {
if v := ctx.Values().Get(tokenInfoContextKey); v != nil {
if t, ok := v.(*TokenInfo); ok {
return t
}
}
return nil
}
// Invalidate invalidates a verified JWT token.
// It adds the request token, retrieved by Verify methods, to the block list.
// Next request will be blocked, even if the token was not yet expired.
// This method can be used when the client-side does not clear the token
// on a user logout operation.
//
// Note: the Blocklist should be initialized before serve-time: j.InitDefaultBlocklist().
func (j *JWT) Invalidate(ctx *context.Context) {
if j.Blocklist == nil {
ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil")
return
}
tokenInfo := GetTokenInfo(ctx)
if tokenInfo == nil {
return
}
j.Blocklist.Set(tokenInfo.RequestToken, tokenInfo.Claims.Expiry.Time())
}
// VerifyTokenString verifies and unmarshals an extracted request token to "dest" destination.
// The last variadic input indicates any further validations against the verified token claims.
// If the given "dest" is a valid context.User then ctx.User() will return it.
// If the token is missing an `ErrMissing` is returned.
// If the incoming token was expired an `ErrExpired` is returned.
// If the incoming token was blocked by the server an `ErrBlocked` is returned.
func (j *JWT) VerifyTokenString(ctx *context.Context, token string, dest interface{}, expectations ...Expectation) (*TokenInfo, error) {
if token == "" {
return nil, ErrMissing
}
var (
parsedToken *jwt.JSONWebToken
err error
)
if j.DecriptionKey != nil {
t, cerr := jwt.ParseSignedAndEncrypted(token)
if cerr != nil {
return nil, cerr
}
parsedToken, err = t.Decrypt(j.DecriptionKey)
} else {
parsedToken, err = jwt.ParseSigned(token)
}
if err != nil {
return nil, err
}
var (
claims Claims
tokenMaxAger tokenWithMaxAge
)
var (
ignoreDest = dest == nil
ignoreVarClaims bool
)
if !ignoreDest { // if dest was not nil, check if the dest is already a standard claims pointer.
_, ignoreVarClaims = dest.(*Claims)
}
// Ensure read the standard claims one if dest was Claims or was nil.
// (it wont break anything if we unmarshal them twice though, we just do it for performance reasons).
var pointers = []interface{}{&tokenMaxAger}
if !ignoreDest {
pointers = append(pointers, dest)
}
if !ignoreVarClaims {
pointers = append(pointers, &claims)
}
if err = parsedToken.Claims(j.VerificationKey, pointers...); err != nil {
return nil, err
}
// Set the std claims, if missing from receiver so the expectations and validation still work.
if ignoreVarClaims {
claims = *dest.(*Claims)
} else if ignoreDest {
dest = &claims
}
expectMaxAge := j.MaxAge
// Build the Expected value.
expected := Expected{}
for _, e := range expectations {
if e != nil {
// expection can be used as a field validation too (see MeetRequirements).
if err = e(&expected, dest); err != nil {
if err == ErrExpectRefreshToken {
if tokenMaxAger.MaxAge > 0 {
// If max age exists, grab it and compare it later.
// Otherwise fire the ErrExpectRefreshToken.
expectMaxAge = tokenMaxAger.MaxAge
continue
}
}
return nil, err
}
}
}
gotMaxAge := getMaxAge(claims)
if !compareMaxAge(expectMaxAge, gotMaxAge) {
// Additional check to automatically invalidate
// any previous jwt maxAge setting change.
// In-short, if the time.Now().Add j.MaxAge
// does not match the "iat" (issued at) then we invalidate the token.
return nil, ErrInvalidMaxAge
}
// For other standard JWT claims fields such as "exp"
// The developer can just add a field of Expiry *NumericDate `json:"exp"`
// and will be filled by the parsed token automatically.
// No need for more interfaces.
err = validateClaims(ctx, dest, claims, expected)
if err != nil {
if err == ErrExpired {
// If token was expired remove it from the block list.
if j.Blocklist != nil {
j.Blocklist.Del(token)
}
}
return nil, err
}
if j.Blocklist != nil {
// If token exists in the block list, then stop here.
if j.Blocklist.Has(token) {
return nil, ErrBlocked
}
}
if !ignoreDest {
if ut, ok := dest.(TokenSetter); ok {
// The u.Token is empty even if we set it and export it on JSON structure.
// Set it manually.
ut.SetToken(token)
}
}
// Set the information.
tokenInfo := &TokenInfo{
RequestToken: token,
Claims: claims,
Value: dest,
}
return tokenInfo, nil
}
// TokenPair holds the access token and refresh token response.
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
type tokenWithMaxAge struct {
// Useful to separate access from refresh tokens.
// Can be used to by-pass the internal check of expected
// MaxAge setting to match the token's received max age too.
MaxAge time.Duration `json:"tokenMaxAge"`
}
// TokenPair generates a token pair of access and refresh tokens.
// The first two arguments required for the refresh token
// and the last one is the claims for the access token one.
func (j *JWT) TokenPair(refreshMaxAge time.Duration, refreshClaims interface{}, accessClaims interface{}) (TokenPair, error) {
if refreshMaxAge <= j.MaxAge {
return TokenPair{}, fmt.Errorf("refresh max age should be bigger than access token's one[%d - %d]", refreshMaxAge, j.MaxAge)
}
accessToken, err := j.Token(accessClaims)
if err != nil {
return TokenPair{}, err
}
c, err := normalize(refreshClaims)
if err != nil {
return TokenPair{}, err
}
if c == nil {
c = make(context.Map)
}
// need to validate against its value instead of the setting's one (see `VerifyTokenString`).
c["tokenMaxAge"] = refreshMaxAge
refreshToken, err := j.token(refreshMaxAge, c)
if err != nil {
return TokenPair{}, nil
}
pair := TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
}
return pair, nil
}
// Verify returns a middleware which
// decrypts an incoming request token to the result of the given "newPtr".
// It does write a 401 unauthorized status code if verification or decryption failed.
// It calls the `ctx.Next` on verified requests.
//
// Iit unmarshals the token to the specific type returned from the given "newPtr" function.
// It sets the Context User and User's Token too. So the next handler(s)
// of the same chain can access the User through a `Context.User()` call.
//
// Note unlike `VerifyToken`, this method automatically protects
// the claims with JSON required tags (see `MeetRequirements` Expection).
//
// On verified tokens:
// - The information can be retrieved through `Get` and `GetTokenInfo` functions.
// - User is set if the newPtr returns a valid Context User
// - The Context Logout method is set if Blocklist was initialized
// Any error is captured to the Context,
// which can be retrieved by a `ctx.GetErr()` call.
//
// See `VerifyJSON` too.
func (j *JWT) Verify(newPtr func() interface{}, expections ...Expectation) context.Handler {
if newPtr == nil {
newPtr = func() interface{} {
// Return a map here as the default type one,
// as it does allow .Get callers to access its fields with ease
// (although, I always recommend using structs for type-safety and
// also they can accept a required tag option too).
return &context.Map{}
}
}
expections = append(expections, MeetRequirements(newPtr()))
return func(ctx *context.Context) {
ptr := newPtr()
tokenInfo, err := j.VerifyToken(ctx, ptr, expections...)
if err != nil {
ctx.Application().Logger().Debugf("iris.jwt.Verify: %v", err)
ctx.StopWithError(401, context.PrivateError(err))
return
}
if u, ok := ptr.(context.User); ok {
ctx.SetUser(u)
}
if j.Blocklist != nil {
ctx.SetLogoutFunc(j.Invalidate)
}
ctx.Values().Set(tokenInfoContextKey, tokenInfo)
ctx.Next()
}
}
// VerifyMap is a shortcut of Verify with a function which will bind
// the claims to a standard Go map[string]interface{}.
func (j *JWT) VerifyMap(expections ...Expectation) context.Handler {
return j.Verify(func() interface{} {
return &context.Map{}
}, expections...)
}
// VerifyJSON works like `Verify` but instead it
// binds its "newPtr" function to return a raw JSON message.
// It does NOT read the token from JSON by itself,
// to do that add the `FromJSON` to the Token Extractors.
// It's used to bind the claims in any value type on the next handler.
//
// This allows the caller to bind this JSON message to any Go structure (or map).
// This is useful when we can accept more than one
// type of JWT token in the same request path,
// but we also want to keep type safety.
// Usage:
// app.Use(jwt.VerifyJSON())
// Inside a route Handler:
// claims := struct { Roles []string `json:"roles"`}{}
// jwt.ReadJSON(ctx, &claims)
// ...access to claims.Roles as []string
func (j *JWT) VerifyJSON(expections ...Expectation) context.Handler {
return j.Verify(func() interface{} {
return new(json.RawMessage)
}, expections...)
}
// ReadJSON is a helper which binds "claimsPtr" to the
// raw JSON token claims.
// Use inside the handlers when `VerifyJSON()` middleware was registered.
func ReadJSON(ctx *context.Context, claimsPtr interface{}) error {
v := Get(ctx)
if v == nil {
return ErrMissing
}
data, ok := v.(json.RawMessage)
if !ok {
return ErrMissing
}
return Unmarshal(data, claimsPtr)
}
// NewUser returns a new User based on the given "opts".
// The caller can modify the User until its `GetToken` is called.
func (j *JWT) NewUser(opts ...UserOption) *User {
u := &User{
j: j,
SimpleUser: &context.SimpleUser{
Authorization: "IRIS_JWT_USER", // Used to separate a refresh token with a user/access one too.
Features: []context.UserFeature{
context.TokenFeature,
},
},
}
for _, opt := range opts {
opt(u)
}
return u
}
// VerifyUser works like the `Verify` method but instead
// it unmarshals the token to the specific User type.
// It sets the Context User too. So the next handler(s)
// of the same chain can access the User through a `Context.User()` call.
func (j *JWT) VerifyUser(expectations ...Expectation) context.Handler {
return j.Verify(func() interface{} {
return new(User)
}, expectations...)
}

View File

@@ -1,8 +1,7 @@
// Package jwt_test contains simple Iris jwt tests. Most of the jwt functionality is already tested inside the jose package itself.
package jwt_test
import (
"os"
"fmt"
"testing"
"time"
@@ -11,324 +10,56 @@ import (
"github.com/kataras/iris/v12/middleware/jwt"
)
type userClaims struct {
// Optionally:
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience jwt.Audience `json:"aud"`
//
Username string `json:"username"`
var testAlg, testSecret = jwt.HS256, []byte("sercrethatmaycontainch@r$")
type fooClaims struct {
Foo string `json:"foo"`
}
const testMaxAge = 7 * time.Second
// Random RSA verification and encryption.
func TestRSA(t *testing.T) {
j := jwt.RSA(testMaxAge)
t.Cleanup(func() {
os.Remove(jwt.DefaultSignFilename)
os.Remove(jwt.DefaultEncFilename)
})
testWriteVerifyBlockToken(t, j)
}
// HMAC verification and encryption.
func TestHMAC(t *testing.T) {
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
testWriteVerifyBlockToken(t, j)
}
func TestNew_HMAC(t *testing.T) {
j, err := jwt.New(testMaxAge, jwt.HS256, []byte("secret"))
if err != nil {
t.Fatal(err)
}
err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, []byte("itsa16bytesecret"))
if err != nil {
t.Fatal(err)
}
testWriteVerifyBlockToken(t, j)
}
// HMAC verification only (unecrypted).
func TestVerify(t *testing.T) {
j, err := jwt.New(testMaxAge, jwt.HS256, []byte("another secret"))
if err != nil {
t.Fatal(err)
}
testWriteVerifyBlockToken(t, j)
}
func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) {
t.Helper()
j.InitDefaultBlocklist()
j.Extractors = append(j.Extractors, jwt.FromJSON("access_token"))
customClaims := &userClaims{
Issuer: "an-issuer",
Audience: jwt.Audience{"an-audience"},
Subject: "user",
Username: "kataras",
}
// The actual tests are inside the kataras/jwt repository.
// This runs simple checks of just the middleware part.
func TestJWT(t *testing.T) {
app := iris.New()
app.OnErrorCode(iris.StatusUnauthorized, func(ctx iris.Context) {
if err := ctx.GetErr(); err != nil {
// Test accessing the private error and set this as the response body.
ctx.WriteString(err.Error())
} else { // Else the default behavior
ctx.WriteString(iris.StatusText(iris.StatusUnauthorized))
}
})
app.Get("/auth", func(ctx iris.Context) {
j.WriteToken(ctx, customClaims)
})
app.Post("/protected", func(ctx iris.Context) {
var claims userClaims
_, err := j.VerifyToken(ctx, &claims)
if err != nil {
// t.Logf("%s: %v", ctx.Path(), err)
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
m := app.Party("/middleware")
m.Use(j.Verify(func() interface{} {
return new(userClaims)
}))
m.Post("/protected", func(ctx iris.Context) {
claims := jwt.Get(ctx)
ctx.JSON(claims)
})
m.Post("/invalidate", func(ctx iris.Context) {
ctx.Logout() // OR j.Invalidate(ctx)
})
e := httptest.New(t, app)
// Get token.
rawToken := e.GET("/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if rawToken == "" {
t.Fatalf("empty token")
}
restrictedPaths := [...]string{"/protected", "/middleware/protected"}
now := time.Now()
for _, path := range restrictedPaths {
// Authorization Header.
e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(customClaims)
// URL Query.
e.POST(path).WithQuery("token", rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(customClaims)
// JSON Body.
e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect().
Status(httptest.StatusOK).JSON().Equal(customClaims)
// Missing "Bearer".
e.POST(path).WithHeader("Authorization", rawToken).Expect().
Status(httptest.StatusUnauthorized).Body().Equal("token is missing")
}
// Invalidate the token.
e.POST("/middleware/invalidate").WithQuery("token", rawToken).Expect().
Status(httptest.StatusOK)
// Token is blocked by server.
e.POST("/middleware/protected").WithQuery("token", rawToken).Expect().
Status(httptest.StatusUnauthorized).Body().Equal("token is blocked")
expireRemDur := testMaxAge - time.Since(now)
// Expiration.
time.Sleep(expireRemDur /* -end */)
for _, path := range restrictedPaths {
e.POST(path).WithQuery("token", rawToken).Expect().
Status(httptest.StatusUnauthorized).Body().Equal("token is expired (exp)")
}
}
func TestVerifyMap(t *testing.T) {
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
expectedClaims := iris.Map{
"iss": "tester",
"username": "makis",
"roles": []string{"admin"},
}
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
err := j.WriteToken(ctx, expectedClaims)
if err != nil {
ctx.StopWithError(iris.StatusUnauthorized, err)
return
}
if expectedClaims["exp"] == nil || expectedClaims["iat"] == nil {
ctx.StopWithText(iris.StatusBadRequest,
"exp or/and iat is nil - this means that the expiry was not set")
return
}
})
userAPI := app.Party("/user")
userAPI.Post("/", func(ctx iris.Context) {
var claims iris.Map
if _, err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
// Test map + Verify middleware.
userAPI.Post("/middleware", j.Verify(nil), func(ctx iris.Context) {
claims := jwt.Get(ctx)
ctx.JSON(claims)
})
e := httptest.New(t, app, httptest.LogLevel("error"))
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
e.POST("/user/middleware").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
}
type customClaims struct {
Username string `json:"username"`
Token string `json:"token"`
}
func (c *customClaims) SetToken(tok string) {
c.Token = tok
}
func TestVerifyStruct(t *testing.T) {
maxAge := testMaxAge / 2
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
err := j.WriteToken(ctx, customClaims{
Username: "makis",
})
if err != nil {
ctx.StopWithError(iris.StatusUnauthorized, err)
return
}
})
userAPI := app.Party("/user")
userAPI.Post("/", func(ctx iris.Context) {
var claims customClaims
if _, err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err))
return
}
ctx.JSON(claims)
})
e := httptest.New(t, app)
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Object().ContainsMap(iris.Map{
"username": "makis",
"token": token, // Test SetToken.
})
e.POST("/user").Expect().Status(httptest.StatusUnauthorized)
time.Sleep(maxAge)
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().Status(httptest.StatusUnauthorized)
}
func TestVerifyJSON(t *testing.T) {
j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret")
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
err := j.WriteToken(ctx, iris.Map{"roles": []string{"admin"}})
if err != nil {
ctx.StopWithError(iris.StatusUnauthorized, err)
return
}
})
app.Post("/", j.VerifyJSON(), func(ctx iris.Context) {
claims := struct {
Roles []string `json:"roles"`
}{}
jwt.ReadJSON(ctx, &claims)
ctx.JSON(claims)
})
e := httptest.New(t, app, httptest.LogLevel("error"))
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(iris.Map{"roles": []string{"admin"}})
e.POST("/").Expect().Status(httptest.StatusUnauthorized)
}
func TestVerifyUserAndExpected(t *testing.T) { // Tests the jwt.User struct + context validator + expected.
maxAge := testMaxAge / 2
j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret")
expectedUser := j.NewUser(jwt.Username("makis"), jwt.Roles("admin"), jwt.Fields(iris.Map{
"custom": true,
})) // only for the sake of the test, we iniitalize it here.
expectedUser.Issuer = "tester"
app := iris.New()
app.Get("/user/auth", func(ctx iris.Context) {
tok, err := expectedUser.GetToken()
signer := jwt.NewSigner(testAlg, testSecret, 3*time.Second)
app.Get("/", func(ctx iris.Context) {
claims := fooClaims{Foo: "bar"}
token, err := signer.Sign(claims)
if err != nil {
ctx.StopWithError(iris.StatusInternalServerError, err)
return
}
ctx.WriteString(tok)
ctx.Write(token)
})
userAPI := app.Party("/user")
userAPI.Use(jwt.WithExpected(jwt.Expected{Issuer: "tester"}, j.VerifyUser()))
userAPI.Post("/", func(ctx iris.Context) {
user := ctx.User()
ctx.JSON(user)
verifier := jwt.NewVerifier(testAlg, testSecret)
verifier.ErrorHandler = func(ctx iris.Context, err error) { // app.OnErrorCode(401, ...)
ctx.StopWithError(iris.StatusUnauthorized, err)
}
middleware := verifier.Verify(func() interface{} { return new(fooClaims) })
app.Get("/protected", middleware, func(ctx iris.Context) {
claims := jwt.Get(ctx).(*fooClaims)
ctx.WriteString(claims.Foo)
})
e := httptest.New(t, app)
token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if token == "" {
t.Fatalf("empty token")
}
e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedUser)
// Get generated token.
token := e.GET("/").Expect().Status(iris.StatusOK).Body().Raw()
// Test Header.
headerValue := fmt.Sprintf("Bearer %s", token)
e.GET("/protected").WithHeader("Authorization", headerValue).Expect().
Status(iris.StatusOK).Body().Equal("bar")
// Test URL query.
e.GET("/protected").WithQuery("token", token).Expect().
Status(iris.StatusOK).Body().Equal("bar")
// Test generic client message if we don't manage the private error by ourselves.
e.POST("/user").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
// Test unauthorized.
e.GET("/protected").Expect().Status(iris.StatusUnauthorized)
e.GET("/protected").WithHeader("Authorization", "missing bearer").Expect().Status(iris.StatusUnauthorized)
e.GET("/protected").WithQuery("token", "invalid_token").Expect().Status(iris.StatusUnauthorized)
// Test expired (note checks happen based on second round).
time.Sleep(5 * time.Second)
e.GET("/protected").WithHeader("Authorization", headerValue).Expect().
Status(iris.StatusUnauthorized).Body().Equal("token expired")
}

View File

@@ -1,106 +0,0 @@
package jwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io/ioutil"
"os"
)
// LoadRSA tries to read RSA Private Key from "fname" system file,
// if does not exist then it generates a new random one based on "bits" (e.g. 2048, 4096)
// and exports it to a new "fname" file.
func LoadRSA(fname string, bits int) (key *rsa.PrivateKey, err error) {
exists := fileExists(fname)
if exists {
key, err = importFromFile(fname)
} else {
key, err = rsa.GenerateKey(rand.Reader, bits)
}
if err != nil {
return
}
if !exists {
err = exportToFile(key, fname)
}
return
}
func exportToFile(key *rsa.PrivateKey, filename string) error {
b := x509.MarshalPKCS1PrivateKey(key)
encoded := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: b,
},
)
return ioutil.WriteFile(filename, encoded, 0600)
}
func importFromFile(filename string) (*rsa.PrivateKey, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
return ParseRSAPrivateKey(b, nil)
}
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
var (
// ErrNotPEM is an error type of the `ParseXXX` function(s) fired
// when the data are not PEM-encoded.
ErrNotPEM = errors.New("key must be PEM encoded")
// ErrInvalidKey is an error type of the `ParseXXX` function(s) fired
// when the contents are not type of rsa private key.
ErrInvalidKey = errors.New("key is not of type *rsa.PrivateKey")
)
// ParseRSAPrivateKey encodes a PEM-encoded PKCS1 or PKCS8 private key protected with a password.
func ParseRSAPrivateKey(key, password []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, ErrNotPEM
}
var (
parsedKey interface{}
err error
)
var blockDecrypted []byte
if len(password) > 0 {
if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil {
return nil, err
}
} else {
blockDecrypted = block.Bytes
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
return nil, err
}
}
privateKey, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, ErrInvalidKey
}
return privateKey, nil
}

59
middleware/jwt/signer.go Normal file
View File

@@ -0,0 +1,59 @@
package jwt
import (
"fmt"
"time"
"github.com/kataras/jwt"
)
type Signer struct {
Alg Alg
Key interface{}
MaxAge time.Duration
Encrypt func([]byte) ([]byte, error)
}
func NewSigner(signatureAlg Alg, signatureKey interface{}, maxAge time.Duration) *Signer {
return &Signer{
Alg: signatureAlg,
Key: signatureKey,
MaxAge: maxAge,
}
}
// WithGCM enables AES-GCM payload decryption.
func (s *Signer) WithGCM(key, additionalData []byte) *Signer {
encrypt, _, err := jwt.GCM(key, additionalData)
if err != nil {
panic(err) // important error before serve, stop everything.
}
s.Encrypt = encrypt
return s
}
func (s *Signer) Sign(claims interface{}, opts ...SignOption) ([]byte, error) {
return SignEncrypted(s.Alg, s.Key, s.Encrypt, claims, append([]SignOption{MaxAge(s.MaxAge)}, opts...)...)
}
func (s *Signer) NewTokenPair(accessClaims interface{}, refreshClaims interface{}, refreshMaxAge time.Duration, accessOpts ...SignOption) (TokenPair, error) {
if refreshMaxAge <= s.MaxAge {
return TokenPair{}, fmt.Errorf("refresh max age should be bigger than access token's one[%d - %d]", refreshMaxAge, s.MaxAge)
}
accessToken, err := s.Sign(accessClaims, accessOpts...)
if err != nil {
return TokenPair{}, err
}
refreshToken, err := Sign(s.Alg, s.Key, refreshClaims, MaxAge(refreshMaxAge))
if err != nil {
return TokenPair{}, err
}
tokenPair := jwt.NewTokenPair(accessToken, refreshToken)
return tokenPair, nil
}

View File

@@ -1,187 +0,0 @@
package jwt
import (
"time"
"github.com/kataras/iris/v12/context"
)
// User a common User structure for JWT.
// However, we're not limited to that one;
// any Go structure can be generated as a JWT token.
//
// Look `NewUser` and `VerifyUser` JWT middleware's methods.
// Use its `GetToken` method to generate the token when
// the User structure is set.
type User struct {
Claims
// Note: we could use a map too as the Token is generated when GetToken is called.
*context.SimpleUser
j *JWT
}
var (
_ context.FeaturedUser = (*User)(nil)
_ TokenSetter = (*User)(nil)
_ ContextValidator = (*User)(nil)
)
// UserOption sets optional fields for a new User
// See `NewUser` instance function.
type UserOption func(*User)
// Username sets the Username and the JWT Claim's Subject
// to the given "username".
func Username(username string) UserOption {
return func(u *User) {
u.Username = username
u.Claims.Subject = username
u.Features = append(u.Features, context.UsernameFeature)
}
}
// Email sets the Email field for the User field.
func Email(email string) UserOption {
return func(u *User) {
u.Email = email
u.Features = append(u.Features, context.EmailFeature)
}
}
// Roles upserts to the User's Roles field.
func Roles(roles ...string) UserOption {
return func(u *User) {
u.Roles = roles
u.Features = append(u.Features, context.RolesFeature)
}
}
// MaxAge sets claims expiration and the AuthorizedAt User field.
func MaxAge(maxAge time.Duration) UserOption {
return func(u *User) {
now := time.Now()
u.Claims.Expiry = NewNumericDate(now.Add(maxAge))
u.Claims.IssuedAt = NewNumericDate(now)
u.AuthorizedAt = now
u.Features = append(u.Features, context.AuthorizedAtFeature)
}
}
// Fields copies the "fields" to the user's Fields field.
// This can be used to set custom fields to the User instance.
func Fields(fields context.Map) UserOption {
return func(u *User) {
if len(fields) == 0 {
return
}
if u.Fields == nil {
u.Fields = make(context.Map, len(fields))
}
for k, v := range fields {
u.Fields[k] = v
}
u.Features = append(u.Features, context.FieldsFeature)
}
}
// SetToken is called automaticaly on VerifyUser/VerifyObject.
// It sets the extracted from request, and verified from server raw token.
func (u *User) SetToken(token string) {
u.Token = token
}
// GetToken overrides the SimpleUser's Token
// and returns the jwt generated token, among with
// a generator error, if any.
func (u *User) GetToken() (string, error) {
if u.Token != "" {
return u.Token, nil
}
if u.j != nil { // it's always not nil.
if u.j.MaxAge > 0 {
// if the MaxAge option was not manually set, resolve it from the JWT instance.
MaxAge(u.j.MaxAge)(u)
}
// we could generate a token here
// but let's do it on GetToken
// as the user fields may change
// by the caller manually until the token
// sent to the client.
tok, err := u.j.Token(u)
if err != nil {
return "", err
}
u.Token = tok
}
if u.Token == "" {
return "", ErrMissing
}
return u.Token, nil
}
// Validate validates the current user's claims against
// the request. It's called automatically by the JWT instance.
func (u *User) Validate(ctx *context.Context, claims Claims, e Expected) error {
err := u.Claims.ValidateWithLeeway(e, 0)
if err != nil {
return err
}
if u.SimpleUser.Authorization != "IRIS_JWT_USER" {
return ErrInvalidKey
}
// We could add specific User Expectations (new struct and accept an interface{}),
// but for the sake of code simplicity we don't, unless is requested, as the caller
// can validate specific fields by its own at the next step.
return nil
}
// UnmarshalJSON implements the json unmarshaler interface.
func (u *User) UnmarshalJSON(data []byte) error {
err := Unmarshal(data, &u.Claims)
if err != nil {
return err
}
simpleUser := new(context.SimpleUser)
err = Unmarshal(data, simpleUser)
if err != nil {
return err
}
u.SimpleUser = simpleUser
return nil
}
// MarshalJSON implements the json marshaler interface.
func (u *User) MarshalJSON() ([]byte, error) {
claimsB, err := Marshal(u.Claims)
if err != nil {
return nil, err
}
userB, err := Marshal(u.SimpleUser)
if err != nil {
return nil, err
}
if len(userB) == 0 {
return claimsB, nil
}
claimsB = claimsB[0 : len(claimsB)-1] // remove last '}'
userB = userB[1:] // remove first '{'
raw := append(claimsB, ',')
raw = append(raw, userB...)
return raw, nil
}

View File

@@ -1,258 +0,0 @@
package jwt
import (
"bytes"
"errors"
"reflect"
"strings"
"time"
"github.com/kataras/iris/v12/context"
"github.com/square/go-jose/v3/json"
// Use this package instead of the standard encoding/json
// to marshal the NumericDate as expected by the implementation (see 'normalize`).
"github.com/square/go-jose/v3/jwt"
)
const (
claimsExpectedContextKey = "iris.jwt.claims.expected"
needsValidationContextKey = "iris.jwt.claims.unvalidated"
)
var (
// ErrMissing when token cannot be extracted from the request (custm error).
ErrMissing = errors.New("token is missing")
// ErrMissingKey when token does not contain a required JSON field (custom error).
ErrMissingKey = errors.New("token is missing a required field")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("token is expired (exp)")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("token not valid yet (nbf)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
// ErrBlocked indicates that the token was not yet expired
// but was blocked by the server's Blocklist (custom error).
ErrBlocked = errors.New("token is blocked")
// ErrInvalidMaxAge indicates that the token is using a different
// max age than the configurated one ( custom error).
ErrInvalidMaxAge = errors.New("token contains invalid max age")
// ErrExpectRefreshToken indicates that the retrieved token
// was not a refresh token one when `ExpectRefreshToken` is set (custome rror).
ErrExpectRefreshToken = errors.New("expect refresh token")
)
// Expectation option to provide
// an extra layer of token validation, a claims type protection.
// See `VerifyToken` method.
type Expectation func(e *Expected, claims interface{}) error
// Expect protects the claims with the expected values.
func Expect(expected Expected) Expectation {
return func(e *Expected, _ interface{}) error {
*e = expected
return nil
}
}
// ExpectID protects the claims with an ID validation.
func ExpectID(id string) Expectation {
return func(e *Expected, _ interface{}) error {
e.ID = id
return nil
}
}
// ExpectIssuer protects the claims with an issuer validation.
func ExpectIssuer(issuer string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Issuer = issuer
return nil
}
}
// ExpectSubject protects the claims with a subject validation.
func ExpectSubject(sub string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Subject = sub
return nil
}
}
// ExpectAudience protects the claims with an audience validation.
func ExpectAudience(audience ...string) Expectation {
return func(e *Expected, _ interface{}) error {
e.Audience = audience
return nil
}
}
// ExpectRefreshToken SHOULD be passed when a token should be verified
// based on the expiration set by `TokenPair` method instead of the JWT instance's MaxAge setting.
// Useful to validate Refresh Tokens and invalidate Access ones when refresh API is fired,
// if that option is missing then refresh tokens are invalidated when an access token was expected.
//
// Usage:
// var refreshClaims jwt.Claims
// _, err := j.VerifyTokenString(ctx, tokenPair.RefreshToken, &refreshClaims, jwt.ExpectRefreshToken)
func ExpectRefreshToken(e *Expected, _ interface{}) error { return ErrExpectRefreshToken }
// MeetRequirements protects the custom fields of JWT claims
// based on the json:required tag; `json:"name,required"`.
// It accepts the value type.
//
// Usage:
// Verify/VerifyToken(... MeetRequirements(MyUser{}))
func MeetRequirements(claimsType interface{}) Expectation {
// pre-calculate if we need to use reflection at serve time to check for required fields,
// this can work as an alternative of expections for custom non-standard JWT fields.
requireFieldsIndexes := getRequiredFieldIndexes(claimsType)
return func(e *Expected, claims interface{}) error {
if len(requireFieldsIndexes) > 0 {
val := reflect.Indirect(reflect.ValueOf(claims))
for _, idx := range requireFieldsIndexes {
field := val.Field(idx)
if field.IsZero() {
return ErrMissingKey
}
}
}
return nil
}
}
// WithExpected is a middleware wrapper. It wraps a VerifyXXX middleware
// with expected claims fields protection.
// Usage:
// jwt.WithExpected(jwt.Expected{Issuer:"app"}, j.VerifyUser)
func WithExpected(e Expected, verifyHandler context.Handler) context.Handler {
return func(ctx *context.Context) {
ctx.Values().Set(claimsExpectedContextKey, e)
verifyHandler(ctx)
}
}
// ContextValidator validates the object based on the given
// claims and the expected once. The end-developer
// can use this method for advanced validations based on the request Context.
type ContextValidator interface {
Validate(ctx *context.Context, claims Claims, e Expected) error
}
func validateClaims(ctx *context.Context, dest interface{}, claims Claims, expected Expected) (err error) {
// Get any dynamic expectation set by prior middleware.
// See `WithExpected` middleware.
if v := ctx.Values().Get(claimsExpectedContextKey); v != nil {
if e, ok := v.(Expected); ok {
expected = e
}
}
// Force-set the time, it's important for expiration.
expected.Time = time.Now()
switch c := dest.(type) {
case Claims:
err = c.ValidateWithLeeway(expected, 0)
case ContextValidator:
err = c.Validate(ctx, claims, expected)
case *context.Map:
// if the dest is a map then set automatically the expiration settings here,
// so the caller can work further with it.
err = claims.ValidateWithLeeway(expected, 0)
if err == nil {
(*c)["exp"] = claims.Expiry
(*c)["iat"] = claims.IssuedAt
if claims.NotBefore != nil {
(*c)["nbf"] = claims.NotBefore
}
}
default:
err = claims.ValidateWithLeeway(expected, 0)
}
if err != nil {
switch err {
case jwt.ErrExpired:
return ErrExpired
case jwt.ErrNotValidYet:
return ErrNotValidYet
case jwt.ErrIssuedInTheFuture:
return ErrIssuedInTheFuture
}
}
return err
}
func normalize(i interface{}) (context.Map, error) {
if m, ok := i.(context.Map); ok {
return m, nil
}
m := make(context.Map)
raw, err := json.Marshal(i)
if err != nil {
return nil, err
}
d := json.NewDecoder(bytes.NewReader(raw))
d.UseNumber()
if err := d.Decode(&m); err != nil {
return nil, err
}
return m, nil
}
func getRequiredFieldIndexes(i interface{}) (v []int) {
val := reflect.Indirect(reflect.ValueOf(i))
typ := val.Type()
if typ.Kind() != reflect.Struct {
return nil
}
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
// Note: for the sake of simplicity we don't lookup for nested objects (FieldByIndex),
// we could do that as we do in dependency injection feature but unless requirested we don't.
tag := field.Tag.Get("json")
if strings.Contains(tag, ",required") {
v = append(v, i)
}
}
return
}
// getMaxAge returns the result of expiry-issued at.
// Note that if in JWT MaxAge's was set to a value like: 3.5 seconds
// this will return 3 on token retreival. Of course this is not a problem
// in real world apps as they don't invalidate tokens in seconds
// based on a division result like 2/7.
func getMaxAge(claims Claims) time.Duration {
if issuedAt := claims.IssuedAt.Time(); !issuedAt.IsZero() {
gotMaxAge := claims.Expiry.Time().Sub(issuedAt)
return gotMaxAge
}
return 0
}
func compareMaxAge(expected, got time.Duration) bool {
if expected == got {
return true
}
// got is int64, maybe rounded, but the max age setting is precise, may be a float result
// e.g. the result of a division 2/7=3.5,
// try to validate by round of second so similar/or equal max age setting are considered valid.
min, max := expected-time.Second, expected+time.Second
if got < min || got > max {
return false
}
return true
}

210
middleware/jwt/verifier.go Normal file
View File

@@ -0,0 +1,210 @@
package jwt
import (
"reflect"
"time"
"github.com/kataras/iris/v12/context"
"github.com/kataras/jwt"
)
const (
claimsContextKey = "iris.jwt.claims"
verifiedTokenContextKey = "iris.jwt.token"
)
// Get returns the claims decoded by a verifier.
func Get(ctx *context.Context) interface{} {
if v := ctx.Values().Get(claimsContextKey); v != nil {
return v
}
return nil
}
// GetVerifiedToken returns the verified token structure
// which holds information about the decoded token
// and its standard claims.
func GetVerifiedToken(ctx *context.Context) *VerifiedToken {
if v := ctx.Values().Get(verifiedTokenContextKey); v != nil {
if tok, ok := v.(*VerifiedToken); ok {
return tok
}
}
return nil
}
// Verifier holds common options to verify an incoming token.
// Its Verify method can be used as a middleware to allow authorized clients to access an API.
type Verifier struct {
Alg Alg
Key interface{}
Decrypt func([]byte) ([]byte, error)
Extractors []TokenExtractor
Blocklist Blocklist
Validators []TokenValidator
ErrorHandler func(ctx *context.Context, err error)
}
// NewVerifier accepts the algorithm for the token's signature among with its (private) key
// and optionally some token validators for all verify middlewares that may initialized under this Verifier.
//
// See its Verify method.
func NewVerifier(signatureAlg Alg, signatureKey interface{}, validators ...TokenValidator) *Verifier {
return &Verifier{
Alg: signatureAlg,
Key: signatureKey,
Extractors: []TokenExtractor{FromHeader, FromQuery},
ErrorHandler: func(ctx *context.Context, err error) {
ctx.StopWithError(401, context.PrivateError(err))
},
Validators: validators,
}
}
// WithGCM enables AES-GCM payload encryption.
func (v *Verifier) WithGCM(key, additionalData []byte) *Verifier {
_, decrypt, err := jwt.GCM(key, additionalData)
if err != nil {
panic(err) // important error before serve, stop everything.
}
v.Decrypt = decrypt
return v
}
// WithDefaultBlocklist attaches an in-memory blocklist storage
// to invalidate tokens through server-side.
// To invalidate a token simply call the Context.Logout method.
func (v *Verifier) WithDefaultBlocklist() *Verifier {
v.Blocklist = jwt.NewBlocklist(30 * time.Minute)
return v
}
func (v *Verifier) invalidate(ctx *context.Context) {
if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims.Expiry)
ctx.Values().Remove(claimsContextKey)
ctx.Values().Remove(verifiedTokenContextKey)
ctx.SetUser(nil)
ctx.SetLogoutFunc(nil)
}
}
// RequestToken extracts the token from the request.
func (v *Verifier) RequestToken(ctx *context.Context) (token string) {
for _, extract := range v.Extractors {
if token = extract(ctx); token != "" {
break // ok we found it.
}
}
return
}
type (
// ClaimsValidator is a special interface which, if the destination claims
// implements it then the verifier runs its Validate method before return.
ClaimsValidator interface {
Validate() error
}
// ClaimsContextValidator same as ClaimsValidator but it accepts
// a request context which can be used for further checks before
// validating the incoming token's claims.
ClaimsContextValidator interface {
Validate(*context.Context) error
}
)
// VerifyToken simply verifies the given "token" and validates its standard claims (such as expiration).
// Returns a structure which holds the token's information. See the Verify method instead.
func (v *Verifier) VerifyToken(token []byte, validators ...TokenValidator) (*VerifiedToken, error) {
return jwt.VerifyEncrypted(v.Alg, v.Key, v.Decrypt, token, validators...)
}
// Verify is the most important piece of code inside the Verifier.
// It accepts the "claimsType" function which should return a pointer to a custom structure
// which the token's decode claims valuee will be binded and validated to.
// Returns a common Iris handler which can be used as a middleware to protect an API
// from unauthorized client requests. After this, the route handlers can access the claims
// through the jwt.Get package-level function.
//
// Example Code:
func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenValidator) context.Handler {
unmarshal := jwt.Unmarshal
if claimsType != nil {
c := claimsType()
if hasRequired(c) {
unmarshal = jwt.UnmarshalWithRequired
}
}
if v.Blocklist != nil {
validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...)
}
return func(ctx *context.Context) {
token := []byte(v.RequestToken(ctx))
verifiedToken, err := v.VerifyToken(token, validators...)
if err != nil {
v.ErrorHandler(ctx, err)
return
}
if claimsType != nil {
dest := claimsType()
if err = unmarshal(verifiedToken.Payload, dest); err != nil {
v.ErrorHandler(ctx, err)
return
}
if validator, ok := dest.(ClaimsValidator); ok {
if err = validator.Validate(); err != nil {
v.ErrorHandler(ctx, err)
return
}
} else if contextValidator, ok := dest.(ClaimsContextValidator); ok {
if err = contextValidator.Validate(ctx); err != nil {
v.ErrorHandler(ctx, err)
return
}
}
if u, ok := dest.(context.User); ok {
ctx.SetUser(u)
}
ctx.Values().Set(claimsContextKey, dest)
}
if v.Blocklist != nil {
ctx.SetLogoutFunc(v.invalidate)
}
ctx.Values().Set(verifiedTokenContextKey, verifiedToken)
ctx.Next()
}
}
func hasRequired(i interface{}) bool {
val := reflect.Indirect(reflect.ValueOf(i))
typ := val.Type()
if typ.Kind() != reflect.Struct {
return false
}
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
if jwt.HasRequiredJSONTag(field) {
return true
}
}
return false
}