feat: simplify the use of jwt middleware in gin

This commit is contained in:
zhuyasen 2025-05-11 22:44:34 +08:00
parent 07f23f7b14
commit 3ff6e6efae
8 changed files with 822 additions and 167 deletions

View File

@ -1,7 +1,17 @@
## middleware
Common gin middleware libraries.
Common gin middleware libraries, including:
- [Logging](README.md#logging-middleware)
- [Cors](README.md#allow-cross-domain-requests-middleware)
- [Rate limiter](README.md#rate-limiter-middleware)
- [Circuit breaker](README.md#circuit-breaker-middleware)
- [JWT authorization](README.md#jwt-authorization-middleware)
- [Tracing](README.md#tracing-middleware)
- [Metrics](README.md#metrics-middleware)
- [Request id](README.md#request-id-middleware)
- [Timeout](README.md#timeout-middleware)
<br>
## Example of use
@ -127,113 +137,123 @@ func NewRouter() *gin.Engine {
### JWT authorization middleware
```go
package main
There are two usage examples available:
import (
"time"
"github.com/gin-gonic/gin"
"github.com/go-dev-frame/sponge/pkg/gin/middleware"
"github.com/go-dev-frame/sponge/pkg/gin/response"
"github.com/go-dev-frame/sponge/pkg/jwt"
)
1. **Example One**: This example adopts a highly abstracted design, making it simpler and more convenient to use. Click to view the example at [pkg/gin/middleware/auth](https://github.com/go-dev-frame/sponge/tree/main/pkg/gin/middleware/auth#example-of-use). Requires sponge version `v1.13.2+`.
2. **Example Two**: This example offers greater flexibility and is suitable for scenarios requiring custom implementations. The example code is as follows:
func main() {
r := gin.Default()
// Case 1: default jwt options, signKey, signMethod(HS256), expiry time(24 hour)
{
r.POST("/auth/login", LoginDefault)
r.GET("/demo1/user/:id", middleware.Auth(), GetByID)
r.GET("/demo2/user/:id", middleware.Auth(middleware.WithReturnErrReason()), GetByID)
r.GET("/demo3/user/:id", middleware.Auth(middleware.WithExtraVerify(extraVerifyFn)), GetByID)
}
// Case 2: custom jwt options, signKey, signMethod(HS512), expiry time(12 hour), fields, claims
{
signKey := []byte("custom-sign-key")
jwtAuth1 := middleware.Auth(middleware.WithSignKey(signKey))
jwtAuth2 := middleware.Auth(middleware.WithSignKey(signKey), middleware.WithReturnErrReason())
jwtAuth3 := middleware.Auth(middleware.WithSignKey(signKey), middleware.WithExtraVerify(extraVerifyFn))
r.POST("/auth/login", LoginCustom)
r.GET("/demo4/user/:id", jwtAuth1, GetByID)
r.GET("/demo5/user/:id", jwtAuth2, GetByID)
r.GET("/demo6/user/:id", jwtAuth3, GetByID)
}
r.Run(":8080")
}
func LoginDefault(c *gin.Context) {
// ......
_, token, err := jwt.GenerateToken("100")
response.Success(c, token)
}
func LoginCustom(c *gin.Context) {
// ......
uid := "100"
fields := map[string]interface{}{
"name": "bob",
"age": 10,
"is_vip": true,
}
_, token, err := jwt.GenerateToken(
uid,
jwt.WithGenerateTokenSignKey([]byte("custom-sign-key")),
jwt.WithGenerateTokenSignMethod(jwt.HS512),
jwt.WithGenerateTokenFields(fields),
jwt.WithGenerateTokenClaims([]jwt.RegisteredClaimsOption{
jwt.WithExpires(time.Hour * 12),
//jwt.WithIssuedAt(now),
// jwt.WithSubject("123"),
// jwt.WithIssuer("https://auth.example.com"),
// jwt.WithAudience("https://api.example.com"),
// jwt.WithNotBefore(now),
// jwt.WithJwtID("abc1234xxx"),
}...),
```go
package main
import (
"time"
"github.com/gin-gonic/gin"
"github.com/go-dev-frame/sponge/pkg/gin/middleware"
"github.com/go-dev-frame/sponge/pkg/gin/response"
"github.com/go-dev-frame/sponge/pkg/jwt"
)
response.Success(c, token)
}
func GetByID(c *gin.Context) {
uid := c.MustGet("id").(string)
claims,ok := middleware.GetClaims(c) // if necessary, claims can be got from gin context.
response.Success(c, gin.H{"id": uid})
}
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
// check if token is about to expire (less than 10 minutes remaining)
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
token, err := claims.NewToken(time.Hour*24, jwt.HS256, jwtSignKey) // same signature as jwt.GenerateToken
if err != nil {
return err
func main() {
r := gin.Default()
g := r.Group("/api/v1")
// Case 1: default jwt options, signKey, signMethod(HS256), expiry time(24 hour)
{
r.POST("/auth/login", LoginDefault)
g.Use(middleware.Auth())
//g.Use(middleware.Auth(middleware.WithExtraVerify(extraVerifyFn))) // add extra verify function
}
c.Header("X-Renewed-Token", token)
// Case 2: custom jwt options, signKey, signMethod(HS512), expiry time(48 hour), fields, claims
{
r.POST("/auth/login", LoginCustom)
signKey := []byte("your-sign-key")
g.Use(middleware.Auth(middleware.WithSignKey(signKey)))
//g.Use(middleware.Auth(middleware.WithSignKey(signKey), middleware.WithExtraVerify(extraVerifyFn))) // add extra verify function
}
g.GET("/user/:id", GetByID)
//g.PUT("/user/:id", Create)
//g.DELETE("/user/:id", DeleteByID)
r.Run(":8080")
}
// judge whether the user is disabled, query whether jwt id exists from the blacklist
//if CheckBlackList(uid, claims.ID) {
// return errors.New("user is disabled")
//}
// get fields from claims
//uid := claims.UID
//name, _ := claims.GetString("name")
//age, _ := claims.GetInt("age")
//isVip, _ := claims.GetBool("is_vip")
return nil
}
```
func customGenerateToken(uid string, fields map[string]interface{}) (string, error) {
_, token, err := jwt.GenerateToken(
uid,
jwt.WithGenerateTokenSignKey([]byte("custom-sign-key")),
jwt.WithGenerateTokenSignMethod(jwt.HS512),
jwt.WithGenerateTokenFields(fields),
jwt.WithGenerateTokenClaims([]jwt.RegisteredClaimsOption{
jwt.WithExpires(time.Hour * 48),
//jwt.WithIssuedAt(now),
// jwt.WithSubject("123"),
// jwt.WithIssuer("https://middleware.example.com"),
// jwt.WithAudience("https://api.example.com"),
// jwt.WithNotBefore(now),
// jwt.WithJwtID("abc1234xxx"),
}...),
)
return token, err
}
func LoginDefault(c *gin.Context) {
// ......
_, token, err := jwt.GenerateToken("100")
response.Success(c, token)
}
func LoginCustom(c *gin.Context) {
// ......
uid := "100"
fields := map[string]interface{}{
"name": "bob",
"age": 10,
"is_vip": true,
}
token, err := customGenerateToken(uid, fields)
response.Success(c, token)
}
func GetByID(c *gin.Context) {
uid := c.Param("id")
// if necessary, claims can be got from gin context.
claims, ok := middleware.GetClaims(c)
//uid := claims.UID
//name, _ := claims.GetString("name")
//age, _ := claims.GetInt("age")
//isVip, _ := claims.GetBool("is_vip")
response.Success(c, gin.H{"id": uid})
}
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
// check if token is about to expire (less than 10 minutes remaining)
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
token, err := claims.NewToken(time.Hour*24, jwt.HS512, []byte("your-sign-key")) // same signature as jwt.GenerateToken
if err != nil {
return err
}
c.Header("X-Renewed-Token", token)
}
// judge whether the user is disabled, query whether jwt id exists from the blacklist
//if CheckBlackList(uid, claims.ID) {
// return errors.New("user is disabled")
//}
return nil
}
```
<br>
@ -311,7 +331,7 @@ func NewRouter() *gin.Engine {
<br>
### Request id
### Request id middleware
```go
import (
@ -345,7 +365,7 @@ func NewRouter() *gin.Engine {
<br>
### Timeout
### Timeout middleware
```go
import (

View File

@ -0,0 +1,91 @@
## auth
`auth` middleware for gin framework.
### Example of use
```go
package main
import (
"time"
"github.com/gin-gonic/gin"
"github.com/go-dev-frame/sponge/pkg/gin/middleware/auth"
"github.com/go-dev-frame/sponge/pkg/gin/response"
)
func main() {
r := gin.Default()
// initialize jwt first
auth.InitAuth([]byte("your-sign-key"), time.Hour*24) // default signing method is HS256
// auth.InitAuth([]byte("your-sign-key"), time.Minute*24, WithInitAuthSigningMethod(HS512), WithInitAuthIssuer("foobar.com"))
r.POST("/auth/login", Login)
g := r.Group("/api/v1")
g.Use(auth.Auth())
//g.Use(auth.Auth(auth.WithExtraVerify(extraVerifyFn))) // add extra verify function
g.GET("/user/:id", GetByID)
//g.PUT("/user/:id", Create)
//g.DELETE("/user/:id", DeleteByID)
r.Run(":8080")
}
func Login(c *gin.Context) {
// ......
// Case 1: only uid for token
{
token, err := auth.GenerateToken("100")
}
// Case 2: uid and custom fields for token
{
uid := "100"
fields := map[string]interface{}{
"name": "bob",
"age": 10,
"is_vip": true,
}
token, err := auth.GenerateToken(uid, auth.WithGenerateTokenFields(fields))
}
response.Success(c, token)
}
func GetByID(c *gin.Context) {
uid := c.Param("id")
// if necessary, claims can be got from gin context
claims, ok := auth.GetClaims(c)
//uid := claims.UID
//name, _ := claims.GetString("name")
//age, _ := claims.GetInt("age")
//isVip, _ := claims.GetBool("is_vip")
response.Success(c, gin.H{"id": uid})
}
func extraVerifyFn(claims *auth.Claims, c *gin.Context) error {
// check if token is about to expire (less than 10 minutes remaining)
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
token, err := auth.RefreshToken(claims)
if err != nil {
return err
}
c.Header("X-Renewed-Token", token)
}
// judge whether the user is disabled, query whether jwt id exists from the blacklist
//if CheckBlackList(uid, claims.ID) {
// return errors.New("user is disabled")
//}
return nil
}
```
<br>

View File

@ -0,0 +1,229 @@
// Package auth provides JWT authentication middleware for gin.
package auth
import (
"errors"
"time"
"github.com/gin-gonic/gin"
"github.com/go-dev-frame/sponge/pkg/errcode"
"github.com/go-dev-frame/sponge/pkg/gin/response"
"github.com/go-dev-frame/sponge/pkg/jwt"
)
type SigningMethodHMAC = jwt.SigningMethodHMAC
type Claims = jwt.Claims
var (
HS256 = jwt.HS256
HS384 = jwt.HS384
HS512 = jwt.HS512
)
var (
customSigningKey []byte
customSigningMethod *jwt.SigningMethodHMAC
customExpire time.Duration
customIssuer string
errOption = errors.New("jwt option is nil, please initialize first, call middleware.InitAuth()")
)
type initAuthOptions struct {
issuer string
signingMethod *SigningMethodHMAC
}
func defaultInirAuthOptions() *initAuthOptions {
return &initAuthOptions{
signingMethod: HS256,
}
}
// InitAuthOption set the jwt initAuthOptions.
type InitAuthOption func(*initAuthOptions)
func (o *initAuthOptions) apply(opts ...InitAuthOption) {
for _, opt := range opts {
opt(o)
}
}
// WithInitAuthSigningMethod set signing method value
func WithInitAuthSigningMethod(sm *jwt.SigningMethodHMAC) InitAuthOption {
return func(o *initAuthOptions) {
o.signingMethod = sm
}
}
// WithInitAuthIssuer set issuer value
func WithInitAuthIssuer(issuer string) InitAuthOption {
return func(o *initAuthOptions) {
o.issuer = issuer
}
}
// InitAuth initializes jwt options.
func InitAuth(signingKey []byte, expire time.Duration, opts ...InitAuthOption) {
o := defaultInirAuthOptions()
o.apply(opts...)
customSigningKey = signingKey
customExpire = expire
customSigningMethod = o.signingMethod
customIssuer = o.issuer
}
// GenerateTokenOption set the jwt options.
type GenerateTokenOption func(*generateTokenOptions)
type generateTokenOptions struct {
fields map[string]interface{}
}
func (o *generateTokenOptions) apply(opts ...GenerateTokenOption) {
for _, opt := range opts {
opt(o)
}
}
// WithGenerateTokenFields set custom fields value
func WithGenerateTokenFields(fields map[string]interface{}) GenerateTokenOption {
return func(o *generateTokenOptions) {
o.fields = fields
}
}
// GenerateToken generates a jwt token with the given uid and options.
func GenerateToken(uid string, opts ...GenerateTokenOption) (string, error) {
if customSigningMethod == nil || len(customSigningKey) == 0 {
panic(errOption)
}
genOpts := []jwt.GenerateTokenOption{
jwt.WithGenerateTokenSignKey(customSigningKey),
jwt.WithGenerateTokenSignMethod(customSigningMethod),
}
o := &generateTokenOptions{}
o.apply(opts...)
if len(o.fields) > 0 {
genOpts = append(genOpts, jwt.WithGenerateTokenFields(o.fields))
}
claimsOpts := []jwt.RegisteredClaimsOption{
jwt.WithExpires(customExpire),
}
if customIssuer != "" {
claimsOpts = append(claimsOpts, jwt.WithIssuer(customIssuer))
}
genOpts = append(genOpts, jwt.WithGenerateTokenClaims(claimsOpts...))
_, token, err := jwt.GenerateToken(uid, genOpts...)
return token, err
}
// ParseToken parses the given token and returns the claims.
func ParseToken(token string) (*jwt.Claims, error) {
if customSigningMethod == nil {
panic(errOption)
}
return jwt.ValidateToken(token, jwt.WithValidateTokenSignKey(customSigningKey))
}
// RefreshToken create a new token with the given claims.
func RefreshToken(claims *jwt.Claims) (string, error) {
return claims.NewToken(customExpire, customSigningMethod, customSigningKey)
}
// -------------------------------------------------------------------------------------------
// HeaderAuthorizationKey http header authorization key, value is "Bearer token"
const HeaderAuthorizationKey = "Authorization"
// ExtraVerifyFn extra verify function
type ExtraVerifyFn = func(claims *jwt.Claims, c *gin.Context) error
// AuthOption set the auth options.
type AuthOption func(*authOptions)
type authOptions struct {
isReturnErrReason bool
extraVerifyFn ExtraVerifyFn
}
func defaultAuthOptions() *authOptions {
return &authOptions{}
}
func (o *authOptions) apply(opts ...AuthOption) {
for _, opt := range opts {
opt(o)
}
}
// WithReturnErrReason set return error reason
func WithReturnErrReason() AuthOption {
return func(o *authOptions) {
o.isReturnErrReason = true
}
}
// WithExtraVerify set extra verify function
func WithExtraVerify(fn ExtraVerifyFn) AuthOption {
return func(o *authOptions) {
o.extraVerifyFn = fn
}
}
func responseUnauthorized(isReturnErrReason bool, errMsg string) *errcode.Error {
if isReturnErrReason {
return errcode.Unauthorized.RewriteMsg("Unauthorized, " + errMsg)
}
return errcode.Unauthorized
}
// Auth authorization middleware, support custom extra verify.
func Auth(opts ...AuthOption) gin.HandlerFunc {
o := defaultAuthOptions()
o.apply(opts...)
return func(c *gin.Context) {
authorization := c.GetHeader(HeaderAuthorizationKey)
if len(authorization) < 100 {
response.Out(c, responseUnauthorized(o.isReturnErrReason, "token is illegal"))
c.Abort()
return
}
tokenString := authorization[7:] // remove Bearer prefix
claims, err := ParseToken(tokenString)
if err != nil {
response.Out(c, responseUnauthorized(o.isReturnErrReason, err.Error()))
c.Abort()
return
}
// extra verify function
if o.extraVerifyFn != nil {
if err = o.extraVerifyFn(claims, c); err != nil {
response.Out(c, responseUnauthorized(o.isReturnErrReason, err.Error()))
c.Abort()
return
}
}
c.Set("claims", claims) // set claims to context
c.Next()
}
}
// GetClaims get jwt claims from gin context.
func GetClaims(c *gin.Context) (*jwt.Claims, bool) {
claims, exists := c.Get("claims")
if !exists {
return nil, false
}
jwtClaims, ok := claims.(*jwt.Claims)
return jwtClaims, ok
}

View File

@ -0,0 +1,207 @@
package auth
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/go-dev-frame/sponge/pkg/gin/response"
"github.com/go-dev-frame/sponge/pkg/httpcli"
"github.com/go-dev-frame/sponge/pkg/jwt"
"github.com/go-dev-frame/sponge/pkg/utils"
)
var (
uid = "100"
fields = map[string]interface{}{
"name": "bob",
"age": 10,
"is_vip": true,
}
jwtSignKey = []byte("your-secret-key")
errMsg = http.StatusText(http.StatusUnauthorized)
compareMsgFn = func(em string) bool {
return strings.Contains(em, errMsg)
}
)
func extraVerifyFn(claims *jwt.Claims, c *gin.Context) error {
// check if token is about to expire (less than 10 minutes remaining)
if time.Now().Unix()-claims.ExpiresAt.Unix() < int64(time.Minute*10) {
token, err := RefreshToken(claims)
if err != nil {
return err
}
c.Header("X-Renewed-Token", token)
}
// judge whether the user is disabled, query whether jwt id exists from the blacklist
//if CheckBlackList(uid, claims.ID) {
// return errors.New("user is disabled")
//}
// check fields
if claims.UID != uid {
return fmt.Errorf("uid not match, expect %s, got %s", uid, claims.UID)
}
if name, _ := claims.GetString("name"); name != fields["name"] {
return fmt.Errorf("name not match, expect %s, got %s", fields["name"], name)
}
if age, _ := claims.GetInt("age"); age != fields["age"] {
return fmt.Errorf("age not match, expect %d, got %d", fields["age"], age)
}
if isVip, _ := claims.GetBool("is_vip"); isVip != fields["is_vip"] {
return fmt.Errorf("is_vip not match, expect %v, got %v", fields["is_vip"], isVip)
}
return nil
}
func runAuthHTTPServer() string {
serverAddr, requestAddr := utils.GetLocalHTTPAddrPairs()
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
loginHandler := func(c *gin.Context) {
token, _ := GenerateToken(uid)
fmt.Println("token1 =", token)
response.Success(c, token)
}
loginCustomFieldsHandler := func(c *gin.Context) {
token, _ := GenerateToken(uid, WithGenerateTokenFields(fields))
fmt.Println("token2 =", token)
response.Success(c, token)
}
getUserByIDHandler := func(c *gin.Context) {
id := c.Param("id")
claims, ok := GetClaims(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"msg": "unauthorized"})
return
}
fmt.Println("claims =", claims)
response.Success(c, id)
}
r.GET("/auth/login", loginHandler)
r.GET("/auth/loginCustomFields", loginCustomFieldsHandler)
r.GET("/user/:id", Auth(), getUserByIDHandler)
r.GET("/user/log/:id", Auth(WithReturnErrReason()), getUserByIDHandler)
r.GET("/user/extra_verify/:id", Auth(WithExtraVerify(extraVerifyFn), WithReturnErrReason()), getUserByIDHandler)
go func() {
err := r.Run(serverAddr)
if err != nil {
panic(err)
}
}()
time.Sleep(time.Millisecond * 200)
return requestAddr
}
func getUser(url string, authorization string) (gin.H, error) {
var result = gin.H{}
client := &http.Client{}
request, err := http.NewRequest("GET", url, nil)
request.Header.Add("Authorization", authorization)
if err != nil {
return result, err
}
resp, _ := client.Do(request)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return result, err
}
err = json.Unmarshal(data, &result)
return result, err
}
func TestAuth(t *testing.T) {
requestAddr := runAuthHTTPServer()
InitAuth(jwtSignKey, time.Minute*10)
//InitAuth(jwtSignKey, time.Minute*10, WithInitAuthSigningMethod(HS512), WithInitAuthIssuer("foobar.com"))
t.Run("only uid for generate token", func(t *testing.T) {
// get token
result := &httpcli.StdResult{}
err := httpcli.Get(result, requestAddr+"/auth/login")
if err != nil {
t.Fatal(err)
}
token := result.Data.(string)
authorization := fmt.Sprintf("Bearer %s", token)
// success
val, err := getUser(requestAddr+"/user/"+uid, authorization)
assert.Equal(t, val["data"], uid)
// success
val, err = getUser(requestAddr+"/user/log/"+uid, authorization)
assert.Equal(t, val["data"], uid)
// return 401, the reason is token have no extra field
val, err = getUser(requestAddr+"/user/extra_verify/"+uid, authorization)
assert.Equal(t, true, compareMsgFn(val["msg"].(string)))
// return 401, the reason is token value is invalid
val, err = getUser(requestAddr+"/user/"+uid, "error-authorization")
assert.Equal(t, val["msg"], errMsg)
})
t.Run("uid and fields for generate token", func(t *testing.T) {
// get token
result := &httpcli.StdResult{}
err := httpcli.Get(result, requestAddr+"/auth/loginCustomFields")
if err != nil {
t.Fatal(err)
}
token := result.Data.(string)
authorization := fmt.Sprintf("Bearer %s", token)
// success
val, err := getUser(requestAddr+"/user/"+uid, authorization)
assert.Equal(t, val["data"], uid)
// success
val, err = getUser(requestAddr+"/user/log/"+uid, authorization)
assert.Equal(t, val["data"], uid)
// return 401, the reason is token expired
token = "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJ1aWQiOiIxMDAiLCJmaWVsZHMiOnsiYWdlIjoxMCwiaXNfdmlwIjp0cnVlLCJuYW1lIjoiYm9iIn0sImV4cCI6MTc0NjY0MTY0MCwiaWF0IjoxNzQ2NjQxMDQwLCJqdGkiOiIxODNkNTBjNWIxZTdmMTEwIn0.P11q5VPo-88Sbw4JKLtp2_Aiz8Pc1oL-jrdEAX0NwJJoxnR_Iu8W6eI7CsUCzVGW"
authorization = fmt.Sprintf("Bearer %s", token)
val, err = getUser(requestAddr+"/user/extra_verify/"+uid, authorization)
assert.Equal(t, true, compareMsgFn(val["msg"].(string)))
// return 401, the reason is token value is invalid
val, err = getUser(requestAddr+"/user/"+uid, "error-authorization")
assert.Equal(t, val["msg"], errMsg)
})
}
func TestError(t *testing.T) {
t.Run("GenerateToken error", func(t *testing.T) {
defer func() { recover() }()
GenerateToken("100")
})
t.Run("ParseToken error", func(t *testing.T) {
defer func() { recover() }()
ParseToken("xxx")
})
}

View File

@ -8,16 +8,16 @@ import (
valid "github.com/go-playground/validator/v10"
)
// Init request body file valid
// Init validator instance, used to gin request parameter check
func Init() *CustomValidator {
validator := NewCustomValidator()
validator.Engine()
return validator
v := NewCustomValidator()
v.Engine()
return v
}
// CustomValidator Custom valid objects
type CustomValidator struct {
Once sync.Once
once sync.Once
Validate *valid.Validate
}
@ -26,37 +26,52 @@ func NewCustomValidator() *CustomValidator {
return &CustomValidator{}
}
// ValidateStruct Instantiate struct valid
// ValidateStruct validates a struct or slice/array
func (v *CustomValidator) ValidateStruct(obj interface{}) error {
if kindOfData(obj) == reflect.Struct {
v.lazyinit()
if obj == nil {
return nil
}
val := reflect.ValueOf(obj)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
switch val.Kind() {
case reflect.Struct:
if err := v.Validate.Struct(obj); err != nil {
return err
}
case reflect.Ptr:
// pointer type: if nil, no validation required; otherwise recursive validation after dereference
if val.IsNil() {
return nil
}
return v.ValidateStruct(val.Elem().Interface())
case reflect.Slice, reflect.Array:
// slice or array type: iterates over each element, recursively validating one by one
for i := 0; i < val.Len(); i++ {
elem := val.Index(i)
if err := v.ValidateStruct(elem.Interface()); err != nil {
return err
}
}
}
return nil
}
// Engine Instantiate valid
// Engine set tag name "binding", which is implementing the validator interface of the gin framework
func (v *CustomValidator) Engine() interface{} {
v.lazyinit()
v.lazyInit()
return v.Validate
}
func (v *CustomValidator) lazyinit() {
v.Once.Do(func() {
func (v *CustomValidator) lazyInit() {
v.once.Do(func() {
v.Validate = valid.New()
v.Validate.SetTagName("binding")
})
}
func kindOfData(data interface{}) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"strings"
"testing"
@ -420,39 +419,122 @@ func do(method string, url string, body interface{}) ([]byte, error) {
// ------------------------------------------------------------------------------------------
type st struct {
Name string
}
func Test_CustomValidator_ValidateStruct(t *testing.T) {
type User struct {
Name string `binding:"required"`
Age int `binding:"gte=18"`
}
func TestCustomValidator_Engine(t *testing.T) {
validator := NewCustomValidator()
v := validator.Engine()
assert.NotNil(t, v)
}
type UserList1 struct {
Users []User `binding:"required,dive"`
}
func TestCustomValidator_ValidateStruct(t *testing.T) {
validator := NewCustomValidator()
err := validator.ValidateStruct(new(st))
assert.NoError(t, err)
}
type UserList2 struct {
Users []*User `binding:"required,dive"`
}
func TestCustomValidator_lazyinit(t *testing.T) {
validator := NewCustomValidator()
validator.lazyinit()
}
func TestInit(t *testing.T) {
validator := Init()
assert.NotNil(t, validator)
user := &User{Name: "John", Age: 10}
if err := validator.ValidateStruct(user); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
var u = &User{Name: "John", Age: 11}
if err := validator.ValidateStruct(&u); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
users := []User{{Name: "Alice", Age: 25}, {Name: "Bob", Age: 17}}
if err := validator.ValidateStruct(users); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
userList := UserList1{}
if err := validator.ValidateStruct(&userList); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
userList1 := UserList1{
Users: []User{{Name: "Charlie", Age: 10}, {Name: "", Age: 30}},
}
if err := validator.ValidateStruct(&userList1); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
userList2 := UserList2{
Users: []*User{{Name: "Charlie", Age: 30}, {Name: "", Age: 40}},
}
if err := validator.ValidateStruct(&userList2); err != nil {
assert.NotNil(t, err)
t.Log(err)
}
}
func TestNewCustomValidator(t *testing.T) {
validator := NewCustomValidator()
assert.NotNil(t, validator)
}
func Benchmark_CustomValidator_ValidateStruct(b *testing.B) {
type User struct {
Name string `binding:"required"`
Age int `binding:"gte=18"`
}
func Test_kindOfData(t *testing.T) {
type UserList1 struct {
Users []User `binding:"required,dive"` // 验证指针切片
}
kind := kindOfData(new(st))
assert.Equal(t, reflect.Struct, kind)
type UserList2 struct {
Users []*User `binding:"required,dive"` // 验证指针切片
}
validator := Init()
b.Run("User struct", func(b *testing.B) {
user := User{Name: "John", Age: 10}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(user)
}
})
b.Run("User struct pointer", func(b *testing.B) {
user := &User{Name: "John", Age: 10}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(user)
}
})
b.Run("User struct pointer pointer", func(b *testing.B) {
var u = &User{Name: "John", Age: 11}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(&u)
}
})
b.Run("User slice", func(b *testing.B) {
users := []User{{Name: "Alice", Age: 25}, {Name: "Bob", Age: 17}}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(users)
}
})
b.Run("UserList slice struct", func(b *testing.B) {
userList1 := UserList1{
Users: []User{{Name: "Charlie", Age: 10}, {Name: "", Age: 30}},
}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(&userList1)
}
})
b.Run("UserList slice struct pointer", func(b *testing.B) {
userList2 := UserList2{
Users: []*User{{Name: "Charlie", Age: 30}, {Name: "", Age: 40}},
}
for i := 0; i < b.N; i++ {
_ = validator.ValidateStruct(&userList2)
}
})
}

View File

@ -2,16 +2,16 @@
Common interceptors for gRPC server and client side, including:
- Logging
- Recovery
- Retry
- Rate limiter
- Circuit breaker
- Timeout
- Tracing
- Request id
- Metrics
- JWT authentication
- [Logging](README.md#logging-interceptor)
- [Recovery](README.md#recovery-interceptor)
- [Retry](README.md#retry-interceptor)
- [Rate limiter](README.md#rate-limiter-interceptor)
- [Circuit breaker](README.md#circuit-breaker-interceptor)
- [Timeout](README.md#timeout-interceptor)
- [Tracing](README.md#tracing-interceptor)
- [Request id](README.md#request-id-interceptor)
- [Metrics](README.md#metrics-interceptor)
- [JWT authentication](README.md#jwt-authentication-interceptor)
<br>
@ -492,6 +492,15 @@ func extraVerifyFn(ctx context.Context, claims *jwt.Claims) error {
return nil
}
// GetByID ...
func (s *user) GetByID(ctx context.Context, req *userV1.GetByIDRequest) (*userV1.GetByIDReply, error) {
// ......
claims,ok := interceptor.GetJwtClaims(ctx) // if necessary, claims can be got from gin context.
// ......
}
```
**gRPC client side**

View File

@ -9,6 +9,8 @@ import (
"github.com/go-dev-frame/sponge/pkg/krand"
)
type SigningMethodHMAC = jwt.SigningMethodHMAC
var (
HS256 = jwt.SigningMethodHS256
HS384 = jwt.SigningMethodHS384
@ -24,8 +26,8 @@ var (
var (
ErrTokenExpired = jwt.ErrTokenExpired
//errInvalid = errors.New("token is invalid")
errClaims = errors.New("claims is not match")
errNotMatch = errors.New(" access token and refresh token is not match")
errClaims = errors.New("claims is not match")
errNotMatch = errors.New(" access token and refresh token is not match")
)
// ------------------------------------------------------------------------------------------