mirror of
https://github.com/kataras/iris.git
synced 2026-01-07 20:17:05 +00:00
jwt: add redis blocklist
This commit is contained in:
188
middleware/jwt/blocklist/redis/blocklist.go
Normal file
188
middleware/jwt/blocklist/redis/blocklist.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/core/host"
|
||||
"github.com/kataras/iris/v12/middleware/jwt"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var defaultContext = context.Background()
|
||||
|
||||
type (
|
||||
// Options is just a type alias for the go-redis Client Options.
|
||||
Options = redis.Options
|
||||
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
|
||||
ClusterOptions = redis.ClusterOptions
|
||||
)
|
||||
|
||||
// Client is the interface which both
|
||||
// go-redis Client and Cluster Client implements.
|
||||
type Client interface {
|
||||
redis.Cmdable // Commands.
|
||||
io.Closer // CloseConnection.
|
||||
}
|
||||
|
||||
// Blocklist is a jwt.Blocklist backed by Redis.
|
||||
type Blocklist struct {
|
||||
Clock func() time.Time // Required. Defaults to time.Now.
|
||||
// GetKey is a function which can be used how to extract
|
||||
// the unique identifier for a token.
|
||||
// Required. By default the token key is extracted through the claims.ID ("jti").
|
||||
GetKey func(token []byte, claims jwt.Claims) string
|
||||
// Prefix the token key into the redis database.
|
||||
// Note that if you can also select a different database
|
||||
// through ClientOptions (or ClusterOptions).
|
||||
// Defaults to empty string (no prefix).
|
||||
Prefix string
|
||||
// Both Client and ClusterClient implements this interface.
|
||||
client Client
|
||||
connected uint32
|
||||
// Customize any go-redis fields manually
|
||||
// before Connect.
|
||||
ClientOptions Options
|
||||
ClusterOptions ClusterOptions
|
||||
}
|
||||
|
||||
var _ jwt.Blocklist = (*Blocklist)(nil)
|
||||
|
||||
// NewBlocklist returns a new redis-based Blocklist.
|
||||
// Modify its ClientOptions or ClusterOptions depending the application needs
|
||||
// and call its Connect.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// blocklist := NewBlocklist()
|
||||
// blocklist.ClientOptions.Addr = ...
|
||||
// err := blocklist.Connect()
|
||||
//
|
||||
// And register it:
|
||||
//
|
||||
// verifier := jwt.NewVerifier(...)
|
||||
// verifier.Blocklist = blocklist
|
||||
func NewBlocklist() *Blocklist {
|
||||
return &Blocklist{
|
||||
Clock: time.Now,
|
||||
GetKey: defaultGetKey,
|
||||
Prefix: "",
|
||||
ClientOptions: Options{
|
||||
Addr: "127.0.0.1:6379",
|
||||
// The rest are defaulted to good values already.
|
||||
},
|
||||
// If its Addrs > 0 before connect then cluster client is used instead.
|
||||
ClusterOptions: ClusterOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultGetKey(_ []byte, claims jwt.Claims) string {
|
||||
return claims.ID
|
||||
}
|
||||
|
||||
// Connect prepares the redis client and fires a ping response to it.
|
||||
func (b *Blocklist) Connect() error {
|
||||
if b.Prefix != "" {
|
||||
getKey := b.GetKey
|
||||
b.GetKey = func(token []byte, claims jwt.Claims) string {
|
||||
return b.Prefix + getKey(token, claims)
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.ClusterOptions.Addrs) > 0 {
|
||||
// Use cluster client.
|
||||
b.client = redis.NewClusterClient(&b.ClusterOptions)
|
||||
} else {
|
||||
b.client = redis.NewClient(&b.ClientOptions)
|
||||
}
|
||||
|
||||
_, err := b.client.Ping(defaultContext).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host.RegisterOnInterrupt(func() {
|
||||
atomic.StoreUint32(&b.connected, 0)
|
||||
b.client.Close()
|
||||
})
|
||||
atomic.StoreUint32(&b.connected, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected reports whether the Connect function was called.
|
||||
func (b *Blocklist) IsConnected() bool {
|
||||
return atomic.LoadUint32(&b.connected) > 0
|
||||
}
|
||||
|
||||
// ValidateToken checks if the token exists and
|
||||
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
|
||||
if err != nil {
|
||||
if err == jwt.ErrExpired {
|
||||
b.Del(b.GetKey(token, c))
|
||||
}
|
||||
|
||||
return err // respect the previous error.
|
||||
}
|
||||
|
||||
has, err := b.Has(b.GetKey(token, c))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return jwt.ErrBlocked
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateToken invalidates a verified JWT token.
|
||||
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
|
||||
key := b.GetKey(token, c)
|
||||
return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
|
||||
}
|
||||
|
||||
// Del removes a token from the storage.
|
||||
func (b *Blocklist) Del(key string) error {
|
||||
return b.client.Del(defaultContext, key).Err()
|
||||
}
|
||||
|
||||
// Has reports whether a specific token exists in the storage.
|
||||
func (b *Blocklist) Has(key string) (bool, error) {
|
||||
n, err := b.client.Exists(defaultContext, key).Result()
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
// Count returns the total amount of tokens stored.
|
||||
func (b *Blocklist) Count() (int64, error) {
|
||||
if b.Prefix == "" {
|
||||
return b.client.DBSize(defaultContext).Result()
|
||||
}
|
||||
|
||||
keys, err := b.getKeys(0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(keys)), nil
|
||||
}
|
||||
|
||||
func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
|
||||
keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cursor != 0 {
|
||||
moreKeys, err := b.getKeys(cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys = append(keys, moreKeys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
Reference in New Issue
Block a user