fix: whitelist names of embed models

This commit is contained in:
zhuyasen 2025-05-18 23:20:56 +08:00
parent d62a16ceb7
commit 7c90a8da10
2 changed files with 65 additions and 0 deletions

View File

@ -780,9 +780,60 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
}
tableColumnsCode, err := getTableColumnsCode(data, isEmbed)
if err != nil {
return "", nil, err
}
structCode += string(tableColumnsCode)
return structCode, newImportPaths, nil
}
func getTableColumnsCode(data tmplData, isEmbed bool) ([]byte, error) {
if data.DBDriver == DBDriverMongodb {
for _, field := range data.Fields {
if field.Name == "ID" {
field.ColName = "_id"
data.Fields = append(data.Fields, field)
break
}
}
}
if isEmbed {
var fields = []tmplField{
{
ColName: "id",
},
{
ColName: "created_at",
},
{
ColName: "updated_at",
},
{
ColName: "deleted_at",
},
}
for _, field := range data.Fields {
if field.Name == __mysqlModel__ {
continue
}
fields = append(fields, field)
}
data.Fields = fields
}
builder := strings.Builder{}
err := tableColumnsTmpl.Execute(&builder, data)
if err != nil {
return nil, fmt.Errorf("tableColumnsTmpl.Execute error: %v", err)
}
code, err := format.Source([]byte(builder.String()))
if err != nil {
return nil, fmt.Errorf("tableColumnsTmpl format.Source error: %v", err)
}
return code, err
}
func getModelCode(data modelCodes) (string, error) {
builder := strings.Builder{}
err := modelTmpl.Execute(&builder, data)

View File

@ -24,6 +24,16 @@ func (m *{{.TableName}}) TableName() string {
return "{{.RawTableName}}"
}
{{end}}
`
tableColumnsTmpl *template.Template
tableColumnsTmplRaw = `
// {{.TableName}}ColumnNames Whitelist for custom query fields to prevent sql injection attacks
var {{.TableName}}ColumnNames = map[string]bool{
{{- range .Fields}}
"{{.ColName}}": true,
{{- end}}
}
`
modelTmpl *template.Template
@ -730,6 +740,10 @@ func initTemplate() {
if err != nil {
errSum = errors.Wrap(err, "modelStructTmplRaw")
}
tableColumnsTmpl, err = template.New("tableColumns").Parse(tableColumnsTmplRaw)
if err != nil {
errSum = errors.Wrap(err, "tableColumnsTmplRaw")
}
modelTmpl, err = template.New("goFile").Parse(modelTmplRaw)
if err != nil {
errSum = errors.Wrap(errSum, "modelTmplRaw:"+err.Error())