1
0
mirror of https://github.com/kataras/iris.git synced 2025-12-20 03:17:04 +00:00

New builtin JWT middleware - this one supports encryption and ed25519

Former-commit-id: ca20d256b766e3e8717e91de7a3f3b5f213af0bc
This commit is contained in:
Gerasimos (Makis) Maropoulos
2020-05-27 12:02:17 +03:00
parent c866709acc
commit d556cfc39a
15 changed files with 930 additions and 19 deletions

73
middleware/jwt/alises.go Normal file
View File

@@ -0,0 +1,73 @@
package jwt
import (
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/jwt"
)
type (
// Claims represents public claim values (as specified in RFC 7519).
Claims = jwt.Claims
// Audience represents the recipients that the token is intended for.
Audience = jwt.Audience
)
type (
// KeyAlgorithm represents a key management algorithm.
KeyAlgorithm = jose.KeyAlgorithm
// SignatureAlgorithm represents a signature (or MAC) algorithm.
SignatureAlgorithm = jose.SignatureAlgorithm
// ContentEncryption represents a content encryption algorithm.
ContentEncryption = jose.ContentEncryption
)
// Key management algorithms.
const (
ED25519 = jose.ED25519
RSA15 = jose.RSA1_5
RSAOAEP = jose.RSA_OAEP
RSAOAEP256 = jose.RSA_OAEP_256
A128KW = jose.A128KW
A192KW = jose.A192KW
A256KW = jose.A256KW
DIRECT = jose.DIRECT
ECDHES = jose.ECDH_ES
ECDHESA128KW = jose.ECDH_ES_A128KW
ECDHESA192KW = jose.ECDH_ES_A192KW
ECDHESA256KW = jose.ECDH_ES_A256KW
A128GCMKW = jose.A128GCMKW
A192GCMKW = jose.A192GCMKW
A256GCMKW = jose.A256GCMKW
PBES2HS256A128KW = jose.PBES2_HS256_A128KW
PBES2HS384A192KW = jose.PBES2_HS384_A192KW
PBES2HS512A256KW = jose.PBES2_HS512_A256KW
)
// Signature algorithms.
const (
EdDSA = jose.EdDSA
HS256 = jose.HS256
HS384 = jose.HS384
HS512 = jose.HS512
RS256 = jose.RS256
RS384 = jose.RS384
RS512 = jose.RS512
ES256 = jose.ES256
ES384 = jose.ES384
ES512 = jose.ES512
PS256 = jose.PS256
PS384 = jose.PS384
PS512 = jose.PS512
)
// Content encryption algorithms.
const (
A128CBCHS256 = jose.A128CBC_HS256
A192CBCHS384 = jose.A192CBC_HS384
A256CBCHS512 = jose.A256CBC_HS512
A128GCM = jose.A128GCM
A192GCM = jose.A192GCM
A256GCM = jose.A256GCM
)

424
middleware/jwt/jwt.go Normal file
View File

