mirror of https://github.com/zhufuyi/sponge
feat: simplify the use of jwt middleware in gin
This commit is contained in:
parent
07f23f7b14
commit
3ff6e6efae
|
@ -1,7 +1,17 @@
|
|||
## middleware
|
||||
|
||||
Common gin middleware libraries.
|
||||
Common gin middleware libraries, including:
|
||||
|
||||
- [Logging](README.md#logging-middleware)
|
||||
- [Cors](README.md#allow-cross-domain-requests-middleware)
|
||||
- [Rate limiter](README.md#rate-limiter-middleware)
|
||||
- [Circuit breaker](README.md#circuit-breaker-middleware)
|
||||
- [JWT authorization](README.md#jwt-authorization-middleware)
|
||||
- [Tracing](README.md#tracing-middleware)
|
||||
- [Metrics](README.md#metrics-middleware)
|
||||
- [Request id](README.md#request-id-middleware)
|
||||
- [Timeout](README.md#timeout-middleware)
|
||||
|
||||
<br>
|
||||
|
||||
## Example of use
|
||||
|
@ -127,113 +137,123 @@ func NewRouter() *gin.Engine {
|
|||
|
||||
### JWT authorization middleware
|
||||
|
||||
```go
|
||||
package main
|
||||
There are two usage examples available:
|
||||
|
||||
import (
|
||||
"time"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/middleware"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/response"
|
||||
"github.com/go-dev-frame/sponge/pkg/jwt"
|
||||
)
|
||||
1. **Example One**: This example adopts a highly abstracted design, making it simpler and more convenient to use. Click to view the example at [pkg/gin/middleware/auth](https://github.com/go-dev-frame/sponge/tree/main/pkg/gin/middleware/auth#example-of-use). Requires sponge version `v1.13.2+`.
|
||||
2. **Example Two**: This example offers greater flexibility and is suitable for scenarios requiring custom implementations. The example code is as follows:
|
||||
|
||||
func main() {
|
||||
r := gin.Default()
|
||||
|
||||
// Case 1: default jwt options, signKey, signMethod(HS256), expiry time(24 hour)
|
||||
{
|
||||
r.POST("/auth/login", LoginDefault)
|
||||
r.GET("/demo1/user/:id", middleware.Auth(), GetByID)
|
||||
r.GET("/demo2/user/:id", middleware.Auth(middleware.WithReturnErrReason()), GetByID)
|
||||
r.GET("/demo3/user/:id", middleware.Auth(middleware.WithExtraVerify(extraVerifyFn)), GetByID)
|
||||
}
|
||||
|
||||
// Case 2: custom jwt options, signKey, signMethod(HS512), expiry time(12 hour), fields, claims
|
||||
{
|
||||
signKey := []byte("custom-sign-key")
|
||||
jwtAuth1 := middleware.Auth(middleware.WithSignKey(signKey))
|
||||
jwtAuth2 := middleware.Auth(middleware.WithSignKey(signKey), middleware.WithReturnErrReason())
|
||||
jwtAuth3 := middleware.Auth(middleware.WithSignKey(signKey), middleware.WithExtraVerify(extraVerifyFn))
|
||||
|
||||
r.POST("/auth/login", LoginCustom)
|
||||
r.GET("/demo4/user/:id", jwtAuth1, GetByID)
|
||||
r.GET("/demo5/user/:id", jwtAuth2, GetByID)
|
||||
r.GET("/demo6/user/:id", jwtAuth3, GetByID)
|
||||
}
|
||||
|
||||
r.Run(":8080")
|
||||
}
|
||||
|
||||
func LoginDefault(c *gin.Context) {
|
||||
// ......
|
||||
|
||||
_, token, err := jwt.GenerateToken("100")
|
||||
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
func LoginCustom(c *gin.Context) {
|
||||
// ......
|
||||
|
||||
uid := "100"
|
||||
fields := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"age": 10,
|
||||
"is_vip": true,
|
||||
}
|
||||
|
||||
_, token, err := jwt.GenerateToken(
|
||||
uid,
|
||||
jwt.WithGenerateTokenSignKey([]byte("custom-sign-key")),
|
||||
jwt.WithGenerateTokenSignMethod(jwt.HS512),
|
||||
jwt.WithGenerateTokenFields(fields),
|
||||
jwt.WithGenerateTokenClaims([]jwt.RegisteredClaimsOption{
|
||||
jwt.WithExpires(time.Hour * 12),
|
||||
//jwt.WithIssuedAt(now),
|
||||
// jwt.WithSubject("123"),
|
||||
// jwt.WithIssuer("https://auth.example.com"),
|
||||
// jwt.WithAudience("https://api.example.com"),
|
||||
// jwt.WithNotBefore(now),
|
||||
// jwt.WithJwtID("abc1234xxx"),
|
||||
}...),
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/middleware"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/response"
|
||||
"github.com/go-dev-frame/sponge/pkg/jwt"
|
||||
)
|
||||
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
func GetByID(c *gin.Context) {
|
||||
uid := c.MustGet("id").(string)
|
||||
|
||||
claims,ok := middleware.GetClaims(c) // if necessary, claims can be got from gin context.
|
||||
|
||||
response.Success(c, gin.H{"id": uid})
|
||||
}
|
||||
|
||||
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
|
||||
// check if token is about to expire (less than 10 minutes remaining)
|
||||
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
|
||||
token, err := claims.NewToken(time.Hour*24, jwt.HS256, jwtSignKey) // same signature as jwt.GenerateToken
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
func main() {
|
||||
r := gin.Default()
|
||||
|
||||
g := r.Group("/api/v1")
|
||||
|
||||
// Case 1: default jwt options, signKey, signMethod(HS256), expiry time(24 hour)
|
||||
{
|
||||
r.POST("/auth/login", LoginDefault)
|
||||
g.Use(middleware.Auth())
|
||||
//g.Use(middleware.Auth(middleware.WithExtraVerify(extraVerifyFn))) // add extra verify function
|
||||
}
|
||||
c.Header("X-Renewed-Token", token)
|
||||
|
||||
// Case 2: custom jwt options, signKey, signMethod(HS512), expiry time(48 hour), fields, claims
|
||||
{
|
||||
r.POST("/auth/login", LoginCustom)
|
||||
signKey := []byte("your-sign-key")
|
||||
g.Use(middleware.Auth(middleware.WithSignKey(signKey)))
|
||||
//g.Use(middleware.Auth(middleware.WithSignKey(signKey), middleware.WithExtraVerify(extraVerifyFn))) // add extra verify function
|
||||
}
|
||||
|
||||
g.GET("/user/:id", GetByID)
|
||||
//g.PUT("/user/:id", Create)
|
||||
//g.DELETE("/user/:id", DeleteByID)
|
||||
|
||||
r.Run(":8080")
|
||||
}
|
||||
|
||||
// judge whether the user is disabled, query whether jwt id exists from the blacklist
|
||||
//if CheckBlackList(uid, claims.ID) {
|
||||
// return errors.New("user is disabled")
|
||||
//}
|
||||
|
||||
// get fields from claims
|
||||
//uid := claims.UID
|
||||
//name, _ := claims.GetString("name")
|
||||
//age, _ := claims.GetInt("age")
|
||||
//isVip, _ := claims.GetBool("is_vip")
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
func customGenerateToken(uid string, fields map[string]interface{}) (string, error) {
|
||||
_, token, err := jwt.GenerateToken(
|
||||
uid,
|
||||
jwt.WithGenerateTokenSignKey([]byte("custom-sign-key")),
|
||||
jwt.WithGenerateTokenSignMethod(jwt.HS512),
|
||||
jwt.WithGenerateTokenFields(fields),
|
||||
jwt.WithGenerateTokenClaims([]jwt.RegisteredClaimsOption{
|
||||
jwt.WithExpires(time.Hour * 48),
|
||||
//jwt.WithIssuedAt(now),
|
||||
// jwt.WithSubject("123"),
|
||||
// jwt.WithIssuer("https://middleware.example.com"),
|
||||
// jwt.WithAudience("https://api.example.com"),
|
||||
// jwt.WithNotBefore(now),
|
||||
// jwt.WithJwtID("abc1234xxx"),
|
||||
}...),
|
||||
)
|
||||
|
||||
return token, err
|
||||
}
|
||||
|
||||
func LoginDefault(c *gin.Context) {
|
||||
// ......
|
||||
|
||||
_, token, err := jwt.GenerateToken("100")
|
||||
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
func LoginCustom(c *gin.Context) {
|
||||
// ......
|
||||
|
||||
uid := "100"
|
||||
fields := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"age": 10,
|
||||
"is_vip": true,
|
||||
}
|
||||
|
||||
token, err := customGenerateToken(uid, fields)
|
||||
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
func GetByID(c *gin.Context) {
|
||||
uid := c.Param("id")
|
||||
|
||||
// if necessary, claims can be got from gin context.
|
||||
claims, ok := middleware.GetClaims(c)
|
||||
//uid := claims.UID
|
||||
//name, _ := claims.GetString("name")
|
||||
//age, _ := claims.GetInt("age")
|
||||
//isVip, _ := claims.GetBool("is_vip")
|
||||
|
||||
response.Success(c, gin.H{"id": uid})
|
||||
}
|
||||
|
||||
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
|
||||
// check if token is about to expire (less than 10 minutes remaining)
|
||||
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
|
||||
token, err := claims.NewToken(time.Hour*24, jwt.HS512, []byte("your-sign-key")) // same signature as jwt.GenerateToken
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Header("X-Renewed-Token", token)
|
||||
}
|
||||
|
||||
// judge whether the user is disabled, query whether jwt id exists from the blacklist
|
||||
//if CheckBlackList(uid, claims.ID) {
|
||||
// return errors.New("user is disabled")
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -311,7 +331,7 @@ func NewRouter() *gin.Engine {
|
|||
|
||||
<br>
|
||||
|
||||
### Request id
|
||||
### Request id middleware
|
||||
|
||||
```go
|
||||
import (
|
||||
|
@ -345,7 +365,7 @@ func NewRouter() *gin.Engine {
|
|||
|
||||
<br>
|
||||
|
||||
### Timeout
|
||||
### Timeout middleware
|
||||
|
||||
```go
|
||||
import (
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
## auth
|
||||
|
||||
`auth` middleware for gin framework.
|
||||
|
||||
### Example of use
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/middleware/auth"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/response"
|
||||
)
|
||||
|
||||
func main() {
|
||||
r := gin.Default()
|
||||
|
||||
// initialize jwt first
|
||||
auth.InitAuth([]byte("your-sign-key"), time.Hour*24) // default signing method is HS256
|
||||
// auth.InitAuth([]byte("your-sign-key"), time.Minute*24, WithInitAuthSigningMethod(HS512), WithInitAuthIssuer("foobar.com"))
|
||||
|
||||
r.POST("/auth/login", Login)
|
||||
|
||||
g := r.Group("/api/v1")
|
||||
g.Use(auth.Auth())
|
||||
//g.Use(auth.Auth(auth.WithExtraVerify(extraVerifyFn))) // add extra verify function
|
||||
|
||||
g.GET("/user/:id", GetByID)
|
||||
//g.PUT("/user/:id", Create)
|
||||
//g.DELETE("/user/:id", DeleteByID)
|
||||
|
||||
r.Run(":8080")
|
||||
}
|
||||
|
||||
func Login(c *gin.Context) {
|
||||
// ......
|
||||
|
||||
// Case 1: only uid for token
|
||||
{
|
||||
token, err := auth.GenerateToken("100")
|
||||
}
|
||||
|
||||
// Case 2: uid and custom fields for token
|
||||
{
|
||||
uid := "100"
|
||||
fields := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"age": 10,
|
||||
"is_vip": true,
|
||||
}
|
||||
token, err := auth.GenerateToken(uid, auth.WithGenerateTokenFields(fields))
|
||||
}
|
||||
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
func GetByID(c *gin.Context) {
|
||||
uid := c.Param("id")
|
||||
|
||||
// if necessary, claims can be got from gin context
|
||||
claims, ok := auth.GetClaims(c)
|
||||
//uid := claims.UID
|
||||
//name, _ := claims.GetString("name")
|
||||
//age, _ := claims.GetInt("age")
|
||||
//isVip, _ := claims.GetBool("is_vip")
|
||||
|
||||
response.Success(c, gin.H{"id": uid})
|
||||
}
|
||||
|
||||
func extraVerifyFn(claims *auth.Claims, c *gin.Context) error {
|
||||
// check if token is about to expire (less than 10 minutes remaining)
|
||||
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
|
||||
token, err := auth.RefreshToken(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Header("X-Renewed-Token", token)
|
||||
}
|
||||
|
||||
// judge whether the user is disabled, query whether jwt id exists from the blacklist
|
||||
//if CheckBlackList(uid, claims.ID) {
|
||||
// return errors.New("user is disabled")
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
<br>
|
|
@ -0,0 +1,229 @@
|
|||
// Package auth provides JWT authentication middleware for gin.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/go-dev-frame/sponge/pkg/errcode"
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/response"
|
||||
"github.com/go-dev-frame/sponge/pkg/jwt"
|
||||
)
|
||||
|
||||
type SigningMethodHMAC = jwt.SigningMethodHMAC
|
||||
type Claims = jwt.Claims
|
||||
|
||||
var (
|
||||
HS256 = jwt.HS256
|
||||
HS384 = jwt.HS384
|
||||
HS512 = jwt.HS512
|
||||
)
|
||||
|
||||
var (
|
||||
customSigningKey []byte
|
||||
customSigningMethod *jwt.SigningMethodHMAC
|
||||
customExpire time.Duration
|
||||
customIssuer string
|
||||
|
||||
errOption = errors.New("jwt option is nil, please initialize first, call middleware.InitAuth()")
|
||||
)
|
||||
|
||||
type initAuthOptions struct {
|
||||
issuer string
|
||||
signingMethod *SigningMethodHMAC
|
||||
}
|
||||
|
||||
func defaultInirAuthOptions() *initAuthOptions {
|
||||
return &initAuthOptions{
|
||||
signingMethod: HS256,
|
||||
}
|
||||
}
|
||||
|
||||
// InitAuthOption set the jwt initAuthOptions.
|
||||
type InitAuthOption func(*initAuthOptions)
|
||||
|
||||
func (o *initAuthOptions) apply(opts ...InitAuthOption) {
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitAuthSigningMethod set signing method value
|
||||
func WithInitAuthSigningMethod(sm *jwt.SigningMethodHMAC) InitAuthOption {
|
||||
return func(o *initAuthOptions) {
|
||||
o.signingMethod = sm
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitAuthIssuer set issuer value
|
||||
func WithInitAuthIssuer(issuer string) InitAuthOption {
|
||||
return func(o *initAuthOptions) {
|
||||
o.issuer = issuer
|
||||
}
|
||||
}
|
||||
|
||||
// InitAuth initializes jwt options.
|
||||
func InitAuth(signingKey []byte, expire time.Duration, opts ...InitAuthOption) {
|
||||
o := defaultInirAuthOptions()
|
||||
o.apply(opts...)
|
||||
|
||||
customSigningKey = signingKey
|
||||
customExpire = expire
|
||||
customSigningMethod = o.signingMethod
|
||||
customIssuer = o.issuer
|
||||
}
|
||||
|
||||
// GenerateTokenOption set the jwt options.
|
||||
type GenerateTokenOption func(*generateTokenOptions)
|
||||
|
||||
type generateTokenOptions struct {
|
||||
fields map[string]interface{}
|
||||
}
|
||||
|
||||
func (o *generateTokenOptions) apply(opts ...GenerateTokenOption) {
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithGenerateTokenFields set custom fields value
|
||||
func WithGenerateTokenFields(fields map[string]interface{}) GenerateTokenOption {
|
||||
return func(o *generateTokenOptions) {
|
||||
o.fields = fields
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates a jwt token with the given uid and options.
|
||||
func GenerateToken(uid string, opts ...GenerateTokenOption) (string, error) {
|
||||
if customSigningMethod == nil || len(customSigningKey) == 0 {
|
||||
panic(errOption)
|
||||
}
|
||||
|
||||
genOpts := []jwt.GenerateTokenOption{
|
||||
jwt.WithGenerateTokenSignKey(customSigningKey),
|
||||
jwt.WithGenerateTokenSignMethod(customSigningMethod),
|
||||
}
|
||||
o := &generateTokenOptions{}
|
||||
o.apply(opts...)
|
||||
if len(o.fields) > 0 {
|
||||
genOpts = append(genOpts, jwt.WithGenerateTokenFields(o.fields))
|
||||
}
|
||||
|
||||
claimsOpts := []jwt.RegisteredClaimsOption{
|
||||
jwt.WithExpires(customExpire),
|
||||
}
|
||||
if customIssuer != "" {
|
||||
claimsOpts = append(claimsOpts, jwt.WithIssuer(customIssuer))
|
||||
}
|
||||
genOpts = append(genOpts, jwt.WithGenerateTokenClaims(claimsOpts...))
|
||||
|
||||
_, token, err := jwt.GenerateToken(uid, genOpts...)
|
||||
return token, err
|
||||
}
|
||||
|
||||
// ParseToken parses the given token and returns the claims.
|
||||
func ParseToken(token string) (*jwt.Claims, error) {
|
||||
if customSigningMethod == nil {
|
||||
panic(errOption)
|
||||
}
|
||||
|
||||
return jwt.ValidateToken(token, jwt.WithValidateTokenSignKey(customSigningKey))
|
||||
}
|
||||
|
||||
// RefreshToken create a new token with the given claims.
|
||||
func RefreshToken(claims *jwt.Claims) (string, error) {
|
||||
return claims.NewToken(customExpire, customSigningMethod, customSigningKey)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------
|
||||
|
||||
// HeaderAuthorizationKey http header authorization key, value is "Bearer token"
|
||||
const HeaderAuthorizationKey = "Authorization"
|
||||
|
||||
// ExtraVerifyFn extra verify function
|
||||
type ExtraVerifyFn = func(claims *jwt.Claims, c *gin.Context) error
|
||||
|
||||
// AuthOption set the auth options.
|
||||
type AuthOption func(*authOptions)
|
||||
|
||||
type authOptions struct {
|
||||
isReturnErrReason bool
|
||||
extraVerifyFn ExtraVerifyFn
|
||||
}
|
||||
|
||||
func defaultAuthOptions() *authOptions {
|
||||
return &authOptions{}
|
||||
}
|
||||
|
||||
func (o *authOptions) apply(opts ...AuthOption) {
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithReturnErrReason set return error reason
|
||||
func WithReturnErrReason() AuthOption {
|
||||
return func(o *authOptions) {
|
||||
o.isReturnErrReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithExtraVerify set extra verify function
|
||||
func WithExtraVerify(fn ExtraVerifyFn) AuthOption {
|
||||
return func(o *authOptions) {
|
||||
o.extraVerifyFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
func responseUnauthorized(isReturnErrReason bool, errMsg string) *errcode.Error {
|
||||
if isReturnErrReason {
|
||||
return errcode.Unauthorized.RewriteMsg("Unauthorized, " + errMsg)
|
||||
}
|
||||
return errcode.Unauthorized
|
||||
}
|
||||
|
||||
// Auth authorization middleware, support custom extra verify.
|
||||
func Auth(opts ...AuthOption) gin.HandlerFunc {
|
||||
o := defaultAuthOptions()
|
||||
o.apply(opts...)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
authorization := c.GetHeader(HeaderAuthorizationKey)
|
||||
if len(authorization) < 100 {
|
||||
response.Out(c, responseUnauthorized(o.isReturnErrReason, "token is illegal"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := authorization[7:] // remove Bearer prefix
|
||||
|
||||
claims, err := ParseToken(tokenString)
|
||||
if err != nil {
|
||||
response.Out(c, responseUnauthorized(o.isReturnErrReason, err.Error()))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// extra verify function
|
||||
if o.extraVerifyFn != nil {
|
||||
if err = o.extraVerifyFn(claims, c); err != nil {
|
||||
response.Out(c, responseUnauthorized(o.isReturnErrReason, err.Error()))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("claims", claims) // set claims to context
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetClaims get jwt claims from gin context.
|
||||
func GetClaims(c *gin.Context) (*jwt.Claims, bool) {
|
||||
claims, exists := c.Get("claims")
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
jwtClaims, ok := claims.(*jwt.Claims)
|
||||
return jwtClaims, ok
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/go-dev-frame/sponge/pkg/gin/response"
|
||||
"github.com/go-dev-frame/sponge/pkg/httpcli"
|
||||
"github.com/go-dev-frame/sponge/pkg/jwt"
|
||||
"github.com/go-dev-frame/sponge/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
uid = "100"
|
||||
fields = map[string]interface{}{
|
||||
"name": "bob",
|
||||
"age": 10,
|
||||
"is_vip": true,
|
||||
}
|
||||
jwtSignKey = []byte("your-secret-key")
|
||||
|
||||
errMsg = http.StatusText(http.StatusUnauthorized)
|
||||
compareMsgFn = func(em string) bool {
|
||||
return strings.Contains(em, errMsg)
|
||||
}
|
||||
)
|
||||
|
||||
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
|
||||
// check if token is about to expire (less than 10 minutes remaining)
|
||||
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
|
||||
token, err := RefreshToken(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Header("X-Renewed-Token", token)
|
||||
}
|
||||
|
||||
// judge whether the user is disabled, query whether jwt id exists from the blacklist
|
||||
//if CheckBlackList(uid, claims.ID) {
|
||||
// return errors.New("user is disabled")
|
||||
//}
|
||||
|
||||
// check fields
|
||||
if claims.UID != uid {
|
||||
return fmt.Errorf("uid not match, expect %s, got %s", uid, claims.UID)
|
||||
}
|
||||
if name, _ := claims.GetString("name"); name != fields["name"] {
|
||||
return fmt.Errorf("name not match, expect %s, got %s", fields["name"], name)
|
||||
}
|
||||
if age, _ := claims.GetInt("age"); age != fields["age"] {
|
||||
return fmt.Errorf("age not match, expect %d, got %d", fields["age"], age)
|
||||
}
|
||||
if isVip, _ := claims.GetBool("is_vip"); isVip != fields["is_vip"] {
|
||||
return fmt.Errorf("is_vip not match, expect %v, got %v", fields["is_vip"], isVip)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runAuthHTTPServer() string {
|
||||
serverAddr, requestAddr := utils.GetLocalHTTPAddrPairs()
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.Default()
|
||||
|
||||
loginHandler := func(c *gin.Context) {
|
||||
token, _ := GenerateToken(uid)
|
||||
fmt.Println("token1 =", token)
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
loginCustomFieldsHandler := func(c *gin.Context) {
|
||||
token, _ := GenerateToken(uid, WithGenerateTokenFields(fields))
|
||||
fmt.Println("token2 =", token)
|
||||
response.Success(c, token)
|
||||
}
|
||||
|
||||
getUserByIDHandler := func(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
claims, ok := GetClaims(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"msg": "unauthorized"})
|
||||
return
|
||||
}
|
||||
fmt.Println("claims =", claims)
|
||||
response.Success(c, id)
|
||||
}
|
||||
|
||||
r.GET("/auth/login", loginHandler)
|
||||
r.GET("/auth/loginCustomFields", loginCustomFieldsHandler)
|
||||
r.GET("/user/:id", Auth(), getUserByIDHandler)
|
||||
r.GET("/user/log/:id", Auth(WithReturnErrReason()), getUserByIDHandler)
|
||||
r.GET("/user/extra_verify/:id", Auth(WithExtraVerify(extraVerifyFn), WithReturnErrReason()), getUserByIDHandler)
|
||||
|
||||
go func() {
|
||||
err := r.Run(serverAddr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
return requestAddr
|
||||
}
|
||||
|
||||
func getUser(url string, authorization string) (gin.H, error) {
|
||||
var result = gin.H{}
|
||||
|
||||
client := &http.Client{}
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
request.Header.Add("Authorization", authorization)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
resp, _ := client.Do(request)
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &result)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
requestAddr := runAuthHTTPServer()
|
||||
InitAuth(jwtSignKey, time.Minute*10)
|
||||
//InitAuth(jwtSignKey, time.Minute*10, WithInitAuthSigningMethod(HS512), WithInitAuthIssuer("foobar.com"))
|
||||
|
||||
t.Run("only uid for generate token", func(t *testing.T) {
|
||||
// get token
|
||||
result := &httpcli.StdResult{}
|
||||
err := httpcli.Get(result, requestAddr+"/auth/login")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := result.Data.(string)
|
||||
authorization := fmt.Sprintf("Bearer %s", token)
|
||||
|
||||
// success
|
||||
val, err := getUser(requestAddr+"/user/"+uid, authorization)
|
||||
assert.Equal(t, val["data"], uid)
|
||||
|
||||
// success
|
||||
val, err = getUser(requestAddr+"/user/log/"+uid, authorization)
|
||||
assert.Equal(t, val["data"], uid)
|
||||
|
||||
// return 401, the reason is token have no extra field
|
||||
val, err = getUser(requestAddr+"/user/extra_verify/"+uid, authorization)
|
||||
assert.Equal(t, true, compareMsgFn(val["msg"].(string)))
|
||||
|
||||
// return 401, the reason is token value is invalid
|
||||
val, err = getUser(requestAddr+"/user/"+uid, "error-authorization")
|
||||
assert.Equal(t, val["msg"], errMsg)
|
||||
})
|
||||
|
||||
t.Run("uid and fields for generate token", func(t *testing.T) {
|
||||
// get token
|
||||
result := &httpcli.StdResult{}
|
||||
err := httpcli.Get(result, requestAddr+"/auth/loginCustomFields")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := result.Data.(string)
|
||||
authorization := fmt.Sprintf("Bearer %s", token)
|
||||
|
||||
// success
|
||||
val, err := getUser(requestAddr+"/user/"+uid, authorization)
|
||||
assert.Equal(t, val["data"], uid)
|
||||
|
||||
// success
|
||||
val, err = getUser(requestAddr+"/user/log/"+uid, authorization)
|
||||
assert.Equal(t, val["data"], uid)
|
||||
|
||||
// return 401, the reason is token expired
|
||||
token = "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJ1aWQiOiIxMDAiLCJmaWVsZHMiOnsiYWdlIjoxMCwiaXNfdmlwIjp0cnVlLCJuYW1lIjoiYm9iIn0sImV4cCI6MTc0NjY0MTY0MCwiaWF0IjoxNzQ2NjQxMDQwLCJqdGkiOiIxODNkNTBjNWIxZTdmMTEwIn0.P11q5VPo-88Sbw4JKLtp2_Aiz8Pc1oL-jrdEAX0NwJJoxnR_Iu8W6eI7CsUCzVGW"
|
||||
authorization = fmt.Sprintf("Bearer %s", token)
|
||||
val, err = getUser(requestAddr+"/user/extra_verify/"+uid, authorization)
|
||||
assert.Equal(t, true, compareMsgFn(val["msg"].(string)))
|
||||
|
||||
// return 401, the reason is token value is invalid
|
||||
val, err = getUser(requestAddr+"/user/"+uid, "error-authorization")
|
||||
assert.Equal(t, val["msg"], errMsg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
t.Run("GenerateToken error", func(t *testing.T) {
|
||||
defer func() { recover() }()
|
||||
GenerateToken("100")
|
||||
})
|
||||
t.Run("ParseToken error", func(t *testing.T) {
|
||||
defer func() { recover() }()
|
||||
ParseToken("xxx")
|
||||
})
|
||||
}
|
|
@ -8,16 +8,16 @@ import (
|
|||
valid "github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Init request body file valid
|
||||
// Init validator instance, used to gin request parameter check
|
||||
func Init() *CustomValidator {
|
||||
validator := NewCustomValidator()
|
||||
validator.Engine()
|
||||
return validator
|
||||
v := NewCustomValidator()
|
||||
v.Engine()
|
||||
return v
|
||||
}
|
||||
|
||||
// CustomValidator Custom valid objects
|
||||
type CustomValidator struct {
|
||||
Once sync.Once
|
||||
once sync.Once
|
||||
Validate *valid.Validate
|
||||
}
|
||||
|
||||
|
@ -26,37 +26,52 @@ func NewCustomValidator() *CustomValidator {
|
|||
return &CustomValidator{}
|
||||
}
|
||||
|
||||
// ValidateStruct Instantiate struct valid
|
||||
// ValidateStruct validates a struct or slice/array
|
||||
func (v *CustomValidator) ValidateStruct(obj interface{}) error {
|
||||
if kindOfData(obj) == reflect.Struct {
|
||||
v.lazyinit()
|
||||
if obj == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(obj)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.Struct:
|
||||
if err := v.Validate.Struct(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// pointer type: if nil, no validation required; otherwise recursive validation after dereference
|
||||
if val.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return v.ValidateStruct(val.Elem().Interface())
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
// slice or array type: iterates over each element, recursively validating one by one
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
elem := val.Index(i)
|
||||
if err := v.ValidateStruct(elem.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Engine Instantiate valid
|
||||
// Engine set tag name "binding", which is implementing the validator interface of the gin framework
|
||||
func (v *CustomValidator) Engine() interface{} {
|
||||
v.lazyinit()
|
||||
v.lazyInit()
|
||||
return v.Validate
|
||||
}
|
||||
|
||||
func (v *CustomValidator) lazyinit() {
|
||||
v.Once.Do(func() {
|
||||
func (v *CustomValidator) lazyInit() {
|
||||
v.once.Do(func() {
|
||||
v.Validate = valid.New()
|
||||
v.Validate.SetTagName("binding")
|
||||
})
|
||||
}
|
||||
|
||||
func kindOfData(data interface{}) reflect.Kind {
|
||||
value := reflect.ValueOf(data)
|
||||
valueType := value.Kind()
|
||||
|
||||
if valueType == reflect.Ptr {
|
||||
valueType = value.Elem().Kind()
|
||||
}
|
||||
return valueType
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -420,39 +419,122 @@ func do(method string, url string, body interface{}) ([]byte, error) {
|
|||
|
||||
// ------------------------------------------------------------------------------------------
|
||||
|
||||
type st struct {
|
||||
Name string
|
||||
}
|
||||
func Test_CustomValidator_ValidateStruct(t *testing.T) {
|
||||
type User struct {
|
||||
Name string `binding:"required"`
|
||||
Age int `binding:"gte=18"`
|
||||
}
|
||||
|
||||
func TestCustomValidator_Engine(t *testing.T) {
|
||||
validator := NewCustomValidator()
|
||||
v := validator.Engine()
|
||||
assert.NotNil(t, v)
|
||||
}
|
||||
type UserList1 struct {
|
||||
Users []User `binding:"required,dive"`
|
||||
}
|
||||
|
||||
func TestCustomValidator_ValidateStruct(t *testing.T) {
|
||||
validator := NewCustomValidator()
|
||||
err := validator.ValidateStruct(new(st))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
type UserList2 struct {
|
||||
Users []*User `binding:"required,dive"`
|
||||
}
|
||||
|
||||
func TestCustomValidator_lazyinit(t *testing.T) {
|
||||
validator := NewCustomValidator()
|
||||
validator.lazyinit()
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
validator := Init()
|
||||
assert.NotNil(t, validator)
|
||||
|
||||
user := &User{Name: "John", Age: 10}
|
||||
if err := validator.ValidateStruct(user); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
var u = &User{Name: "John", Age: 11}
|
||||
if err := validator.ValidateStruct(&u); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
users := []User{{Name: "Alice", Age: 25}, {Name: "Bob", Age: 17}}
|
||||
if err := validator.ValidateStruct(users); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
userList := UserList1{}
|
||||
if err := validator.ValidateStruct(&userList); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
userList1 := UserList1{
|
||||
Users: []User{{Name: "Charlie", Age: 10}, {Name: "", Age: 30}},
|
||||
}
|
||||
if err := validator.ValidateStruct(&userList1); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
userList2 := UserList2{
|
||||
Users: []*User{{Name: "Charlie", Age: 30}, {Name: "", Age: 40}},
|
||||
}
|
||||
if err := validator.ValidateStruct(&userList2); err != nil {
|
||||
assert.NotNil(t, err)
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCustomValidator(t *testing.T) {
|
||||
validator := NewCustomValidator()
|
||||
assert.NotNil(t, validator)
|
||||
}
|
||||
func Benchmark_CustomValidator_ValidateStruct(b *testing.B) {
|
||||
type User struct {
|
||||
Name string `binding:"required"`
|
||||
Age int `binding:"gte=18"`
|
||||
}
|
||||
|
||||
func Test_kindOfData(t *testing.T) {
|
||||
type UserList1 struct {
|
||||
Users []User `binding:"required,dive"` // 验证指针切片
|
||||
}
|
||||
|
||||
kind := kindOfData(new(st))
|
||||
assert.Equal(t, reflect.Struct, kind)
|
||||
type UserList2 struct {
|
||||
Users []*User `binding:"required,dive"` // 验证指针切片
|
||||
}
|
||||
|
||||
validator := Init()
|
||||
|
||||
b.Run("User struct", func(b *testing.B) {
|
||||
user := User{Name: "John", Age: 10}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(user)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("User struct pointer", func(b *testing.B) {
|
||||
user := &User{Name: "John", Age: 10}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(user)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("User struct pointer pointer", func(b *testing.B) {
|
||||
var u = &User{Name: "John", Age: 11}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(&u)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("User slice", func(b *testing.B) {
|
||||
users := []User{{Name: "Alice", Age: 25}, {Name: "Bob", Age: 17}}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(users)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UserList slice struct", func(b *testing.B) {
|
||||
userList1 := UserList1{
|
||||
Users: []User{{Name: "Charlie", Age: 10}, {Name: "", Age: 30}},
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(&userList1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UserList slice struct pointer", func(b *testing.B) {
|
||||
userList2 := UserList2{
|
||||
Users: []*User{{Name: "Charlie", Age: 30}, {Name: "", Age: 40}},
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateStruct(&userList2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
|
||||
Common interceptors for gRPC server and client side, including:
|
||||
|
||||
- Logging
|
||||
- Recovery
|
||||
- Retry
|
||||
- Rate limiter
|
||||
- Circuit breaker
|
||||
- Timeout
|
||||
- Tracing
|
||||
- Request id
|
||||
- Metrics
|
||||
- JWT authentication
|
||||
- [Logging](README.md#logging-interceptor)
|
||||
- [Recovery](README.md#recovery-interceptor)
|
||||
- [Retry](README.md#retry-interceptor)
|
||||
- [Rate limiter](README.md#rate-limiter-interceptor)
|
||||
- [Circuit breaker](README.md#circuit-breaker-interceptor)
|
||||
- [Timeout](README.md#timeout-interceptor)
|
||||
- [Tracing](README.md#tracing-interceptor)
|
||||
- [Request id](README.md#request-id-interceptor)
|
||||
- [Metrics](README.md#metrics-interceptor)
|
||||
- [JWT authentication](README.md#jwt-authentication-interceptor)
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -492,6 +492,15 @@ func extraVerifyFn(ctx context.Context, claims *jwt.Claims) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID ...
|
||||
func (s *user) GetByID(ctx context.Context, req *userV1.GetByIDRequest) (*userV1.GetByIDReply, error) {
|
||||
// ......
|
||||
|
||||
claims,ok := interceptor.GetJwtClaims(ctx) // if necessary, claims can be got from gin context.
|
||||
|
||||
// ......
|
||||
}
|
||||
```
|
||||
|
||||
**gRPC client side**
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"github.com/go-dev-frame/sponge/pkg/krand"
|
||||
)
|
||||
|
||||
type SigningMethodHMAC = jwt.SigningMethodHMAC
|
||||
|
||||
var (
|
||||
HS256 = jwt.SigningMethodHS256
|
||||
HS384 = jwt.SigningMethodHS384
|
||||
|
@ -24,8 +26,8 @@ var (
|
|||
var (
|
||||
ErrTokenExpired = jwt.ErrTokenExpired
|
||||
//errInvalid = errors.New("token is invalid")
|
||||
errClaims = errors.New("claims is not match")
|
||||
errNotMatch = errors.New(" access token and refresh token is not match")
|
||||
errClaims = errors.New("claims is not match")
|
||||
errNotMatch = errors.New(" access token and refresh token is not match")
|
||||
)
|
||||
|
||||
// ------------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue