mirror of
https://github.com/kataras/iris.git
synced 2026-01-06 03:27:27 +00:00
first release of SSO package and more examples
This commit is contained in:
162
sso/configuration.go
Normal file
162
sso/configuration.go
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build go1.18
|
||||
|
||||
package sso
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/kataras/jwt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
KIDAccess = "IRIS_SSO_ACCESS"
|
||||
KIDRefresh = "IRIS_SSO_REFRESH"
|
||||
)
|
||||
|
||||
type (
|
||||
Configuration struct {
|
||||
Cookie CookieConfiguration `json:"cookie" yaml:"Cookie" toml:"Cookie" ini:"cookie"`
|
||||
// keep it to always renew the refresh token. RefreshStrategy string `json:"refresh_strategy" yaml:"RefreshStrategy" toml:"RefreshStrategy" ini:"refresh_strategy"`
|
||||
Keys jwt.KeysConfiguration `json:"keys" yaml:"Keys" toml:"Keys" ini:"keys"`
|
||||
}
|
||||
|
||||
CookieConfiguration struct {
|
||||
Name string `json:"cookie" yaml:"Name" toml:"Name" ini:"name"`
|
||||
Hash string `json:"hash" yaml:"Hash" toml:"Hash" ini:"hash"`
|
||||
Block string `json:"block" yaml:"Block" toml:"Block" ini:"block"`
|
||||
}
|
||||
)
|
||||
|
||||
func (c *Configuration) validate() (jwt.Keys, error) {
|
||||
if c.Cookie.Name != "" {
|
||||
if c.Cookie.Hash == "" || c.Cookie.Block == "" {
|
||||
return nil, fmt.Errorf("cookie block and cookie hash are required for security reasons when cookie is used")
|
||||
}
|
||||
}
|
||||
|
||||
keys, err := c.Keys.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sso: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := keys[KIDAccess]; !ok {
|
||||
return nil, fmt.Errorf("sso: %s access token is missing from the configuration", KIDAccess)
|
||||
}
|
||||
|
||||
// Let's keep refresh optional.
|
||||
// if _, ok := keys[KIDRefresh]; !ok {
|
||||
// return nil, fmt.Errorf("sso: %s refresh token is missing from the configuration", KIDRefresh)
|
||||
// }
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// BindRandom binds the "c" configuration to random values for keys and cookie security.
|
||||
// Keys will not be persisted between restarts,
|
||||
// a more persistent storage should be considered for production applications.
|
||||
func (c *Configuration) BindRandom() error {
|
||||
accessPublic, accessPrivate, err := jwt.GenerateEdDSA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
refreshPublic, refreshPrivate, err := jwt.GenerateEdDSA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = Configuration{
|
||||
Cookie: CookieConfiguration{
|
||||
Name: "iris_sso",
|
||||
Hash: string(securecookie.GenerateRandomKey(64)),
|
||||
Block: string(securecookie.GenerateRandomKey(32)),
|
||||
},
|
||||
Keys: jwt.KeysConfiguration{
|
||||
{
|
||||
ID: KIDAccess,
|
||||
Alg: jwt.EdDSA.Name(),
|
||||
MaxAge: 2 * time.Hour,
|
||||
Public: string(accessPublic),
|
||||
Private: string(accessPrivate),
|
||||
},
|
||||
{
|
||||
ID: KIDRefresh,
|
||||
Alg: jwt.EdDSA.Name(),
|
||||
MaxAge: 720 * time.Hour,
|
||||
Public: string(refreshPublic),
|
||||
Private: string(refreshPrivate),
|
||||
EncryptionKey: string(jwt.MustGenerateRandom(32)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Configuration) BindFile(filename string) error {
|
||||
switch filepath.Ext(filename) {
|
||||
case ".json":
|
||||
contents, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
generatedConfig := MustGenerateConfiguration()
|
||||
if generatedYAML, gErr := generatedConfig.ToJSON(); gErr == nil {
|
||||
err = fmt.Errorf("%w: example:\n\n%s", err, generatedYAML)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(contents, c)
|
||||
default:
|
||||
contents, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
generatedConfig := MustGenerateConfiguration()
|
||||
if generatedYAML, gErr := generatedConfig.ToYAML(); gErr == nil {
|
||||
err = fmt.Errorf("%w: example:\n\n%s", err, generatedYAML)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(contents, c)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *Configuration) ToYAML() ([]byte, error) {
|
||||
return yaml.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *Configuration) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
func MustGenerateConfiguration() (c Configuration) {
|
||||
if err := c.BindRandom(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func LoadConfiguration(filename string) (c Configuration, err error) {
|
||||
err = c.BindFile(filename)
|
||||
return
|
||||
}
|
||||
|
||||
func MustLoadConfiguration(filename string) Configuration {
|
||||
c, err := LoadConfiguration(filename)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
83
sso/provider.go
Normal file
83
sso/provider.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build go1.18
|
||||
|
||||
package sso
|
||||
|
||||
import (
|
||||
stdContext "context"
|
||||
"fmt"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
"github.com/kataras/iris/v12/middleware/jwt"
|
||||
"github.com/kataras/iris/v12/x/errors"
|
||||
)
|
||||
|
||||
type VerifiedToken = jwt.VerifiedToken
|
||||
|
||||
type Provider[T User] interface { // A provider can implement Transformer and ErrorHandler as well.
|
||||
Signin(ctx stdContext.Context, username, password string) (T, error)
|
||||
|
||||
// We could do this instead of transformer below but let's keep separated logic methods:
|
||||
// ValidateToken(ctx context.Context, tok *VerifiedToken, t *T) error
|
||||
ValidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error
|
||||
|
||||
InvalidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error
|
||||
InvalidateTokens(ctx stdContext.Context, t T) error
|
||||
}
|
||||
|
||||
// ClaimsProvider is an optional interface, which may not be used at all.
|
||||
// If completed by a Provider, it signs the jwt token
|
||||
// using these claims to each of the following token types.
|
||||
type ClaimsProvider interface {
|
||||
GetAccessTokenClaims() StandardClaims
|
||||
GetRefreshTokenClaims(accessClaims StandardClaims) StandardClaims
|
||||
}
|
||||
|
||||
type Transformer[T User] interface {
|
||||
Transform(ctx stdContext.Context, tok *VerifiedToken) (T, error)
|
||||
}
|
||||
|
||||
type TransformerFunc[T User] func(ctx stdContext.Context, tok *VerifiedToken) (T, error)
|
||||
|
||||
func (fn TransformerFunc[T]) Transform(ctx stdContext.Context, tok *VerifiedToken) (T, error) {
|
||||
return fn(ctx, tok)
|
||||
}
|
||||
|
||||
type ErrorHandler interface {
|
||||
InvalidArgument(ctx *context.Context, err error)
|
||||
Unauthenticated(ctx *context.Context, err error)
|
||||
}
|
||||
|
||||
type DefaultErrorHandler struct{}
|
||||
|
||||
func (e *DefaultErrorHandler) InvalidArgument(ctx *context.Context, err error) {
|
||||
errors.InvalidArgument.Details(ctx, "unable to parse body", err.Error())
|
||||
}
|
||||
|
||||
func (e *DefaultErrorHandler) Unauthenticated(ctx *context.Context, err error) {
|
||||
errors.Unauthenticated.Err(ctx, err)
|
||||
}
|
||||
|
||||
type provider[T User] struct{}
|
||||
|
||||
func newProvider[T User]() *provider[T] {
|
||||
return new(provider[T])
|
||||
}
|
||||
|
||||
func (p *provider[T]) Signin(ctx stdContext.Context, username, password string) (T, error) { // fired on SigninHandler.
|
||||
// your database...
|
||||
var t T
|
||||
return t, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
func (p *provider[T]) ValidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error { // fired on VerifyHandler.
|
||||
// your database and checks of blocked tokens...
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *provider[T]) InvalidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error { // fired on SignoutHandler.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *provider[T]) InvalidateTokens(ctx stdContext.Context, t T) error { // fired on SignoutAllHandler.
|
||||
return nil
|
||||
}
|
||||
568
sso/sso.go
Normal file
568
sso/sso.go
Normal file
@@ -0,0 +1,568 @@
|
||||
//go:build go1.18
|
||||
|
||||
package sso
|
||||
|
||||
import (
|
||||
stdContext "context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/kataras/jwt"
|
||||
)
|
||||
|
||||
type (
|
||||
SSO[T User] struct {
|
||||
config Configuration
|
||||
|
||||
keys jwt.Keys
|
||||
securecookie context.SecureCookie
|
||||
|
||||
providers []Provider[T] // at least one.
|
||||
errorHandler ErrorHandler
|
||||
transformer Transformer[T]
|
||||
claimsProvider ClaimsProvider
|
||||
refreshEnabled bool // if KIDRefresh exists in keys.
|
||||
}
|
||||
|
||||
TVerify[T User] func(t T) error
|
||||
|
||||
SigninRequest struct {
|
||||
Username string `json:"username" form:"username,omitempty"` // username OR email, username has priority over email.
|
||||
Email string `json:"email" form:"email,omitempty"` // email OR username.
|
||||
Password string `json:"password" form:"password"`
|
||||
}
|
||||
|
||||
SigninResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
)
|
||||
|
||||
func MustLoad[T User](filename string) *SSO[T] {
|
||||
var config Configuration
|
||||
if err := config.BindFile(filename); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s, err := New[T](config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func Must[T User](s *SSO[T], err error) *SSO[T] {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func New[T User](config Configuration) (*SSO[T], error) {
|
||||
keys, err := config.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, refreshEnabled := keys[KIDRefresh]
|
||||
|
||||
s := &SSO[T]{
|
||||
config: config,
|
||||
keys: keys,
|
||||
securecookie: securecookie.New([]byte(config.Cookie.Hash), []byte(config.Cookie.Block)),
|
||||
refreshEnabled: refreshEnabled,
|
||||
// providers: []Provider[T]{newProvider[T]()},
|
||||
errorHandler: new(DefaultErrorHandler),
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) WithProviderAndErrorHandler(provider Provider[T], errHandler ErrorHandler) *SSO[T] {
|
||||
if provider != nil {
|
||||
for i := range s.providers {
|
||||
s.providers[i] = nil
|
||||
}
|
||||
s.providers = nil
|
||||
|
||||
s.providers = make([]Provider[T], 0, 1)
|
||||
s.AddProvider(provider)
|
||||
}
|
||||
|
||||
if errHandler != nil {
|
||||
s.SetErrorHandler(errHandler)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SSO[T]) AddProvider(providers ...Provider[T]) *SSO[T] {
|
||||
// defaultProviderTypename := strings.Replace(fmt.Sprintf("%T", s), "SSO", "provider", 1)
|
||||
// if len(s.providers) == 1 && fmt.Sprintf("%T", s.providers[0]) == defaultProviderTypename {
|
||||
// s.providers = append(s.providers[1:], p...)
|
||||
|
||||
// A provider can also implement both transformer and
|
||||
// error handler if that's the design option of the end-developer.
|
||||
for _, p := range providers {
|
||||
if s.transformer == nil {
|
||||
if transformer, ok := p.(Transformer[T]); ok {
|
||||
s.SetTransformer(transformer)
|
||||
}
|
||||
}
|
||||
|
||||
if errHandler, ok := p.(ErrorHandler); ok {
|
||||
s.SetErrorHandler(errHandler)
|
||||
}
|
||||
|
||||
if s.claimsProvider == nil {
|
||||
if claimsProvider, ok := p.(ClaimsProvider); ok {
|
||||
s.claimsProvider = claimsProvider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.providers = append(s.providers, providers...)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SetErrorHandler(errHandler ErrorHandler) *SSO[T] {
|
||||
s.errorHandler = errHandler
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SetTransformer(transformer Transformer[T]) *SSO[T] {
|
||||
s.transformer = transformer
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SetTransformerFunc(transfermerFunc func(ctx stdContext.Context, tok *VerifiedToken) (T, error)) *SSO[T] {
|
||||
s.transformer = TransformerFunc[T](transfermerFunc)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SSO[T]) Signin(ctx stdContext.Context, username, password string) ([]byte, []byte, error) {
|
||||
var t T
|
||||
|
||||
// get "t" from a valid provider.
|
||||
if n := len(s.providers); n > 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
p := s.providers[i]
|
||||
|
||||
v, err := p.Signin(ctx, username, password)
|
||||
if err != nil {
|
||||
if i == n-1 { // last provider errored.
|
||||
return nil, nil, fmt.Errorf("sso: signin: %w", err)
|
||||
}
|
||||
// keep searching.
|
||||
continue
|
||||
}
|
||||
|
||||
// found.
|
||||
t = v
|
||||
break
|
||||
}
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("sso: signin: no provider")
|
||||
}
|
||||
|
||||
// sign the tokens.
|
||||
accessToken, refreshToken, err := s.sign(t)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("sso: signin: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) sign(t T) ([]byte, []byte, error) {
|
||||
// sign the tokens.
|
||||
var (
|
||||
accessStdClaims StandardClaims
|
||||
refreshStdClaims StandardClaims
|
||||
)
|
||||
|
||||
if s.claimsProvider != nil {
|
||||
accessStdClaims = s.claimsProvider.GetAccessTokenClaims()
|
||||
refreshStdClaims = s.claimsProvider.GetRefreshTokenClaims(accessStdClaims)
|
||||
}
|
||||
|
||||
iat := jwt.Clock().Unix()
|
||||
|
||||
if accessStdClaims.IssuedAt == 0 {
|
||||
accessStdClaims.IssuedAt = iat
|
||||
}
|
||||
|
||||
if accessStdClaims.ID == "" {
|
||||
accessStdClaims.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
if refreshStdClaims.IssuedAt == 0 {
|
||||
refreshStdClaims.IssuedAt = iat
|
||||
}
|
||||
|
||||
if refreshStdClaims.ID == "" {
|
||||
refreshStdClaims.ID = uuid.NewString()
|
||||
}
|
||||
|
||||
if refreshStdClaims.OriginID == "" {
|
||||
// keep a reference of the access token the refresh token is created,
|
||||
// if that access token is invalidated then
|
||||
// its refresh token should be too so the user can force-login.
|
||||
refreshStdClaims.OriginID = accessStdClaims.ID
|
||||
}
|
||||
|
||||
accessToken, err := s.keys.SignToken(KIDAccess, t, accessStdClaims)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("access: %w", err)
|
||||
}
|
||||
|
||||
var refreshToken []byte
|
||||
if s.refreshEnabled {
|
||||
refreshToken, err = s.keys.SignToken(KIDRefresh, t, refreshStdClaims)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("refresh: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SigninHandler(ctx *context.Context) {
|
||||
// No, let the developer decide it based on a middleware, e.g. iris.LimitRequestBodySize.
|
||||
// ctx.SetMaxRequestBodySize(s.maxRequestBodySize)
|
||||
|
||||
var (
|
||||
req SigninRequest
|
||||
err error
|
||||
)
|
||||
|
||||
switch ctx.GetContentTypeRequested() {
|
||||
case context.ContentFormHeaderValue, context.ContentFormMultipartHeaderValue:
|
||||
err = ctx.ReadForm(&req)
|
||||
default:
|
||||
err = ctx.ReadJSON(&req)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.errorHandler.InvalidArgument(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Username == "" {
|
||||
req.Username = req.Email
|
||||
}
|
||||
|
||||
accessTokenBytes, refreshTokenBytes, err := s.Signin(ctx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
s.tryRemoveCookie(ctx) // remove cookie on invalidated.
|
||||
|
||||
s.errorHandler.Unauthenticated(ctx, err)
|
||||
return
|
||||
}
|
||||
accessToken := jwt.BytesToString(accessTokenBytes)
|
||||
refreshToken := jwt.BytesToString(refreshTokenBytes)
|
||||
|
||||
s.trySetCookie(ctx, accessToken)
|
||||
|
||||
resp := SigninResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
ctx.JSON(resp)
|
||||
}
|
||||
|
||||
func (s *SSO[T]) Verify(ctx stdContext.Context, token []byte) (T, StandardClaims, error) {
|
||||
t, claims, err := s.verify(ctx, token)
|
||||
if err != nil {
|
||||
return t, StandardClaims{}, fmt.Errorf("sso: verify: %w", err)
|
||||
}
|
||||
|
||||
return t, claims, nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) verify(ctx stdContext.Context, token []byte) (T, StandardClaims, error) {
|
||||
var t T
|
||||
|
||||
if len(token) == 0 { // should never happen at this state.
|
||||
return t, StandardClaims{}, jwt.ErrMissing
|
||||
}
|
||||
|
||||
verifiedToken, err := jwt.VerifyWithHeaderValidator(nil, nil, token, s.keys.ValidateHeader, jwt.Leeway(time.Minute))
|
||||
if err != nil {
|
||||
return t, StandardClaims{}, err
|
||||
}
|
||||
|
||||
if s.transformer != nil {
|
||||
if t, err = s.transformer.Transform(ctx, verifiedToken); err != nil {
|
||||
return t, StandardClaims{}, err
|
||||
}
|
||||
} else {
|
||||
if err = verifiedToken.Claims(&t); err != nil {
|
||||
return t, StandardClaims{}, err
|
||||
}
|
||||
}
|
||||
|
||||
standardClaims := verifiedToken.StandardClaims
|
||||
|
||||
if n := len(s.providers); n > 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
p := s.providers[i]
|
||||
|
||||
err := p.ValidateToken(ctx, standardClaims, t)
|
||||
if err != nil {
|
||||
if i == n-1 { // last provider errored.
|
||||
return t, StandardClaims{}, err
|
||||
}
|
||||
// keep searching.
|
||||
continue
|
||||
}
|
||||
|
||||
// token is allowed.
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// return t, StandardClaims{}, fmt.Errorf("no provider")
|
||||
}
|
||||
|
||||
return t, standardClaims, nil
|
||||
}
|
||||
|
||||
/* Good idea but not practical.
|
||||
func Transform[T User, V User](transformer Transformer[T, V]) context.Handler {
|
||||
return func(ctx *context.Context) {
|
||||
existingUserValue := GetUser[T](ctx)
|
||||
newUserValue, err := transformer.Transform(ctx, existingUserValue)
|
||||
if err != nil {
|
||||
ctx.SetErr(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Values().Set(userContextKey, newUserValue)
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func (s *SSO[T]) VerifyHandler(verifyFuncs ...TVerify[T]) context.Handler {
|
||||
return func(ctx *context.Context) {
|
||||
accessToken := s.extractAccessToken(ctx)
|
||||
|
||||
if accessToken == "" { // if empty, fire 401.
|
||||
s.errorHandler.Unauthenticated(ctx, jwt.ErrMissing)
|
||||
return
|
||||
}
|
||||
|
||||
t, claims, err := s.Verify(ctx, []byte(accessToken))
|
||||
if err != nil {
|
||||
s.errorHandler.Unauthenticated(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, verify := range verifyFuncs {
|
||||
if verify == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = verify(t); err != nil {
|
||||
err = fmt.Errorf("sso: verify: %v", err)
|
||||
s.errorHandler.Unauthenticated(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetUser(t)
|
||||
|
||||
// store the user to the request.
|
||||
ctx.Values().Set(accessTokenContextKey, accessToken)
|
||||
|
||||
ctx.Values().Set(userContextKey, t)
|
||||
ctx.Values().Set(standardClaimsContextKey, claims)
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSO[T]) extractAccessToken(ctx *context.Context) string {
|
||||
// first try from authorization: bearer header.
|
||||
accessToken := extractTokenFromHeader(ctx)
|
||||
|
||||
// then if no header, try try extract from cookie.
|
||||
if accessToken == "" {
|
||||
if cookieName := s.config.Cookie.Name; cookieName != "" {
|
||||
accessToken = ctx.GetCookie(cookieName, context.CookieEncoding(s.securecookie))
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken
|
||||
}
|
||||
|
||||
func (s *SSO[T]) Refresh(ctx stdContext.Context, refreshToken []byte) ([]byte, []byte, error) {
|
||||
if !s.refreshEnabled {
|
||||
return nil, nil, fmt.Errorf("sso: refresh: disabled")
|
||||
}
|
||||
|
||||
t, _, err := s.verify(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("sso: refresh: %w", err)
|
||||
}
|
||||
|
||||
// refresh the tokens, both refresh & access tokens will be renew to prevent
|
||||
// malicious 😈 users that may hold a refresh token.
|
||||
accessTok, refreshTok, err := s.sign(t)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("sso: refresh: %w", err)
|
||||
}
|
||||
|
||||
return accessTok, refreshTok, nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) RefreshHandler(ctx *context.Context) {
|
||||
var req RefreshRequest
|
||||
err := ctx.ReadJSON(&req)
|
||||
if err != nil {
|
||||
s.errorHandler.InvalidArgument(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
accessTokenBytes, refreshTokenBytes, err := s.Refresh(ctx, []byte(req.RefreshToken))
|
||||
if err != nil {
|
||||
// s.tryRemoveCookie(ctx)
|
||||
s.errorHandler.Unauthenticated(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := jwt.BytesToString(accessTokenBytes)
|
||||
refreshToken := jwt.BytesToString(refreshTokenBytes)
|
||||
|
||||
s.trySetCookie(ctx, accessToken)
|
||||
|
||||
resp := SigninResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
ctx.JSON(resp)
|
||||
}
|
||||
|
||||
func (s *SSO[T]) Signout(ctx stdContext.Context, token []byte, all bool) error {
|
||||
t, standardClaims, err := s.verify(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sso: signout: verify: %w", err)
|
||||
}
|
||||
|
||||
for i, n := 0, len(s.providers)-1; i <= n; i++ {
|
||||
p := s.providers[i]
|
||||
|
||||
if all {
|
||||
err = p.InvalidateTokens(ctx, t)
|
||||
} else {
|
||||
err = p.InvalidateToken(ctx, standardClaims, t)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if i == n { // last provider errored.
|
||||
return err
|
||||
}
|
||||
// keep trying.
|
||||
continue
|
||||
}
|
||||
|
||||
// token is marked as invalidated by a provider.
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SignoutHandler(ctx *context.Context) {
|
||||
s.signoutHandler(ctx, false)
|
||||
}
|
||||
|
||||
func (s *SSO[T]) SignoutAllHandler(ctx *context.Context) {
|
||||
s.signoutHandler(ctx, true)
|
||||
}
|
||||
|
||||
func (s *SSO[T]) signoutHandler(ctx *context.Context, all bool) {
|
||||
accessToken := s.extractAccessToken(ctx)
|
||||
if accessToken == "" {
|
||||
s.errorHandler.Unauthenticated(ctx, jwt.ErrMissing)
|
||||
return
|
||||
}
|
||||
|
||||
err := s.Signout(ctx, []byte(accessToken), all)
|
||||
if err != nil {
|
||||
s.errorHandler.Unauthenticated(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.tryRemoveCookie(ctx)
|
||||
|
||||
ctx.SetUser(nil)
|
||||
|
||||
ctx.Values().Remove(accessTokenContextKey)
|
||||
ctx.Values().Remove(userContextKey)
|
||||
ctx.Values().Remove(standardClaimsContextKey)
|
||||
}
|
||||
|
||||
var headerKeys = [...]string{
|
||||
"Authorization",
|
||||
"X-Authorization",
|
||||
}
|
||||
|
||||
func extractTokenFromHeader(ctx *context.Context) string {
|
||||
for _, headerKey := range headerKeys {
|
||||
headerValue := ctx.GetHeader(headerKey)
|
||||
if headerValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// pure check: authorization header format must be Bearer {token}
|
||||
authHeaderParts := strings.Split(headerValue, " ")
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
continue
|
||||
}
|
||||
|
||||
return authHeaderParts[1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *SSO[T]) trySetCookie(ctx *context.Context, accessToken string) {
|
||||
if cookieName := s.config.Cookie.Name; cookieName != "" {
|
||||
maxAge := s.keys[KIDAccess].MaxAge
|
||||
if maxAge == 0 {
|
||||
maxAge = context.SetCookieKVExpiration
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Path: "/",
|
||||
Name: cookieName,
|
||||
Value: url.QueryEscape(accessToken),
|
||||
HttpOnly: true,
|
||||
Domain: ctx.Domain(),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(maxAge),
|
||||
MaxAge: int(maxAge.Seconds()),
|
||||
}
|
||||
|
||||
ctx.SetCookie(cookie, context.CookieEncoding(s.securecookie))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSO[T]) tryRemoveCookie(ctx *context.Context) {
|
||||
if cookieName := s.config.Cookie.Name; cookieName != "" {
|
||||
ctx.RemoveCookie(cookieName)
|
||||
}
|
||||
}
|
||||
53
sso/user.go
Normal file
53
sso/user.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build go1.18
|
||||
|
||||
package sso
|
||||
|
||||
import (
|
||||
"github.com/kataras/iris/v12/context"
|
||||
|
||||
"github.com/kataras/jwt"
|
||||
)
|
||||
|
||||
type (
|
||||
StandardClaims = jwt.Claims
|
||||
User = interface{} // any type.
|
||||
)
|
||||
|
||||
const accessTokenContextKey = "iris.sso.context.access_token"
|
||||
|
||||
func GetAccessToken(ctx *context.Context) string {
|
||||
return ctx.Values().GetString(accessTokenContextKey)
|
||||
}
|
||||
|
||||
const standardClaimsContextKey = "iris.sso.context.standard_claims"
|
||||
|
||||
func GetStandardClaims(ctx *context.Context) StandardClaims {
|
||||
if v := ctx.Values().Get(standardClaimsContextKey); v != nil {
|
||||
if c, ok := v.(StandardClaims); ok {
|
||||
return c
|
||||
}
|
||||
}
|
||||
|
||||
return StandardClaims{}
|
||||
}
|
||||
|
||||
func (s *SSO[T]) GetStandardClaims(ctx *context.Context) StandardClaims {
|
||||
return GetStandardClaims(ctx)
|
||||
}
|
||||
|
||||
const userContextKey = "iris.sso.context.user"
|
||||
|
||||
func GetUser[T User](ctx *context.Context) T {
|
||||
if v := ctx.Values().Get(userContextKey); v != nil {
|
||||
if t, ok := v.(T); ok {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
var empty T
|
||||
return empty
|
||||
}
|
||||
|
||||
func (s *SSO[T]) GetUser(ctx *context.Context) T {
|
||||
return GetUser[T](ctx)
|
||||
}
|
||||
Reference in New Issue
Block a user