@@ -0,0 +1,424 @@
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"strings"
"time"
"github.com/kataras/iris/v12/context"
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/jwt"
)
func init() {
context.SetHandlerName("iris/middleware/jwt.*", "iris.jwt")
}
// TokenExtractor is a function that takes a context as input and returns
// a token. An empty string should be returned if no token found
// without additional information.
type TokenExtractor func(context.Context) string
// FromHeader is a token extractor.
// It reads the token from the Authorization request header of form:
// Authorization: "Bearer {token}".
func FromHeader(ctx context.Context) string {
authHeader := ctx.GetHeader("Authorization")
if authHeader == "" {
return ""
}
// pure check: authorization header format must be Bearer {token}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return ""
}
return authHeaderParts[1]
}
// FromQuery is a token extractor.
// It reads the token from the "token" url query parameter.
func FromQuery(ctx context.Context) string {
return ctx.URLParam("token")
}
// FromJSON is a token extractor.
// Reads a json request body and extracts the json based on the given field.
// The request content-type should contain the: application/json header value, otherwise
// this method will not try to read and consume the body.
func FromJSON(jsonKey string) TokenExtractor {
return func(ctx context.Context) string {
if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue {
return ""
}
var m context.Map
if err := ctx.ReadJSON(&m); err != nil {
return ""
}
if m == nil {
return ""
}
v, ok := m[jsonKey]
if !ok {
return ""
}
tok, ok := v.(string)
if !ok {
return ""
}
return tok
}
}
// JWT holds the necessary information the middleware need
// to sign and verify tokens.
//
// The `RSA(privateFile, publicFile, password)` package-level helper function
// can be used to decode the SignKey and VerifyKey.
type JWT struct {
// MaxAge is the expiration duration of the generated tokens.
MaxAge time.Duration
// Extractors are used to extract a raw token string value
// from the request.
// Builtin extractors:
// * FromHeader
// * FromQuery
// * FromJSON
// Defaults to a slice of `FromHeader` and `FromQuery`.
Extractors []TokenExtractor
// Signer is used to sign the token.
// It is set on `New` and `Default` package-level functions.
Signer jose.Signer
// VerificationKey is used to verify the token (public key).
VerificationKey interface{}
// Encrypter is used to, optionally, encrypt the token.
// It is set on `WithExpiration` method.
Encrypter jose.Encrypter
// DecriptionKey is used to decrypt the token (private key)
DecriptionKey interface{}
}
// Random returns a new `JWT` instance
// with in-memory generated rsa256 signing and encryption keys (development).
// It panics on errors. Next server ran will invalidate all request tokens.
//
// Use the `New` package-level function for production use.
func Random(maxAge time.Duration) *JWT {
sigKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
j, err := New(maxAge, RS256, sigKey)
if err != nil {
panic(err)
}
encKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
err = j.WithEncryption(A128CBCHS256, RSA15, encKey)
if err != nil {
panic(err)
}
return j
}
type privateKey interface{ Public() crypto.PublicKey }
// New returns a new JWT instance.
// It accepts a maximum time duration for token expiration
// and the algorithm among with its key for signing and verification.
//
// See `WithEncryption` method to add token encryption too.
// Use `Token` method to generate a new token string
// and `VerifyToken` method to decrypt, verify and bind claims of an incoming request token.
// Token, by default, is extracted by "Authorization: Bearer {token}" request header and
// url query parameter of "token". Token extractors can be modified through the `Extractors` field.
//
// For example, if you want to sign and verify using RSA-256 key:
// 1. Generate key file, e.g:
// $ openssl genrsa -des3 -out private.pem 2048
// 2. Read file contents with io.ReadFile("./private.pem")
// 3. Pass the []byte result to the `MustParseRSAPrivateKey(contents, password)` package-level helper
// 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function.
//
// See aliases.go file for available algorithms.
func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, error) {
sig, err := jose.NewSigner(jose.SigningKey{
Algorithm: alg,
Key: key,
}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, err
}
j := &JWT{
Signer: sig,
VerificationKey: key,
MaxAge: maxAge,
Extractors: []TokenExtractor{FromHeader, FromQuery},
}
if s, ok := key.(privateKey); ok {
j.VerificationKey = s.Public()
}
return j, nil
}
// WithEncryption method enables encryption and decryption of the token.
// It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type.
func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error {
var publicKey interface{} = key
if s, ok := key.(privateKey); ok {
publicKey = s.Public()
}
enc, err := jose.NewEncrypter(contentEncryption, jose.Recipient{
Algorithm: alg,
Key: publicKey,
},
(&jose.EncrypterOptions{}).WithType("JWT").WithContentType("JWT"),
)
if err != nil {
return err
}
j.Encrypter = enc
j.DecriptionKey = key
return nil
}
// Expiry returns a new standard Claims with
// the `Expiry` and `IssuedAt` fields of the "claims" filled
// based on the given "maxAge" duration.
//
// See the `JWT.Expiry` method too.
func Expiry(maxAge time.Duration, claims Claims) Claims {
now := time.Now()
claims.Expiry = jwt.NewNumericDate(now.Add(maxAge))
claims.IssuedAt = jwt.NewNumericDate(now)
return claims
}
// Expiry method same as `Expiry` package-level function,
// it returns a Claims with the expiration fields of the "claims"
// filled based on the JWT's `MaxAge` field.
// Only use it when this standard "claims"
// is embedded on a custom claims structure.
// Usage:
// type UserClaims struct {
// jwt.Claims
// Username string
// }
// [...]
// standardClaims := j.Expiry(jwt.Claims{...})
// customClaims := UserClaims{
// Claims: standardClaims,
// Username: "kataras",
// }
// j.WriteToken(ctx, customClaims)
func (j *JWT) Expiry(claims Claims) Claims {
return Expiry(j.MaxAge, claims)
}
// Token generates and returns a new token string.
// See `VerifyToken` too.
func (j *JWT) Token(claims interface{}) (string, error) {
if c, ok := claims.(Claims); ok {
claims = Expiry(j.MaxAge, c)
}
var (
token string
err error
)
// jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same.
if j.DecriptionKey != nil {
token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize()
} else {
token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize()
}
if err != nil {
return "", err
}
return token, nil
}
// WriteToken is a helper which just generates(calls the `Token` method) and writes
// a new token to the client in plain text format.
//
// Use the `Token` method to get a new generated token raw string value.
func (j *JWT) WriteToken(ctx context.Context, claims interface{}) error {
token, err := j.Token(claims)
if err != nil {
ctx.StatusCode(500)
return err
}
_, err = ctx.WriteString(token)
return err
}
var (
// ErrTokenMissing when token cannot be extracted from the request.
ErrTokenMissing = errors.New("token is missing")
// ErrTokenInvalid when incoming token is invalid.
ErrTokenInvalid = errors.New("token is invalid")
// ErrTokenExpired when incoming token has expired.
ErrTokenExpired = errors.New("token has expired")
)
type (
claimsValidator interface {
ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error
}
claimsAlternativeValidator interface {
Validate() error
}
)
// IsValidated reports whether a token is already validated through
// `VerifyToken`. It returns true when the claims are compatible
// validators: a `Claims` value or a value that implements the `Validate() error` method.
func IsValidated(ctx context.Context) bool { // see the `ReadClaims`.
return ctx.Values().Get(needsValidationContextKey) == nil
}
func validateClaims(ctx context.Context, claimsPtr interface{}) (err error) {
switch claims := claimsPtr.(type) {
case claimsValidator:
err = claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0)
case claimsAlternativeValidator:
err = claims.Validate()
default:
ctx.Values().Set(needsValidationContextKey, struct{}{})
}
if err != nil {
if err == jwt.ErrExpired {
return ErrTokenExpired
}
}
return err
}
// VerifyToken verifies (and decrypts) the request token,
// it also validates and binds the parsed token's claims to the "claimsPtr" (destination).
// It does return a nil error on success.
func (j *JWT) VerifyToken(ctx context.Context, claimsPtr interface{}) error {
var token string
for _, extract := range j.Extractors {
if token = extract(ctx); token != "" {
break // ok we found it.
}
}
if token == "" {
return ErrTokenMissing
}
var (
parsedToken *jwt.JSONWebToken
err error
)
if j.DecriptionKey != nil {
t, cerr := jwt.ParseSignedAndEncrypted(token)
if cerr != nil {
return cerr
}
parsedToken, err = t.Decrypt(j.DecriptionKey)
} else {
parsedToken, err = jwt.ParseSigned(token)
}
if err != nil {
return ErrTokenInvalid
}
if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil {
return ErrTokenInvalid
}
return validateClaims(ctx, claimsPtr)
}
const (
// ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method.
ClaimsContextKey = "iris.jwt.claims"
needsValidationContextKey = "iris.jwt.claims.unvalidated"
)
// Verify is a middleware. It verifies and optionally decrypts an incoming request token.
// It does write a 401 unauthorized status code if verification or decryption failed.
// It calls the `ctx.Next` on verified requests.
//
// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once.
//
// A call of `ReadClaims` is required to validate and acquire the jwt claims
// on the next request.
func (j *JWT) Verify(ctx context.Context) {
var raw json.RawMessage
if err := j.VerifyToken(ctx, &raw); err != nil {
ctx.StopWithStatus(401)
return
}
ctx.Values().Set(ClaimsContextKey, raw)
ctx.Next()
}
// ReadClaims binds the "claimsPtr" (destination)
// to the verified (and decrypted) claims.
// The `Verify` method should be called first (registered as middleware).
func ReadClaims(ctx context.Context, claimsPtr interface{}) error {
v := ctx.Values().Get(ClaimsContextKey)
if v == nil {
return ErrTokenMissing
}
raw, ok := v.(json.RawMessage)
if !ok {
return ErrTokenMissing
}
err := json.Unmarshal(raw, claimsPtr)
if err != nil {
return err
}
// If already validated on VerifyToken (a claimsValidator/claimsAlternativeValidator)
// then no need to perform the check again.
if !IsValidated(ctx) {
ctx.Values().Remove(needsValidationContextKey)
return validateClaims(ctx, claimsPtr)
}
return nil
}

119
middleware/jwt/jwt_test.go Normal file
View File

@@ -0,0 +1,119 @@
// Package jwt_test contains simple Iris jwt tests. Most of the jwt functionality is already tested inside the jose package itself.
package jwt_test
import (
"testing"
"time"
"github.com/kataras/iris/v12"
"github.com/kataras/iris/v12/httptest"
"github.com/kataras/iris/v12/middleware/jwt"
)
type userClaims struct {
jwt.Claims
Username string
}
const testMaxAge = 3 * time.Second
// Random RSA verification and encryption.
func TestRSA(t *testing.T) {
j := jwt.Random(testMaxAge)
testWriteVerifyToken(t, j)
}
// HMAC verification and encryption.
func TestHMAC(t *testing.T) {
j, err := jwt.New(testMaxAge, jwt.HS256, []byte("secret"))
if err != nil {
t.Fatal(err)
}
err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, []byte("itsa16bytesecret"))
if err != nil {
t.Fatal(err)
}
testWriteVerifyToken(t, j)
}
// HMAC verification only (unecrypted).
func TestVerify(t *testing.T) {
j, err := jwt.New(testMaxAge, jwt.HS256, []byte("another secret"))
if err != nil {
t.Fatal(err)
}
testWriteVerifyToken(t, j)
}
func testWriteVerifyToken(t *testing.T, j *jwt.JWT) {
t.Helper()
j.Extractors = append(j.Extractors, jwt.FromJSON("access_token"))
standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}}
expectedClaims := userClaims{
Claims: j.Expiry(standardClaims),
Username: "kataras",
}
app := iris.New()
app.Get("/auth", func(ctx iris.Context) {
j.WriteToken(ctx, expectedClaims)
})
app.Post("/restricted", func(ctx iris.Context) {
var claims userClaims
if err := j.VerifyToken(ctx, &claims); err != nil {
ctx.StopWithStatus(iris.StatusUnauthorized)
return
}
ctx.JSON(claims)
})
app.Post("/restricted_middleware", j.Verify, func(ctx iris.Context) {
var claims userClaims
if err := jwt.ReadClaims(ctx, &claims); err != nil {
ctx.StopWithStatus(iris.StatusUnauthorized)
return
}
ctx.JSON(claims)
})
e := httptest.New(t, app)
// Get token.
rawToken := e.GET("/auth").Expect().Status(httptest.StatusOK).Body().Raw()
if rawToken == "" {
t.Fatalf("empty token")
}
restrictedPaths := [...]string{"/restricted", "/restricted_middleware"}
now := time.Now()
for _, path := range restrictedPaths {
// Authorization Header.
e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
// URL Query.
e.POST(path).WithQuery("token", rawToken).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
// JSON Body.
e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect().
Status(httptest.StatusOK).JSON().Equal(expectedClaims)
// Missing "Bearer".
e.POST(path).WithHeader("Authorization", rawToken).Expect().
Status(httptest.StatusUnauthorized)
}
expireRemDur := testMaxAge - time.Since(now)
// Expiration.
time.Sleep(expireRemDur /* -end */)
for _, path := range restrictedPaths {
e.POST(path).WithQuery("token", rawToken).Expect().Status(httptest.StatusUnauthorized)
}
}

98
middleware/jwt/util.go Normal file
View File

@@ -0,0 +1,98 @@
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
// ErrNotPEM is a panic error of the MustParseXXX functions when the data are not PEM-encoded.
var ErrNotPEM = errors.New("key must be PEM encoded")
// MustParseRSAPrivateKey encodes a PEM-encoded PKCS1 or PKCS8 private key protected with a password.
func MustParseRSAPrivateKey(key, password []byte) *rsa.PrivateKey {
block, _ := pem.Decode(key)
if block == nil {
panic(ErrNotPEM)
}
var (
parsedKey interface{}
err error
)
var blockDecrypted []byte
if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil {
panic(err)
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
panic(err)
}
}
privateKey, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
panic("key is not of type *rsa.PrivateKey")
}
return privateKey
}
// MustParseRSAPublicKey encodes a PEM encoded PKCS1 or PKCS8 public key.
func MustParseRSAPublicKey(key []byte) *rsa.PublicKey {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
panic(ErrNotPEM)
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
panic(err)
}
}
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
panic("key is not of type *rsa.PublicKey")
}
return pkey
}
/*
// MustParseEd25519 PEM encoded Ed25519.
func MustParseEd25519(key []byte) ed25519.PrivateKey {
// Parse PEM block
block, _ := pem.Decode(key)
if block == nil {
panic(ErrNotPEM)
}
type ed25519PrivKey struct {
Version int
ObjectIdentifier struct {
ObjectIdentifier asn1.ObjectIdentifier
}
PrivateKey []byte
}
var asn1PrivKey ed25519PrivKey
if _, err := asn1.Unmarshal(block.Bytes, &asn1PrivKey); err != nil {
panic(err)
}
privateKey := ed25519.NewKeyFromSeed(asn1PrivKey.PrivateKey[2:])
return privateKey
}
*/