mirror of https://github.com/zhufuyi/sponge
fix: whitelist names of embed models
This commit is contained in:
parent
d62a16ceb7
commit
7c90a8da10
|
@ -780,9 +780,60 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
|
||||||
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
|
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tableColumnsCode, err := getTableColumnsCode(data, isEmbed)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
structCode += string(tableColumnsCode)
|
||||||
|
|
||||||
return structCode, newImportPaths, nil
|
return structCode, newImportPaths, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTableColumnsCode(data tmplData, isEmbed bool) ([]byte, error) {
|
||||||
|
if data.DBDriver == DBDriverMongodb {
|
||||||
|
for _, field := range data.Fields {
|
||||||
|
if field.Name == "ID" {
|
||||||
|
field.ColName = "_id"
|
||||||
|
data.Fields = append(data.Fields, field)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isEmbed {
|
||||||
|
var fields = []tmplField{
|
||||||
|
{
|
||||||
|
ColName: "id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColName: "created_at",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColName: "updated_at",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColName: "deleted_at",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, field := range data.Fields {
|
||||||
|
if field.Name == __mysqlModel__ {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
data.Fields = fields
|
||||||
|
}
|
||||||
|
builder := strings.Builder{}
|
||||||
|
err := tableColumnsTmpl.Execute(&builder, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tableColumnsTmpl.Execute error: %v", err)
|
||||||
|
}
|
||||||
|
code, err := format.Source([]byte(builder.String()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tableColumnsTmpl format.Source error: %v", err)
|
||||||
|
}
|
||||||
|
return code, err
|
||||||
|
}
|
||||||
|
|
||||||
func getModelCode(data modelCodes) (string, error) {
|
func getModelCode(data modelCodes) (string, error) {
|
||||||
builder := strings.Builder{}
|
builder := strings.Builder{}
|
||||||
err := modelTmpl.Execute(&builder, data)
|
err := modelTmpl.Execute(&builder, data)
|
||||||
|
|
|
@ -24,6 +24,16 @@ func (m *{{.TableName}}) TableName() string {
|
||||||
return "{{.RawTableName}}"
|
return "{{.RawTableName}}"
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
tableColumnsTmpl *template.Template
|
||||||
|
tableColumnsTmplRaw = `
|
||||||
|
// {{.TableName}}ColumnNames Whitelist for custom query fields to prevent sql injection attacks
|
||||||
|
var {{.TableName}}ColumnNames = map[string]bool{
|
||||||
|
{{- range .Fields}}
|
||||||
|
"{{.ColName}}": true,
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
modelTmpl *template.Template
|
modelTmpl *template.Template
|
||||||
|
@ -730,6 +740,10 @@ func initTemplate() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errSum = errors.Wrap(err, "modelStructTmplRaw")
|
errSum = errors.Wrap(err, "modelStructTmplRaw")
|
||||||
}
|
}
|
||||||
|
tableColumnsTmpl, err = template.New("tableColumns").Parse(tableColumnsTmplRaw)
|
||||||
|
if err != nil {
|
||||||
|
errSum = errors.Wrap(err, "tableColumnsTmplRaw")
|
||||||
|
}
|
||||||
modelTmpl, err = template.New("goFile").Parse(modelTmplRaw)
|
modelTmpl, err = template.New("goFile").Parse(modelTmplRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errSum = errors.Wrap(errSum, "modelTmplRaw:"+err.Error())
|
errSum = errors.Wrap(errSum, "modelTmplRaw:"+err.Error())
|
||||||
|
|
Loading…
Reference in New Issue