feat: for custom query, add whitelist judgment

This commit is contained in:
zhuyasen 2025-05-18 23:23:08 +08:00
parent 7c90a8da10
commit fed822c5b7
12 changed files with 306 additions and 215 deletions

View File

@ -175,41 +175,10 @@ func (d *userExampleDao) GetByID(ctx context.Context, id uint64) (*model.UserExa
return nil, err
}
// GetByColumns get paging records by column information,
// Note: query performance degrades when table rows are very large because of the use of offset.
//
// params includes paging parameters and query parameters
// paging parameters (required):
//
// page: page number, starting from 0
// limit: lines per page
// sort: sort fields, default is id backwards, you can add - sign before the field to indicate reverse order, no - sign to indicate ascending order, multiple fields separated by comma
//
// query parameters (not required):
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: search for a male over 20 years of age
//
// params = &query.Params{
// Page: 0,
// Limit: 20,
// Columns: []query.Column{
// {
// Name: "age",
// Exp: ">",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// GetByColumns get paging records by column information.
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html
func (d *userExampleDao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.UserExample, int64, error) {
queryStr, args, err := params.ConvertToGormConditions()
queryStr, args, err := params.ConvertToGormConditions(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}

View File

@ -181,41 +181,10 @@ func (d *userExampleDao) GetByID(ctx context.Context, id uint64) (*model.UserExa
return nil, err
}
// GetByColumns get paging records by column information,
// Note: query performance degrades when table rows are very large because of the use of offset.
//
// params includes paging parameters and query parameters
// paging parameters (required):
//
// page: page number, starting from 0
// limit: lines per page
// sort: sort fields, default is id backwards, you can add - sign before the field to indicate reverse order, no - sign to indicate ascending order, multiple fields separated by comma
//
// query parameters (not required):
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: search for a male over 20 years of age
//
// params = &query.Params{
// Page: 0,
// Limit: 20,
// Columns: []query.Column{
// {
// Name: "age",
// Exp: ">",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// GetByColumns get paging records by column information.
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html
func (d *userExampleDao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.UserExample, int64, error) {
queryStr, args, err := params.ConvertToGormConditions()
queryStr, args, err := params.ConvertToGormConditions(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}
@ -256,29 +225,10 @@ func (d *userExampleDao) DeleteByIDs(ctx context.Context, ids []uint64) error {
return nil
}
// GetByCondition get a record by condition
// query conditions:
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: find a male aged 20
//
// condition = &query.Conditions{
// Columns: []query.Column{
// {
// Name: "age",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// GetByCondition get a record by condition.
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html#_2-condition-parameters-optional
func (d *userExampleDao) GetByCondition(ctx context.Context, c *query.Conditions) (*model.UserExample, error) {
queryStr, args, err := c.ConvertToGorm()
queryStr, args, err := c.ConvertToGorm(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, err
}

View File

@ -187,44 +187,13 @@ func (d *{{.TableNameCamelFCL}}Dao) GetBy{{.ColumnNameCamel}}(ctx context.Contex
return nil, err
}
// GetByColumns get paging records by column information,
// Note: query performance degrades when table rows are very large because of the use of offset.
//
// params includes paging parameters and query parameters
// paging parameters (required):
//
// page: page number, starting from 0
// limit: lines per page
// sort: sort fields, default is {{.ColumnNameCamelFCL}} backwards, you can add - sign before the field to indicate reverse order, no - sign to indicate ascending order, multiple fields separated by comma
//
// query parameters (not required):
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: search for a male over 20 years of age
//
// params = &query.Params{
// Page: 0,
// Limit: 20,
// Columns: []query.Column{
// {
// Name: "age",
// Exp: ">",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// GetByColumns get paging records by column information.
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html
func (d *{{.TableNameCamelFCL}}Dao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.{{.TableNameCamel}}, int64, error) {
if params.Sort == "" {
params.Sort = "-{{.ColumnName}}"
}
queryStr, args, err := params.ConvertToGormConditions()
queryStr, args, err := params.ConvertToGormConditions(query.WithWhitelistNames(model.{{.TableNameCamel}}ColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}
@ -266,28 +235,9 @@ func (d *{{.TableNameCamelFCL}}Dao) DeleteBy{{.ColumnNamePluralCamel}}(ctx conte
}
// GetByCondition get a record by condition
// query conditions:
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: find a male aged 20
//
// condition = &query.Conditions{
// Columns: []query.Column{
// {
// Name: "age",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html#_2-condition-parameters-optional
func (d *{{.TableNameCamelFCL}}Dao) GetByCondition(ctx context.Context, c *query.Conditions) (*model.{{.TableNameCamel}}, error) {
queryStr, args, err := c.ConvertToGorm()
queryStr, args, err := c.ConvertToGorm(query.WithWhitelistNames(model.{{.TableNameCamel}}ColumnNames))
if err != nil {
return nil, err
}

View File

@ -225,9 +225,13 @@ func (d *userExampleDao) GetByID(ctx context.Context, id string) (*model.UserExa
// Name: "gender",
// Value: "male",
// },
// {
// Name: "post_id:oid", // suffix :oid is required for objectId type
// Value: "65ce48483f11aff697e30d6d",
// },
// }
func (d *userExampleDao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.UserExample, int64, error) {
filter, err := params.ConvertToMongoFilter()
filter, err := params.ConvertToMongoFilter(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}

View File

@ -230,9 +230,13 @@ func (d *userExampleDao) GetByID(ctx context.Context, id string) (*model.UserExa
// Name: "gender",
// Value: "male",
// },
// {
// Name: "post_id:oid", // suffix :oid is required for objectId type
// Value: "65ce48483f11aff697e30d6d",
// },
// }
func (d *userExampleDao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.UserExample, int64, error) {
filter, err := params.ConvertToMongoFilter()
filter, err := params.ConvertToMongoFilter(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}
@ -297,12 +301,12 @@ func (d *userExampleDao) DeleteByIDs(ctx context.Context, ids []string) error {
// Value: "James",
// },
// {
// Name: "post_id:oid",
// Name: "post_id:oid", // suffix :oid is required for objectId type
// Value: "65ce48483f11aff697e30d6d",
// },
// }
func (d *userExampleDao) GetByCondition(ctx context.Context, c *query.Conditions) (*model.UserExample, error) {
filter, err := c.ConvertToMongo()
filter, err := c.ConvertToMongo(query.WithWhitelistNames(model.UserExampleColumnNames))
if err != nil {
return nil, err
}

View File

@ -181,44 +181,13 @@ func (d *{{.TableNameCamelFCL}}Dao) GetBy{{.ColumnNameCamel}}(ctx context.Contex
return nil, err
}
// GetByColumns get paging records by column information,
// Note: query performance degrades when table rows are very large because of the use of offset.
//
// params includes paging parameters and query parameters
// paging parameters (required):
//
// page: page number, starting from 0
// limit: lines per page
// sort: sort fields, default is {{.ColumnNameCamelFCL}} backwards, you can add - sign before the field to indicate reverse order, no - sign to indicate ascending order, multiple fields separated by comma
//
// query parameters (not required):
//
// name: column name
// exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in, notin, isnull, isnotnull
// value: column value, if exp=in, multiple values are separated by commas
// logic: logical type, default value is "and", support &, and, ||, or
//
// example: search for a male over 20 years of age
//
// params = &query.Params{
// Page: 0,
// Limit: 20,
// Columns: []query.Column{
// {
// Name: "age",
// Exp: ">",
// Value: 20,
// },
// {
// Name: "gender",
// Value: "male",
// },
// }
// GetByColumns get paging records by column information.
// For more details, please refer to https://go-sponge.com/component/custom-page-query.html
func (d *{{.TableNameCamelFCL}}Dao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.{{.TableNameCamel}}, int64, error) {
if params.Sort == "" {
params.Sort = "-{{.ColumnName}}"
}
queryStr, args, err := params.ConvertToGormConditions()
queryStr, args, err := params.ConvertToGormConditions(query.WithWhitelistNames(model.{{.TableNameCamel}}ColumnNames))
if err != nil {
return nil, 0, errors.New("query params error: " + err.Error())
}

View File

@ -27,4 +27,21 @@ func (table *UserExample) TableName() string {
return "user_example"
}
// UserExampleColumnNames Whitelist for custom query fields to prevent sql injection attacks
var UserExampleColumnNames = map[string]bool{
"id": true,
"created_at": true,
"updated_at": true,
"deleted_at": true,
"name": true,
"password": true,
"email": true,
"phone": true,
"avatar": true,
"age": true,
"gender": true,
"status": true,
"login_at": true,
}
// delete the templates code end

View File

@ -35,7 +35,7 @@ type initAuthOptions struct {
signingMethod *SigningMethodHMAC
}
func defaultInirAuthOptions() *initAuthOptions {
func defaultInitAuthOptions() *initAuthOptions {
return &initAuthOptions{
signingMethod: HS256,
}
@ -66,7 +66,7 @@ func WithInitAuthIssuer(issuer string) InitAuthOption {
// InitAuth initializes jwt options.
func InitAuth(signingKey []byte, expire time.Duration, opts ...InitAuthOption) {
o := defaultInirAuthOptions()
o := defaultInitAuthOptions()
o.apply(opts...)
customSigningKey = signingKey

View File

@ -78,6 +78,38 @@ var logicMap = map[string]string{
orSymbol2: orSymbol1,
}
// ---------------------------------------------------------------------------
type rulerOptions struct {
whitelistNames map[string]bool
validateFn func(columns []Column) error
}
// RulerOption set the parameters of ruler options
type RulerOption func(*rulerOptions)
func (o *rulerOptions) apply(opts ...RulerOption) {
for _, opt := range opts {
opt(o)
}
}
// WithWhitelistNames set white list names of columns
func WithWhitelistNames(whitelistNames map[string]bool) RulerOption {
return func(o *rulerOptions) {
o.whitelistNames = whitelistNames
}
}
// WithValidateFn set validate function of columns
func WithValidateFn(fn func(columns []Column) error) RulerOption {
return func(o *rulerOptions) {
o.validateFn = fn
}
}
// -----------------------------------------------------------------------------
// Params query parameters
type Params struct {
Page int `json:"page" form:"page" binding:"gte=0"`
@ -98,6 +130,13 @@ type Column struct {
Logic string `json:"logic" form:"logic"` // logical type, defaults to and when the value is null, with &(and), ||(or)
}
func (c *Column) checkName(whitelists map[string]bool) error {
if c.Name == "" || (whitelists != nil && !whitelists[c.Name]) {
return fmt.Errorf("field name '%s' is not allowed", c.Name)
}
return nil
}
func (c *Column) checkValid() error {
if c.Name == "" {
return fmt.Errorf("field 'name' cannot be empty")
@ -187,7 +226,16 @@ func (p *Params) ConvertToPage() (sort bson.D, limit int, skip int) { //nolint
// ConvertToMongoFilter conversion to mongo-compliant parameters based on the Columns parameter
// ignore the logical type of the last column, whether it is a one-column or multi-column query
func (p *Params) ConvertToMongoFilter() (bson.M, error) {
func (p *Params) ConvertToMongoFilter(opts ...RulerOption) (bson.M, error) {
o := rulerOptions{}
o.apply(opts...)
if o.validateFn != nil {
err := o.validateFn(p.Columns)
if err != nil {
return nil, err
}
}
filter := bson.M{}
l := len(p.Columns)
switch l {
@ -195,7 +243,11 @@ func (p *Params) ConvertToMongoFilter() (bson.M, error) {
return bson.M{}, nil
case 1: // l == 1
err := p.Columns[0].convert()
err := p.Columns[0].checkName(o.whitelistNames)
if err != nil {
return nil, err
}
err = p.Columns[0].convert()
if err != nil {
return nil, err
}
@ -203,7 +255,15 @@ func (p *Params) ConvertToMongoFilter() (bson.M, error) {
return filter, nil
case 2: // l == 2
err := p.Columns[0].convert()
err := p.Columns[0].checkName(o.whitelistNames)
if err != nil {
return nil, err
}
err = p.Columns[1].checkName(o.whitelistNames)
if err != nil {
return nil, err
}
err = p.Columns[0].convert()
if err != nil {
return nil, err
}
@ -223,11 +283,11 @@ func (p *Params) ConvertToMongoFilter() (bson.M, error) {
return filter, nil
default: // l >=3
return p.convertMultiColumns()
return p.convertMultiColumns(o.whitelistNames)
}
}
func (p *Params) convertMultiColumns() (bson.M, error) {
func (p *Params) convertMultiColumns(whitelistNames map[string]bool) (bson.M, error) {
filter := bson.M{}
logicType, groupIndexes, err := checkSameLogic(p.Columns)
if err != nil {
@ -235,7 +295,12 @@ func (p *Params) convertMultiColumns() (bson.M, error) {
}
if logicType == allLogicAnd {
for _, column := range p.Columns {
err := column.convert()
err = column.checkName(whitelistNames)
if err != nil {
return nil, err
}
err = column.convert()
if err != nil {
return nil, err
}
@ -251,7 +316,7 @@ func (p *Params) convertMultiColumns() (bson.M, error) {
return filter, nil
} else if logicType == allLogicOr {
for _, column := range p.Columns {
err := column.convert()
err = column.convert()
if err != nil {
return nil, err
}
@ -376,7 +441,7 @@ func (c *Conditions) CheckValid() error {
// ConvertToMongo conversion to mongo-compliant parameters based on the Columns parameter
// ignore the logical type of the last column, whether it is a one-column or multi-column query
func (c *Conditions) ConvertToMongo() (bson.M, error) {
func (c *Conditions) ConvertToMongo(opts ...RulerOption) (bson.M, error) {
p := &Params{Columns: c.Columns}
return p.ConvertToMongoFilter()
return p.ConvertToMongoFilter(opts...)
}

View File

@ -1,6 +1,7 @@
package query
import (
"errors"
"reflect"
"testing"
@ -538,6 +539,38 @@ func TestParams_ConvertToMongoFilter(t *testing.T) {
}
}
func TestParams_ConvertToMongoFilter_Error(t *testing.T) {
p := &Params{
Limit: 10,
Columns: []Column{
{
Name: "age",
Value: 10,
},
{
Name: "email",
Value: "foo@bar.com",
},
}}
whitelists := map[string]bool{"name": true, "age": true}
_, err := p.ConvertToMongoFilter(WithWhitelistNames(whitelists))
t.Log(err)
assert.Error(t, err)
fn := func(columns []Column) error {
for _, col := range columns {
if col.Value == "foo@bar.com" {
return errors.New("'foo@bar.com' is not allowed")
}
}
return nil
}
_, err = p.ConvertToMongoFilter(WithValidateFn(fn))
t.Log(err)
assert.Error(t, err)
}
func TestConditions_ConvertToMongo(t *testing.T) {
c := Conditions{
Columns: []Column{
@ -667,3 +700,33 @@ func Test_getSort(t *testing.T) {
t.Log(d)
}
}
func TestConditions_ConvertToMongo_Error(t *testing.T) {
c := Conditions{Columns: []Column{
{
Name: "age",
Value: 10,
},
{
Name: "email",
Value: "foo@bar.com",
},
}}
whitelists := map[string]bool{"name": true, "age": true}
_, err := c.ConvertToMongo(WithWhitelistNames(whitelists))
t.Log(err)
assert.Error(t, err)
fn := func(columns []Column) error {
for _, col := range columns {
if col.Value == "foo@bar.com" {
return errors.New("'foo@bar.com' is not allowed")
}
}
return nil
}
_, err = c.ConvertToMongo(WithValidateFn(fn))
t.Log(err)
assert.Error(t, err)
}

View File

@ -77,6 +77,38 @@ var logicMap = map[string]string{
"or:)": " OR ",
}
// ---------------------------------------------------------------------------
type rulerOptions struct {
whitelistNames map[string]bool
validateFn func(columns []Column) error
}
// RulerOption set the parameters of ruler options
type RulerOption func(*rulerOptions)
func (o *rulerOptions) apply(opts ...RulerOption) {
for _, opt := range opts {
opt(o)
}
}
// WithWhitelistNames set white list names of columns
func WithWhitelistNames(whitelistNames map[string]bool) RulerOption {
return func(o *rulerOptions) {
o.whitelistNames = whitelistNames
}
}
// WithValidateFn set validate function of columns
func WithValidateFn(fn func(columns []Column) error) RulerOption {
return func(o *rulerOptions) {
o.validateFn = fn
}
}
// -----------------------------------------------------------------------------
// Params query parameters
type Params struct {
Page int `json:"page" form:"page" binding:"gte=0"`
@ -97,22 +129,8 @@ type Column struct {
Logic string `json:"logic" form:"logic"` // logical type, defaults to and when the value is null, with &(and), ||(or)
}
func (c *Column) checkValid() error {
if c.Name == "" {
return fmt.Errorf("field 'name' cannot be empty")
}
if c.Value == nil {
v := expMap[strings.ToLower(c.Exp)]
if v == " IS NULL " || v == " IS NOT NULL " {
return nil
}
return fmt.Errorf("field 'value' cannot be nil")
}
return nil
}
// converting ExpType to sql expressions and LogicType to sql using characters
func (c *Column) convert() (string, error) {
func (c *Column) checkExp() (string, error) {
symbol := "?"
if c.Exp == "" {
c.Exp = Eq
@ -185,7 +203,7 @@ func (p *Params) ConvertToPage() (order string, limit int, offset int) { //nolin
// ConvertToGormConditions conversion to gorm-compliant parameters based on the Columns parameter
// ignore the logical type of the last column, whether it is a one-column or multi-column query
func (p *Params) ConvertToGormConditions() (string, []interface{}, error) {
func (p *Params) ConvertToGormConditions(opts ...RulerOption) (string, []interface{}, error) { //nolint
str := ""
args := []interface{}{}
l := len(p.Columns)
@ -199,12 +217,31 @@ func (p *Params) ConvertToGormConditions() (string, []interface{}, error) {
}
field := p.Columns[0].Name
for i, column := range p.Columns {
if err := column.checkValid(); err != nil {
o := rulerOptions{}
o.apply(opts...)
if o.validateFn != nil {
err := o.validateFn(p.Columns)
if err != nil {
return "", nil, err
}
}
symbol, err := column.convert()
for i, column := range p.Columns {
// check name
if column.Name == "" || (o.whitelistNames != nil && !o.whitelistNames[column.Name]) {
return "", nil, fmt.Errorf("field name '%s' is not allowed", column.Name)
}
// check value
if column.Value == nil {
v := expMap[strings.ToLower(column.Exp)]
if v != " IS NULL " && v != " IS NOT NULL " {
return "", nil, fmt.Errorf("field 'value' cannot be nil")
}
}
// check exp
symbol, err := column.checkExp()
if err != nil {
return "", nil, err
}
@ -256,9 +293,9 @@ type Conditions struct {
// ConvertToGorm conversion to gorm-compliant parameters based on the Columns parameter
// ignore the logical type of the last column, whether it is a one-column or multi-column query
func (c *Conditions) ConvertToGorm() (string, []interface{}, error) {
func (c *Conditions) ConvertToGorm(opts ...RulerOption) (string, []interface{}, error) {
p := &Params{Columns: c.Columns}
return p.ConvertToGormConditions()
return p.ConvertToGormConditions(opts...)
}
// CheckValid check valid

View File

@ -1,6 +1,7 @@
package query
import (
"errors"
"reflect"
"strings"
"testing"
@ -617,6 +618,38 @@ func TestParams_ConvertToGormConditions(t *testing.T) {
}
}
func TestConditions_ConvertToGormConditions_Error(t *testing.T) {
p := &Params{
Limit: 10,
Columns: []Column{
{
Name: "age",
Value: 10,
},
{
Name: "email",
Value: "foo@bar.com",
},
}}
whitelists := map[string]bool{"name": true, "age": true}
_, _, err := p.ConvertToGormConditions(WithWhitelistNames(whitelists))
t.Log(err)
assert.Error(t, err)
fn := func(columns []Column) error {
for _, col := range columns {
if col.Value == "foo@bar.com" {
return errors.New("'foo@bar.com' is not allowed")
}
}
return nil
}
_, _, err = p.ConvertToGormConditions(WithValidateFn(fn))
t.Log(err)
assert.Error(t, err)
}
func TestConditions_ConvertToGorm(t *testing.T) {
c := Conditions{
Columns: []Column{
@ -636,3 +669,33 @@ func TestConditions_ConvertToGorm(t *testing.T) {
assert.Equal(t, "name = ? AND gender = ?", str)
assert.Equal(t, len(values), 2)
}
func TestConditions_ConvertToGorm_Error(t *testing.T) {
c := Conditions{Columns: []Column{
{
Name: "age",
Value: 10,
},
{
Name: "email",
Value: "foo@bar.com",
},
}}
whitelists := map[string]bool{"name": true, "age": true}
_, _, err := c.ConvertToGorm(WithWhitelistNames(whitelists))
t.Log(err)
assert.Error(t, err)
fn := func(columns []Column) error {
for _, col := range columns {
if col.Value == "foo@bar.com" {
return errors.New("'foo@bar.com' is not allowed")
}
}
return nil
}
_, _, err = c.ConvertToGorm(WithValidateFn(fn))
t.Log(err)
assert.Error(t, err)
}