mirror of
https://github.com/kataras/iris.git
synced 2025-12-21 20:07:04 +00:00
@@ -3,7 +3,7 @@ package redis
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"strings"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/sessions"
|
||||
@@ -18,8 +18,6 @@ const (
|
||||
DefaultRedisAddr = "127.0.0.1:6379"
|
||||
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second
|
||||
DefaultRedisTimeout = time.Duration(30) * time.Second
|
||||
// DefaultDelim ths redis delim option, "-".
|
||||
DefaultDelim = "-"
|
||||
)
|
||||
|
||||
// Config the redis configuration used inside sessions
|
||||
@@ -31,31 +29,36 @@ type Config struct {
|
||||
// Defaults to "127.0.0.1:6379".
|
||||
Addr string
|
||||
// Clusters a list of network addresses for clusters.
|
||||
// If not empty "Addr" is ignored.
|
||||
// Currently only Radix() Driver supports it.
|
||||
// If not empty "Addr" is ignored and Redis clusters feature is used instead.
|
||||
Clusters []string
|
||||
// Password string .If no password then no 'AUTH'. Defaults to "".
|
||||
// Use the specified Username to authenticate the current connection
|
||||
// with one of the connections defined in the ACL list when connecting
|
||||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
||||
Username string
|
||||
// Optional password. Must match the password specified in the
|
||||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
||||
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
||||
// that is using the Redis ACL system.
|
||||
Password string
|
||||
// If Database is empty "" then no 'SELECT'. Defaults to "".
|
||||
Database string
|
||||
// MaxActive. Defaults to 10.
|
||||
// Maximum number of socket connections.
|
||||
// Default is 10 connections per every CPU as reported by runtime.NumCPU.
|
||||
MaxActive int
|
||||
// Timeout for connect, write and read, defaults to 30 seconds, 0 means no timeout.
|
||||
Timeout time.Duration
|
||||
// Prefix "myprefix-for-this-website". Defaults to "".
|
||||
Prefix string
|
||||
// Delim the delimiter for the keys on the sessiondb. Defaults to "-".
|
||||
Delim string
|
||||
|
||||
// TLSConfig will cause Dial to perform a TLS handshake using the provided
|
||||
// config. If is nil then no TLS is used.
|
||||
// See https://golang.org/pkg/crypto/tls/#Config
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// Driver supports `Redigo()` or `Radix()` go clients for redis.
|
||||
// Configure each driver by the return value of their constructors.
|
||||
// A Driver should support be a go client for redis communication.
|
||||
// It can be set to a custom one or a mock one (for testing).
|
||||
//
|
||||
// Defaults to `Redigo()`.
|
||||
// Defaults to `GoRedis()`.
|
||||
Driver Driver
|
||||
}
|
||||
|
||||
@@ -64,14 +67,14 @@ func DefaultConfig() Config {
|
||||
return Config{
|
||||
Network: DefaultRedisNetwork,
|
||||
Addr: DefaultRedisAddr,
|
||||
Username: "",
|
||||
Password: "",
|
||||
Database: "",
|
||||
MaxActive: 10,
|
||||
Timeout: DefaultRedisTimeout,
|
||||
Prefix: "",
|
||||
Delim: DefaultDelim,
|
||||
TLSConfig: nil,
|
||||
Driver: Redigo(),
|
||||
Driver: GoRedis(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +86,7 @@ type Database struct {
|
||||
|
||||
var _ sessions.Database = (*Database)(nil)
|
||||
|
||||
// New returns a new redis database.
|
||||
// New returns a new redis sessions database.
|
||||
func New(cfg ...Config) *Database {
|
||||
c := DefaultConfig()
|
||||
if len(cfg) > 0 {
|
||||
@@ -101,16 +104,8 @@ func New(cfg ...Config) *Database {
|
||||
c.Addr = DefaultRedisAddr
|
||||
}
|
||||
|
||||
if c.MaxActive == 0 {
|
||||
c.MaxActive = 10
|
||||
}
|
||||
|
||||
if c.Delim == "" {
|
||||
c.Delim = DefaultDelim
|
||||
}
|
||||
|
||||
if c.Driver == nil {
|
||||
c.Driver = Redigo()
|
||||
c.Driver = GoRedis()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,93 +122,108 @@ func New(cfg ...Config) *Database {
|
||||
return db
|
||||
}
|
||||
|
||||
// Config returns the configuration for the redis server bridge, you can change them.
|
||||
func (db *Database) Config() *Config {
|
||||
return &db.c // 6 Aug 2019 - keep that for no breaking change.
|
||||
}
|
||||
|
||||
// SetLogger sets the logger once before server ran.
|
||||
// By default the Iris one is injected.
|
||||
func (db *Database) SetLogger(logger *golog.Logger) {
|
||||
db.logger = logger
|
||||
}
|
||||
|
||||
func (db *Database) makeSID(sid string) string {
|
||||
return db.c.Prefix + sid
|
||||
}
|
||||
|
||||
// SessionIDKey the session ID stored to the redis session itself.
|
||||
const SessionIDKey = "session_id"
|
||||
|
||||
// Acquire receives a session's lifetime from the database,
|
||||
// if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration.
|
||||
func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime {
|
||||
key := db.makeKey(sid, "")
|
||||
seconds, hasExpiration, found := db.c.Driver.TTL(key)
|
||||
if !found {
|
||||
// fmt.Printf("db.Acquire expires: %s. Seconds: %v\n", expires, expires.Seconds())
|
||||
// not found, create an entry with ttl and return an empty lifetime, session manager will do its job.
|
||||
if err := db.c.Driver.Set(key, sid, int64(expires.Seconds())); err != nil {
|
||||
sidKey := db.makeSID(sid)
|
||||
if !db.c.Driver.Exists(sidKey) {
|
||||
if err := db.Set(sidKey, SessionIDKey, sid, 0, false); err != nil {
|
||||
db.logger.Debug(err)
|
||||
} else if expires > 0 {
|
||||
if err := db.c.Driver.UpdateTTL(sidKey, expires); err != nil {
|
||||
db.logger.Debug(err)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions.LifeTime{} // session manager will handle the rest.
|
||||
}
|
||||
|
||||
if !hasExpiration {
|
||||
return sessions.LifeTime{}
|
||||
}
|
||||
|
||||
return sessions.LifeTime{Time: time.Now().Add(time.Duration(seconds) * time.Second)}
|
||||
untilExpire := db.c.Driver.TTL(sidKey)
|
||||
return sessions.LifeTime{Time: time.Now().Add(untilExpire)}
|
||||
}
|
||||
|
||||
// OnUpdateExpiration will re-set the database's session's entry ttl.
|
||||
// https://redis.io/commands/expire#refreshing-expires
|
||||
func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error {
|
||||
return db.c.Driver.UpdateTTLMany(db.makeKey(sid, ""), int64(newExpires.Seconds()))
|
||||
}
|
||||
|
||||
func (db *Database) makeKey(sid, key string) string {
|
||||
if key == "" {
|
||||
return db.c.Prefix + sid
|
||||
}
|
||||
return db.c.Prefix + sid + db.c.Delim + key
|
||||
return db.c.Driver.UpdateTTL(db.makeSID(sid), newExpires)
|
||||
}
|
||||
|
||||
// Set sets a key value of a specific session.
|
||||
// Ignore the "immutable".
|
||||
func (db *Database) Set(sid string, lifetime *sessions.LifeTime, key string, value interface{}, immutable bool) {
|
||||
func (db *Database) Set(sid string, key string, value interface{}, _ time.Duration, _ bool) error {
|
||||
valueBytes, err := sessions.DefaultTranscoder.Marshal(value)
|
||||
if err != nil {
|
||||
db.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// fmt.Println("database.Set")
|
||||
// fmt.Printf("lifetime.DurationUntilExpiration(): %s. Seconds: %v\n", lifetime.DurationUntilExpiration(), lifetime.DurationUntilExpiration().Seconds())
|
||||
if err = db.c.Driver.Set(db.makeKey(sid, key), valueBytes, int64(lifetime.DurationUntilExpiration().Seconds())); err != nil {
|
||||
db.logger.Debug(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a session value based on the key.
|
||||
func (db *Database) Get(sid string, key string) (value interface{}) {
|
||||
db.get(db.makeKey(sid, key), &value)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) get(key string, outPtr interface{}) error {
|
||||
data, err := db.c.Driver.Get(key)
|
||||
if err != nil {
|
||||
// not found.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = sessions.DefaultTranscoder.Unmarshal(data.([]byte), outPtr); err != nil {
|
||||
db.logger.Debugf("unable to unmarshal value of key: '%s': %v", key, err)
|
||||
if err = db.c.Driver.Set(db.makeSID(sid), key, valueBytes); err != nil {
|
||||
db.logger.Debug(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) keys(sid string) []string {
|
||||
keys, err := db.c.Driver.GetKeys(db.makeKey(sid, ""))
|
||||
// Get retrieves a session value based on the key.
|
||||
func (db *Database) Get(fullSID string, key string) (value interface{}) {
|
||||
if err := db.Decode(fullSID, key, &value); err == nil {
|
||||
return value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode binds the "outPtr" to the value associated to the provided "key".
|
||||
func (db *Database) Decode(sid, key string, outPtr interface{}) error {
|
||||
data, err := db.c.Driver.Get(sid, key)
|
||||
if err != nil {
|
||||
db.logger.Debugf("unable to get all redis keys of session '%s': %v", sid, err)
|
||||
// not found.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = db.decodeValue(data, outPtr); err != nil {
|
||||
db.logger.Debugf("unable to unmarshal value of key: '%s%s': %v", sid, key, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) decodeValue(val interface{}, outPtr interface{}) error {
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch data := val.(type) {
|
||||
case []byte:
|
||||
// this is the most common type, as we save all values as []byte,
|
||||
// the only exception is where the value is string on HGetAll command.
|
||||
return sessions.DefaultTranscoder.Unmarshal(data, outPtr)
|
||||
case string:
|
||||
return sessions.DefaultTranscoder.Unmarshal([]byte(data), outPtr)
|
||||
default:
|
||||
return fmt.Errorf("unknown value type of %T", data)
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) keys(fullSID string) []string {
|
||||
keys, err := db.c.Driver.GetKeys(fullSID)
|
||||
if err != nil {
|
||||
db.logger.Debugf("unable to get all redis keys of session '%s': %v", fullSID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -221,24 +231,33 @@ func (db *Database) keys(sid string) []string {
|
||||
}
|
||||
|
||||
// Visit loops through all session keys and values.
|
||||
func (db *Database) Visit(sid string, cb func(key string, value interface{})) {
|
||||
keys := db.keys(sid)
|
||||
for _, key := range keys {
|
||||
var value interface{} // new value each time, we don't know what user will do in "cb".
|
||||
db.get(key, &value)
|
||||
key = strings.TrimPrefix(key, db.c.Prefix+sid+db.c.Delim)
|
||||
cb(key, value)
|
||||
func (db *Database) Visit(sid string, cb func(key string, value interface{})) error {
|
||||
kv, err := db.c.Driver.GetAll(db.makeSID(sid))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range kv {
|
||||
var value interface{} // new value each time, we don't know what user will do in "cb".
|
||||
if err = db.decodeValue(v, &value); err != nil {
|
||||
db.logger.Debugf("unable to decode %s:%s: %v", sid, k, err)
|
||||
return err
|
||||
}
|
||||
|
||||
cb(k, value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the length of the session's entries (keys).
|
||||
func (db *Database) Len(sid string) (n int) {
|
||||
return len(db.keys(sid))
|
||||
func (db *Database) Len(sid string) int {
|
||||
return db.c.Driver.Len(sid)
|
||||
}
|
||||
|
||||
// Delete removes a session key value based on its key.
|
||||
func (db *Database) Delete(sid string, key string) (deleted bool) {
|
||||
err := db.c.Driver.Delete(db.makeKey(sid, key))
|
||||
err := db.c.Driver.Delete(db.makeSID(sid), key)
|
||||
if err != nil {
|
||||
db.logger.Error(err)
|
||||
}
|
||||
@@ -246,25 +265,30 @@ func (db *Database) Delete(sid string, key string) (deleted bool) {
|
||||
}
|
||||
|
||||
// Clear removes all session key values but it keeps the session entry.
|
||||
func (db *Database) Clear(sid string) {
|
||||
keys := db.keys(sid)
|
||||
func (db *Database) Clear(sid string) error {
|
||||
keys := db.keys(db.makeSID(sid))
|
||||
for _, key := range keys {
|
||||
if err := db.c.Driver.Delete(key); err != nil {
|
||||
if key == SessionIDKey {
|
||||
continue
|
||||
}
|
||||
if err := db.c.Driver.Delete(sid, key); err != nil {
|
||||
db.logger.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Release destroys the session, it clears and removes the session entry,
|
||||
// session manager will create a new session ID on the next request after this call.
|
||||
func (db *Database) Release(sid string) {
|
||||
// clear all $sid-$key.
|
||||
db.Clear(sid)
|
||||
// and remove the $sid.
|
||||
err := db.c.Driver.Delete(db.c.Prefix + sid)
|
||||
func (db *Database) Release(sid string) error {
|
||||
err := db.c.Driver.Delete(db.makeSID(sid), "")
|
||||
if err != nil {
|
||||
db.logger.Debugf("Database.Release.Driver.Delete: %s: %v", sid, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Close terminates the redis connection.
|
||||
|
||||
Reference in New Issue
Block a user