mirror of https://github.com/zhufuyi/sponge
fix: whitelist names of embed models
This commit is contained in:
parent
d62a16ceb7
commit
7c90a8da10
|
@ -780,9 +780,60 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
|
|||
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
|
||||
}
|
||||
|
||||
tableColumnsCode, err := getTableColumnsCode(data, isEmbed)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
structCode += string(tableColumnsCode)
|
||||
|
||||
return structCode, newImportPaths, nil
|
||||
}
|
||||
|
||||
func getTableColumnsCode(data tmplData, isEmbed bool) ([]byte, error) {
|
||||
if data.DBDriver == DBDriverMongodb {
|
||||
for _, field := range data.Fields {
|
||||
if field.Name == "ID" {
|
||||
field.ColName = "_id"
|
||||
data.Fields = append(data.Fields, field)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isEmbed {
|
||||
var fields = []tmplField{
|
||||
{
|
||||
ColName: "id",
|
||||
},
|
||||
{
|
||||
ColName: "created_at",
|
||||
},
|
||||
{
|
||||
ColName: "updated_at",
|
||||
},
|
||||
{
|
||||
ColName: "deleted_at",
|
||||
},
|
||||
}
|
||||
for _, field := range data.Fields {
|
||||
if field.Name == __mysqlModel__ {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, field)
|
||||
}
|
||||
data.Fields = fields
|
||||
}
|
||||
builder := strings.Builder{}
|
||||
err := tableColumnsTmpl.Execute(&builder, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tableColumnsTmpl.Execute error: %v", err)
|
||||
}
|
||||
code, err := format.Source([]byte(builder.String()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tableColumnsTmpl format.Source error: %v", err)
|
||||
}
|
||||
return code, err
|
||||
}
|
||||
|
||||
func getModelCode(data modelCodes) (string, error) {
|
||||
builder := strings.Builder{}
|
||||
err := modelTmpl.Execute(&builder, data)
|
||||
|
|
|
@ -24,6 +24,16 @@ func (m *{{.TableName}}) TableName() string {
|
|||
return "{{.RawTableName}}"
|
||||
}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
tableColumnsTmpl *template.Template
|
||||
tableColumnsTmplRaw = `
|
||||
// {{.TableName}}ColumnNames Whitelist for custom query fields to prevent sql injection attacks
|
||||
var {{.TableName}}ColumnNames = map[string]bool{
|
||||
{{- range .Fields}}
|
||||
"{{.ColName}}": true,
|
||||
{{- end}}
|
||||
}
|
||||
`
|
||||
|
||||
modelTmpl *template.Template
|
||||
|
@ -730,6 +740,10 @@ func initTemplate() {
|
|||
if err != nil {
|
||||
errSum = errors.Wrap(err, "modelStructTmplRaw")
|
||||
}
|
||||
tableColumnsTmpl, err = template.New("tableColumns").Parse(tableColumnsTmplRaw)
|
||||
if err != nil {
|
||||
errSum = errors.Wrap(err, "tableColumnsTmplRaw")
|
||||
}
|
||||
modelTmpl, err = template.New("goFile").Parse(modelTmplRaw)
|
||||
if err != nil {
|
||||
errSum = errors.Wrap(errSum, "modelTmplRaw:"+err.Error())
|
||||
|
|
Loading…
Reference in New Issue