sponge/pkg/sql2code/parser/parser.go

1289 lines
37 KiB
Go

// Package parser is a library that parses to go structures based on sql
// and generates the code needed based on the template.
package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/format"
"sort"
"strings"
"text/template"
"github.com/jinzhu/inflection"
"github.com/zhufuyi/sqlparser/ast"
"github.com/zhufuyi/sqlparser/dependency/mysql"
"github.com/zhufuyi/sqlparser/dependency/types"
"github.com/zhufuyi/sqlparser/parser"
)
const (
// TableName table name
TableName = "__table_name__"
// CodeTypeModel model code
CodeTypeModel = "model"
// CodeTypeJSON json code
CodeTypeJSON = "json"
// CodeTypeDAO update fields code
CodeTypeDAO = "dao"
// CodeTypeHandler handler request and respond code
CodeTypeHandler = "handler"
// CodeTypeProto proto file code
CodeTypeProto = "proto"
// CodeTypeService grpc service code
CodeTypeService = "service"
// CodeTypeCrudInfo crud info json data
CodeTypeCrudInfo = "crud_info"
// CodeTypeTableInfo table info json data
CodeTypeTableInfo = "table_info"
// DBDriverMysql mysql driver
DBDriverMysql = "mysql"
// DBDriverPostgresql postgresql driver
DBDriverPostgresql = "postgresql"
// DBDriverTidb tidb driver
DBDriverTidb = "tidb"
// DBDriverSqlite sqlite driver
DBDriverSqlite = "sqlite"
// DBDriverMongodb mongodb driver
DBDriverMongodb = "mongodb"
jsonTypeName = "datatypes.JSON"
jsonPkgPath = "gorm.io/datatypes"
boolTypeName = "sgorm.Bool"
boolTypeTinyName = "sgorm.TinyBool"
boolPkgPath = "github.com/go-dev-frame/sponge/pkg/sgorm"
decimalTypeName = "decimal.Decimal"
decimalPkgPath = "github.com/shopspring/decimal"
unknownCustomType = "UnknownCustomType"
)
// Codes content
type Codes struct {
Model []string // model code
UpdateFields []string // update fields code
ModelJSON []string // model json code
HandlerStruct []string // handler request and respond code
}
// modelCodes model code
type modelCodes struct {
Package string
ImportPath []string
StructCode []string
}
// ParseSQL generate different usage codes based on sql
func ParseSQL(sql string, options ...Option) (map[string]string, error) {
initTemplate()
initCommonTemplate()
opt := parseOption(options)
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation)
if err != nil {
return nil, err
}
modelStructCodes := make([]string, 0, len(stmts))
updateFieldsCodes := make([]string, 0, len(stmts))
handlerStructCodes := make([]string, 0, len(stmts))
protoFileCodes := make([]string, 0, len(stmts))
serviceStructCodes := make([]string, 0, len(stmts))
modelJSONCodes := make([]string, 0, len(stmts))
importPath := make(map[string]struct{})
tableNames := make([]string, 0, len(stmts))
primaryKeysCodes := make([]string, 0, len(stmts))
tableInfoCodes := make([]string, 0, len(stmts))
for _, stmt := range stmts {
if ct, ok := stmt.(*ast.CreateTableStmt); ok {
code, err2 := makeCode(ct, opt)
if err2 != nil {
return nil, err2
}
modelStructCodes = append(modelStructCodes, code.modelStruct)
updateFieldsCodes = append(updateFieldsCodes, code.updateFields)
handlerStructCodes = append(handlerStructCodes, code.handlerStruct)
protoFileCodes = append(protoFileCodes, code.protoFile)
serviceStructCodes = append(serviceStructCodes, code.serviceStruct)
modelJSONCodes = append(modelJSONCodes, code.modelJSON)
tableNames = append(tableNames, toCamel(ct.Table.Name.String()))
primaryKeysCodes = append(primaryKeysCodes, code.crudInfo)
tableInfoCodes = append(tableInfoCodes, string(code.tableInfo))
for _, s := range code.importPaths {
importPath[s] = struct{}{}
}
}
}
importPathArr := make([]string, 0, len(importPath))
for s := range importPath {
importPathArr = append(importPathArr, s)
}
sort.Strings(importPathArr)
mc := modelCodes{
Package: opt.Package,
ImportPath: importPathArr,
StructCode: modelStructCodes,
}
modelCode, err := getModelCode(mc)
if err != nil {
return nil, err
}
var codesMap = map[string]string{
CodeTypeModel: modelCode,
CodeTypeJSON: strings.Join(modelJSONCodes, "\n\n"),
CodeTypeDAO: strings.Join(updateFieldsCodes, "\n\n"),
CodeTypeHandler: strings.Join(handlerStructCodes, "\n\n"),
CodeTypeProto: strings.Join(protoFileCodes, "\n\n"),
CodeTypeService: strings.Join(serviceStructCodes, "\n\n"),
TableName: strings.Join(tableNames, ", "),
CodeTypeCrudInfo: strings.Join(primaryKeysCodes, " |||| "),
CodeTypeTableInfo: strings.Join(tableInfoCodes, " |||| "),
}
return codesMap, nil
}
type tmplData struct {
TableNamePrefix string
RawTableName string // raw table name, example: foo_bar
TableName string // table name in camel case, example: FooBar
TName string // table name first letter in lower case, example: fooBar
NameFunc bool
Fields []tmplField
Comment string
SubStructs string // sub structs for model
ProtoSubStructs string // sub structs for protobuf
DBDriver string
CrudInfo *CrudInfo
}
type tmplField struct {
IsPrimaryKey bool // is primary key
ColName string // table column name
Name string // convert to camel case
GoType string // convert to go type
Tag string
Comment string
JSONName string
DBDriver string
rewriterField *rewriterField
}
type rewriterField struct {
goType string
path string
}
func (d tmplData) isCommonStyle(isEmbed bool) bool {
if d.DBDriver != DBDriverMongodb && !isEmbed && !d.CrudInfo.isIDPrimaryKey() {
return true
}
return false
}
// ConditionZero type of condition 0, used in dao template code
func (t tmplField) ConditionZero() string {
if t.DBDriver == DBDriverMysql || t.DBDriver == DBDriverPostgresql || t.DBDriver == DBDriverTidb {
if t.rewriterField != nil {
switch t.rewriterField.goType {
case boolTypeName, boolTypeTinyName:
return ` != nil` //nolint
case jsonTypeName:
return `.String() != ""`
case decimalTypeName:
return `.IsZero() == false`
}
}
}
switch t.GoType {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32", //nolint
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64": //nolint
return ` != 0`
case "string", "sql.NullString": //nolint
return ` != ""`
case "time.Time", "*time.Time", "sql.NullTime": //nolint
return `.IsZero() == false`
case "[]byte", "[]string", "[]int", "interface{}": //nolint
return ` != nil` //nolint
case "bool": //nolint
return ` != false`
}
if t.DBDriver == DBDriverMongodb {
if t.GoType == goTypeOID {
return ` != primitive.NilObjectID`
}
if t.GoType == "*"+t.Name {
return ` != nil` //nolint
}
if strings.Contains(t.GoType, "[]") {
return ` != nil` //nolint
}
}
if t.GoType == "" {
return ` != "unknown_zero_value"`
}
return ` != ` + t.GoType
}
// GoZero type of 0, used in model to json template code
func (t tmplField) GoZero() string {
if t.DBDriver == DBDriverMysql || t.DBDriver == DBDriverPostgresql || t.DBDriver == DBDriverTidb {
if t.rewriterField != nil {
switch t.rewriterField.goType {
case jsonTypeName, decimalTypeName:
return ` = "string"`
case boolTypeName, boolTypeTinyName:
return `= false`
}
}
}
switch t.GoType {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64":
return `= 0`
case "string", "sql.NullString":
return `= "string"`
case "time.Time", "*time.Time", "sql.NullTime":
return `= "0000-01-00T00:00:00.000+08:00"`
case "[]byte", "[]string", "[]int", "interface{}": //nolint
return `= nil` //nolint
case "bool": //nolint
return `= false`
}
if t.DBDriver == DBDriverMongodb {
if t.GoType == goTypeOID {
return `= primitive.NilObjectID`
}
if t.GoType == "*"+t.Name {
return `= nil`
}
if strings.Contains(t.GoType, "[]") {
return `= nil`
}
}
if t.GoType == "" {
return `!= "unknown_zero_value"`
}
return `= ` + t.GoType
}
// GoTypeZero type of 0, used in service template code, corresponding protobuf type
func (t tmplField) GoTypeZero() string {
if t.DBDriver == DBDriverMysql || t.DBDriver == DBDriverPostgresql || t.DBDriver == DBDriverTidb {
if t.rewriterField != nil {
switch t.rewriterField.goType {
case jsonTypeName:
return `""` //nolint
case decimalTypeName:
return `""`
case boolTypeName, boolTypeTinyName:
return `false`
}
}
}
switch t.GoType {
case "int8", "int16", "int32", "int64", "int", "uint8", "uint16", "uint32", "uint64", "uint", "float64", "float32",
"sql.NullInt32", "sql.NullInt64", "sql.NullFloat64":
return `0`
case "string", "sql.NullString", jsonTypeName:
return `""`
case "time.Time", "*time.Time", "sql.NullTime":
return `""`
case "[]byte", "[]string", "[]int", "interface{}": //nolint
return `nil` //nolint
case "bool": //nolint
return `false`
}
if t.DBDriver == DBDriverMongodb {
if t.GoType == goTypeOID {
return `primitive.NilObjectID`
}
if t.GoType == "*"+t.Name {
return `nil` //nolint
}
if strings.Contains(t.GoType, "[]") {
return `nil` //nolint
}
}
if t.GoType == "" {
return `"unknown_zero_value"`
}
return t.GoType
}
// AddOne counter
func (t tmplField) AddOne(i int) int {
return i + 1
}
// AddOneWithTag counter and add id tag
func (t tmplField) AddOneWithTag(i int) string {
if t.ColName == "id" {
if t.DBDriver == DBDriverMongodb {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""]`, i+1)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"id\""]`, i+1, t.GoType)
}
return fmt.Sprintf("%d", i+1)
}
func (t tmplField) AddOneWithTag2(i int) string {
if t.IsPrimaryKey || t.ColName == "id" {
if t.GoType == "string" {
return fmt.Sprintf(`%d [(validate.rules).string.min_len = 1, (tagger.tags) = "uri:\"%s\""]`, i+1, t.JSONName)
}
return fmt.Sprintf(`%d [(validate.rules).%s.gt = 0, (tagger.tags) = "uri:\"%s\""]`, i+1, t.GoType, t.JSONName)
}
return fmt.Sprintf("%d", i+1)
}
func getProtoFieldName(fields []tmplField) string {
for _, field := range fields {
if field.IsPrimaryKey || field.ColName == "id" {
return field.JSONName
}
}
return ""
}
const (
__mysqlModel__ = "__mysqlModel__" //nolint
__type__ = "__type__" //nolint
)
var replaceFields = map[string]string{
__mysqlModel__: "sgorm.Model",
__type__: "",
}
const (
columnID = "id"
_columnID = "_id"
columnCreatedAt = "created_at"
columnUpdatedAt = "updated_at"
columnDeletedAt = "deleted_at"
columnMysqlModel = __mysqlModel__
)
var ignoreColumns = map[string]struct{}{
columnID: {},
columnCreatedAt: {},
columnUpdatedAt: {},
columnDeletedAt: {},
columnMysqlModel: {},
}
func isIgnoreFields(colName string, falseColumn ...string) bool {
for _, v := range falseColumn {
if colName == v {
return false
}
}
_, ok := ignoreColumns[colName]
return ok
}
type codeText struct {
importPaths []string
modelStruct string
modelJSON string
updateFields string
handlerStruct string
protoFile string
serviceStruct string
crudInfo string
tableInfo []byte
}
// nolint
func makeCode(stmt *ast.CreateTableStmt, opt options) (*codeText, error) {
importPath := make([]string, 0, 1)
data := tmplData{
TableNamePrefix: opt.TablePrefix,
RawTableName: stmt.Table.Name.String(),
DBDriver: opt.DBDriver,
}
tablePrefix := data.TableNamePrefix
if tablePrefix != "" && strings.HasPrefix(data.RawTableName, tablePrefix) {
data.NameFunc = true
data.TableName = toCamel(data.RawTableName[len(tablePrefix):])
} else {
data.TableName = toCamel(data.RawTableName)
}
data.TName = firstLetterToLower(data.TableName)
if opt.ForceTableName || data.RawTableName != inflection.Plural(data.RawTableName) {
data.NameFunc = true
}
switch opt.DBDriver {
case DBDriverMongodb:
if opt.JSONNamedType != 0 {
SetJSONTagCamelCase()
} else {
SetJSONTagSnakeCase()
}
}
// find table comment
for _, o := range stmt.Options {
if o.Tp == ast.TableOptionComment {
data.Comment = o.StrValue
break
}
}
isPrimaryKey := make(map[string]bool)
for _, con := range stmt.Constraints {
if con.Tp == ast.ConstraintPrimaryKey {
isPrimaryKey[con.Keys[0].Column.String()] = true
}
if con.Tp == ast.ConstraintForeignKey {
// TODO: foreign key support
}
}
columnPrefix := opt.ColumnPrefix
for _, col := range stmt.Cols {
colName := col.Name.Name.String()
goFieldName := colName
if columnPrefix != "" && strings.HasPrefix(goFieldName, columnPrefix) {
goFieldName = goFieldName[len(columnPrefix):]
}
jsonName := colName
if opt.JSONNamedType == 0 { // snake case
jsonName = customToSnake(jsonName)
} else {
jsonName = customToCamel(jsonName) // camel case (default)
}
field := tmplField{
Name: toCamel(goFieldName),
ColName: colName,
JSONName: jsonName,
}
tags := make([]string, 0, 4)
// make GORM's tag
gormTag := strings.Builder{}
gormTag.WriteString("column:")
gormTag.WriteString(colName)
if opt.GormType {
gormTag.WriteString(";type:")
switch opt.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverSqlite:
gormTag.WriteString(col.Tp.InfoSchemaStr())
case DBDriverPostgresql:
gormTag.WriteString(opt.FieldTypes[colName])
}
}
if isPrimaryKey[colName] {
field.IsPrimaryKey = true
gormTag.WriteString(";primary_key")
}
isNotNull := false
canNull := false
for _, o := range col.Options {
switch o.Tp {
case ast.ColumnOptionPrimaryKey:
if !isPrimaryKey[colName] {
gormTag.WriteString(";primary_key")
isPrimaryKey[colName] = true
}
case ast.ColumnOptionNotNull:
isNotNull = true
case ast.ColumnOptionAutoIncrement:
gormTag.WriteString(";AUTO_INCREMENT")
case ast.ColumnOptionDefaultValue:
if value := getDefaultValue(o.Expr); value != "" {
gormTag.WriteString(";default:")
gormTag.WriteString(value)
}
case ast.ColumnOptionUniqKey:
gormTag.WriteString(";unique")
case ast.ColumnOptionNull:
//gormTag.WriteString(";NULL")
canNull = true
case ast.ColumnOptionOnUpdate: // For Timestamp and Datetime only.
case ast.ColumnOptionFulltext:
case ast.ColumnOptionComment:
field.Comment = o.Expr.GetDatum().GetString()
default:
//return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp)
}
}
field.DBDriver = opt.DBDriver
switch opt.DBDriver {
case DBDriverMongodb: // mongodb
tags = append(tags, "bson", gormTag.String())
if opt.JSONTag {
if strings.ToLower(jsonName) == "_id" {
jsonName = "id"
}
field.JSONName = jsonName
tags = append(tags, "json", jsonName)
}
field.Tag = makeTagStr(tags)
field.GoType = opt.FieldTypes[colName]
if field.GoType == "time.Time" {
importPath = append(importPath, "time")
}
default: // gorm
if !isPrimaryKey[colName] && isNotNull {
gormTag.WriteString(";not null")
}
tags = append(tags, "gorm", gormTag.String())
if opt.JSONTag {
tags = append(tags, "json", jsonName)
}
field.Tag = makeTagStr(tags)
// get type in golang
nullStyle := opt.NullStyle
if !canNull {
nullStyle = NullDisable
}
goType, pkg, rrField := mysqlToGoType(col.Tp, nullStyle)
if pkg != "" {
importPath = append(importPath, pkg)
}
field.GoType = goType
field.rewriterField = rrField
if opt.DBDriver == DBDriverPostgresql {
if opt.FieldTypes[colName] == "bool" {
field.GoType = "bool" // rewritten type
}
}
}
data.Fields = append(data.Fields, field)
}
if v, ok := opt.FieldTypes[SubStructKey]; ok {
data.SubStructs = v
}
if v, ok := opt.FieldTypes[ProtoSubStructKey]; ok {
data.ProtoSubStructs = v
}
if len(data.Fields) == 0 {
return nil, errors.New("no columns found in table " + data.TableName)
}
data.CrudInfo = newCrudInfo(data)
data.CrudInfo.IsCommonType = data.isCommonStyle(opt.IsEmbed)
if opt.IsCustomTemplate {
tableInfo := newTableInfo(data)
return &codeText{tableInfo: tableInfo.getCode()}, nil
}
modelStructCode, importPaths, err := getModelStructCode(data, importPath, opt.IsEmbed, opt.JSONNamedType)
if err != nil {
return nil, err
}
updateFieldsCode, err := getUpdateFieldsCode(data, opt.IsEmbed)
if err != nil {
return nil, err
}
modelJSONCode, err := getModelJSONCode(data)
if err != nil {
return nil, err
}
handlerStructCode := ""
serviceStructCode := ""
protoFileCode := ""
if data.isCommonStyle(opt.IsEmbed) {
handlerStructCode, err = getCommonHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getCommonServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getCommonProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
} else {
handlerStructCode, err = getHandlerStructCodes(data, opt.JSONNamedType)
if err != nil {
return nil, err
}
serviceStructCode, err = getServiceStructCode(data)
if err != nil {
return nil, err
}
protoFileCode, err = getProtoFileCode(data, opt.JSONNamedType, opt.IsWebProto, opt.IsExtendedAPI)
if err != nil {
return nil, err
}
}
return &codeText{
importPaths: importPaths,
modelStruct: modelStructCode,
modelJSON: modelJSONCode,
updateFields: updateFieldsCode,
handlerStruct: handlerStructCode,
protoFile: protoFileCode,
serviceStruct: serviceStructCode,
crudInfo: data.CrudInfo.getCode(),
}, nil
}
// nolint
func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonNamedType int) (string, []string, error) {
// filter to ignore field fields
var newFields = []tmplField{}
var newImportPaths = []string{}
if isEmbed {
newFields = append(newFields, tmplField{
Name: __mysqlModel__,
ColName: __mysqlModel__,
GoType: __type__,
Tag: `gorm:"embedded"`,
Comment: "embed id and time\n",
})
isHaveTimeType := false
for _, field := range data.Fields {
if isIgnoreFields(field.ColName) {
continue
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
switch field.rewriterField.goType {
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName, boolTypeTinyName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
}
}
newFields = append(newFields, field)
if strings.Contains(field.GoType, "time.Time") {
isHaveTimeType = true
}
}
data.Fields = newFields
// filter time package name
if isHaveTimeType {
newImportPaths = importPaths
} else {
for _, path := range importPaths {
if path == "time" { //nolint
continue
}
newImportPaths = append(newImportPaths, path)
}
}
newImportPaths = append(newImportPaths, "github.com/go-dev-frame/sponge/pkg/sgorm")
} else {
for _, field := range data.Fields {
switch field.DBDriver {
case DBDriverMongodb:
if field.Name == "ID" {
field.GoType = goTypeOID
importPaths = append(importPaths, "go.mongodb.org/mongo-driver/bson/primitive")
}
default:
if strings.Contains(field.GoType, "time.Time") {
field.GoType = "*time.Time"
}
// force conversion of ID field to uint64 type
if field.Name == "ID" {
field.GoType = "uint64"
if data.isCommonStyle(isEmbed) {
field.GoType = data.CrudInfo.GoType
}
}
if field.DBDriver == DBDriverMysql || field.DBDriver == DBDriverPostgresql || field.DBDriver == DBDriverTidb {
if field.rewriterField != nil {
switch field.rewriterField.goType {
//case jsonTypeName, decimalTypeName:
// field.GoType = field.rewriterField.goType
// importPaths = append(importPaths, field.rewriterField.path)
case jsonTypeName, decimalTypeName, boolTypeName, boolTypeTinyName:
field.GoType = "*" + field.rewriterField.goType
importPaths = append(importPaths, field.rewriterField.path)
}
}
}
}
newFields = append(newFields, field)
}
data.Fields = newFields
newImportPaths = importPaths
}
builder := strings.Builder{}
err := modelStructTmpl.Execute(&builder, data)
if err != nil {
return "", nil, fmt.Errorf("modelStructTmpl.Execute error: %v", err)
}
code, err := format.Source([]byte(builder.String()))
if err != nil {
return "", nil, fmt.Errorf("modelStructTmpl format.Source error: %v", err)
}
structCode := string(code)
// restore the real embedded fields
if isEmbed {
gormEmbed := replaceFields[__mysqlModel__]
if jsonNamedType == 0 { // snake case
gormEmbed += "2" // sgorm.Model2
}
structCode = strings.ReplaceAll(structCode, __mysqlModel__, gormEmbed)
structCode = strings.ReplaceAll(structCode, __type__, replaceFields[__type__])
}
if data.SubStructs != "" {
structCode += data.SubStructs
}
if data.DBDriver == DBDriverMongodb {
structCode = strings.ReplaceAll(structCode, `bson:"column:`, `bson:"`)
structCode = strings.ReplaceAll(structCode, `;type:"`, `"`)
structCode = strings.ReplaceAll(structCode, `;type:;primary_key`, ``)
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
}
return structCode, newImportPaths, nil
}
func getModelCode(data modelCodes) (string, error) {
builder := strings.Builder{}
err := modelTmpl.Execute(&builder, data)
if err != nil {
return "", err
}
code, err := format.Source([]byte(builder.String()))
if err != nil {
return "", fmt.Errorf("getModelCode format.Source error: %v", err)
}
return string(code), nil
}
func getUpdateFieldsCode(data tmplData, isEmbed bool) (string, error) {
_ = isEmbed
// filter fields
var newFields = []tmplField{}
for _, field := range data.Fields {
falseColumns := []string{}
if isIgnoreFields(field.ColName, falseColumns...) || field.ColName == columnID || field.ColName == _columnID {
continue
}
switch field.DBDriver {
case DBDriverMysql, DBDriverTidb, DBDriverPostgresql:
if field.rewriterField != nil {
if field.rewriterField.goType == jsonTypeName {
field.GoType = "[]byte"
}
}
}
newFields = append(newFields, field)
}
data.Fields = newFields
buf := new(bytes.Buffer)
err := updateFieldTmpl.Execute(buf, data)
if err != nil {
return "", err
}
return buf.String(), nil
}
func getHandlerStructCodes(data tmplData, jsonNamedType int) (string, error) {
newFields := []tmplField{}
for _, field := range data.Fields {
if field.DBDriver == DBDriverMongodb { // mongodb
if field.Name == "ID" {
field.GoType = "string"
}
if "*"+field.Name == field.GoType {
field.GoType = "*model." + field.Name
}
if strings.Contains(field.GoType, "[]*") {
field.GoType = "[]*model." + strings.ReplaceAll(field.GoType, "[]*", "")
}
}
if jsonNamedType == 0 { // snake case
field.JSONName = customToSnake(field.ColName)
} else {
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
field.GoType = getHandlerGoType(&field)
newFields = append(newFields, field)
}
data.Fields = newFields
postStructCode, err := tmplExecuteWithFilter(data, handlerCreateStructTmpl)
if err != nil {
return "", fmt.Errorf("handlerCreateStructTmpl error: %v", err)
}
putStructCode, err := tmplExecuteWithFilter(data, handlerUpdateStructTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handlerUpdateStructTmpl error: %v", err)
}
getStructCode, err := tmplExecuteWithFilter(data, handlerDetailStructTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handlerDetailStructTmpl error: %v", err)
}
return postStructCode + putStructCode + getStructCode, nil
}
// customized filter fields
func tmplExecuteWithFilter(data tmplData, tmpl *template.Template, reservedColumns ...string) (string, error) {
var newFields = []tmplField{}
for _, field := range data.Fields {
if isIgnoreFields(field.ColName, reservedColumns...) {
continue
}
if field.DBDriver == DBDriverMongodb { // mongodb
if strings.ToLower(field.Name) == "id" {
field.GoType = "string"
}
}
newFields = append(newFields, field)
}
data.Fields = newFields
builder := strings.Builder{}
err := tmpl.Execute(&builder, data)
if err != nil {
return "", fmt.Errorf("tmpl.Execute error: %v", err)
}
return builder.String(), nil
}
func getModelJSONCode(data tmplData) (string, error) {
builder := strings.Builder{}
err := modelJSONTmpl.Execute(&builder, data)
if err != nil {
return "", err
}
code, err := format.Source([]byte(builder.String()))
if err != nil {
return "", fmt.Errorf("getModelJSONCode format.Source error: %v", err)
}
modelJSONCode := strings.ReplaceAll(string(code), " =", ":")
modelJSONCode = addCommaToJSON(modelJSONCode)
return modelJSONCode, nil
}
func getProtoFileCode(data tmplData, jsonNamedType int, isWebProto bool, isExtendedAPI bool) (string, error) {
data.Fields = goTypeToProto(data.Fields, jsonNamedType, false)
var err error
builder := strings.Builder{}
if isWebProto {
if isExtendedAPI {
err = protoFileForWebTmpl.Execute(&builder, data)
} else {
err = protoFileForSimpleWebTmpl.Execute(&builder, data)
}
if err != nil {
return "", err
}
} else {
if isExtendedAPI {
err = protoFileTmpl.Execute(&builder, data)
} else {
err = protoFileSimpleTmpl.Execute(&builder, data)
}
if err != nil {
return "", err
}
}
code := builder.String()
protoMessageCreateCode, err := tmplExecuteWithFilter(data, protoMessageCreateTmpl)
if err != nil {
return "", fmt.Errorf("handle protoMessageCreateTmpl error: %v", err)
}
protoMessageUpdateCode, err := tmplExecuteWithFilter(data, protoMessageUpdateTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handle protoMessageUpdateTmpl error: %v", err)
}
if !isWebProto {
protoMessageUpdateCode = strings.ReplaceAll(protoMessageUpdateCode, `, (tagger.tags) = "uri:\"id\""`, "")
}
protoMessageDetailCode, err := tmplExecuteWithFilter(data, protoMessageDetailTmpl, columnID, columnCreatedAt, columnUpdatedAt)
if err != nil {
return "", fmt.Errorf("handle protoMessageDetailTmpl error: %v", err)
}
code = strings.ReplaceAll(code, "// protoMessageCreateCode", protoMessageCreateCode)
code = strings.ReplaceAll(code, "// protoMessageUpdateCode", protoMessageUpdateCode)
code = strings.ReplaceAll(code, "// protoMessageDetailCode", protoMessageDetailCode)
code = strings.ReplaceAll(code, "*time.Time", "int64")
code = strings.ReplaceAll(code, "time.Time", "int64")
code = adaptedDbType(data, isWebProto, code)
return code, nil
}
const (
createTableReplyFieldCodeMark = "// createTableReplyFieldCode"
deleteTableByIDRequestFieldCodeMark = "// deleteTableByIDRequestFieldCode"
deleteTableByIDsRequestFieldCodeMark = "// deleteTableByIDsRequestFieldCode"
getTableByIDRequestFieldCodeMark = "// getTableByIDRequestFieldCode"
getTableByIDsRequestFieldCodeMark = "// getTableByIDsRequestFieldCode"
listTableByLastIDRequestFieldCodeMark = "// listTableByLastIDRequestFieldCode"
)
var grpcDefaultProtoMessageFieldCodes = map[string]string{
createTableReplyFieldCodeMark: "uint64 id = 1;",
deleteTableByIDRequestFieldCodeMark: "uint64 id = 1 [(validate.rules).uint64.gt = 0];",
deleteTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
getTableByIDRequestFieldCodeMark: "uint64 id = 1 [(validate.rules).uint64.gt = 0];",
getTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
listTableByLastIDRequestFieldCodeMark: "uint64 lastID = 1; // last id",
}
var webDefaultProtoMessageFieldCodes = map[string]string{
createTableReplyFieldCodeMark: "uint64 id = 1;",
deleteTableByIDRequestFieldCodeMark: `uint64 id =1 [(validate.rules).uint64.gt = 0, (tagger.tags) = "uri:\"id\""];`,
deleteTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
getTableByIDRequestFieldCodeMark: `uint64 id =1 [(validate.rules).uint64.gt = 0, (tagger.tags) = "uri:\"id\"" ];`,
getTableByIDsRequestFieldCodeMark: "repeated uint64 ids = 1 [(validate.rules).repeated.min_items = 1];",
listTableByLastIDRequestFieldCodeMark: `uint64 lastID = 1 [(tagger.tags) = "form:\"lastID\""]; // last id`,
}
var grpcProtoMessageFieldCodes = map[string]string{
createTableReplyFieldCodeMark: "string id = 1;",
deleteTableByIDRequestFieldCodeMark: "string id = 1 [(validate.rules).string.min_len = 6];",
deleteTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
getTableByIDRequestFieldCodeMark: "string id = 1 [(validate.rules).string.min_len = 6];",
getTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
listTableByLastIDRequestFieldCodeMark: "string lastID = 1; // last id",
}
var webProtoMessageFieldCodes = map[string]string{
createTableReplyFieldCodeMark: "string id = 1;",
deleteTableByIDRequestFieldCodeMark: `string id =1 [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\""];`,
deleteTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
getTableByIDRequestFieldCodeMark: `string id =1 [(validate.rules).string.min_len = 6, (tagger.tags) = "uri:\"id\"" ];`,
getTableByIDsRequestFieldCodeMark: "repeated string ids = 1 [(validate.rules).repeated.min_items = 1];",
listTableByLastIDRequestFieldCodeMark: `string lastID = 1 [(tagger.tags) = "form:\"lastID\""]; // last id`,
}
func adaptedDbType(data tmplData, isWebProto bool, code string) string {
switch data.DBDriver {
case DBDriverMongodb: // mongodb
if isWebProto {
code = replaceProtoMessageFieldCode(code, webProtoMessageFieldCodes)
} else {
code = replaceProtoMessageFieldCode(code, grpcProtoMessageFieldCodes)
}
default:
if isWebProto {
code = replaceProtoMessageFieldCode(code, webDefaultProtoMessageFieldCodes)
} else {
code = replaceProtoMessageFieldCode(code, grpcDefaultProtoMessageFieldCodes)
}
}
if data.ProtoSubStructs != "" {
code += "\n" + data.ProtoSubStructs
}
return code
}
func replaceProtoMessageFieldCode(code string, messageFields map[string]string) string {
for k, v := range messageFields {
code = strings.ReplaceAll(code, k, v)
}
return code
}
func getServiceStructCode(data tmplData) (string, error) {
builder := strings.Builder{}
err := serviceStructTmpl.Execute(&builder, data)
if err != nil {
return "", err
}
code := builder.String()
serviceCreateStructCode, err := tmplExecuteWithFilter(data, serviceCreateStructTmpl)
if err != nil {
return "", fmt.Errorf("handle serviceCreateStructTmpl error: %v", err)
}
serviceCreateStructCode = strings.ReplaceAll(serviceCreateStructCode, "ID:", "Id:")
serviceUpdateStructCode, err := tmplExecuteWithFilter(data, serviceUpdateStructTmpl, columnID)
if err != nil {
return "", fmt.Errorf("handle serviceUpdateStructTmpl error: %v", err)
}
serviceUpdateStructCode = strings.ReplaceAll(serviceUpdateStructCode, "ID:", "Id:")
code = strings.ReplaceAll(code, "// serviceCreateStructCode", serviceCreateStructCode)
code = strings.ReplaceAll(code, "// serviceUpdateStructCode", serviceUpdateStructCode)
return code, nil
}
func addCommaToJSON(modelJSONCode string) string {
r := strings.NewReader(modelJSONCode)
buf := bufio.NewReader(r)
lines := []string{}
count := 0
for {
line, err := buf.ReadString(byte('\n'))
if err != nil {
break
}
lines = append(lines, line)
if len(line) > 5 {
count++
}
}
out := ""
for _, line := range lines {
if len(line) < 5 && (strings.Contains(line, "{") || strings.Contains(line, "}")) {
out += line
continue
}
count--
if count == 0 {
out += line
continue
}
index := bytes.IndexByte([]byte(line), '\n')
out += line[:index] + "," + line[index:]
}
return out
}
// nolint
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string, rrField *rewriterField) {
if style == NullInSql {
path = "database/sql"
switch colTp.Tp {
case mysql.TypeTiny:
name = "sql.NullInt8"
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
name = "sql.NullInt32"
case mysql.TypeLonglong, mysql.TypeDuration:
name = "sql.NullInt64"
case mysql.TypeFloat, mysql.TypeDouble:
name = "sql.NullFloat64"
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "sql.NullString"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
name = "sql.NullTime"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "sql.NullString"
case mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
name = "sql.NullString"
case mysql.TypeBit:
name = "sql.NullBool"
default:
return unknownCustomType, "", nil
}
} else {
switch colTp.Tp {
case mysql.TypeTiny:
if strings.ToLower(colTp.String()) == "tinyint(1)" {
name = "bool"
rrField = &rewriterField{
goType: boolTypeTinyName,
path: boolPkgPath,
}
} else {
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint"
} else {
name = "int"
}
}
case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeYear:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint"
} else {
name = "int"
}
case mysql.TypeLonglong, mysql.TypeDuration:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "uint64"
} else {
name = "int64"
}
case mysql.TypeFloat, mysql.TypeDouble:
name = "float64"
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "string"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate:
path = "time" //nolint
name = "time.Time"
case mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry:
name = "string"
case mysql.TypeJSON:
name = "string"
rrField = &rewriterField{
goType: jsonTypeName,
path: jsonPkgPath,
}
case mysql.TypeBit:
if strings.ToLower(colTp.String()) == "bit(1)" {
name = "bool"
rrField = &rewriterField{
goType: boolTypeName,
path: boolPkgPath,
}
} else {
name = "[]byte"
}
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "string"
rrField = &rewriterField{
goType: decimalTypeName,
path: decimalPkgPath,
}
default:
return unknownCustomType, "", nil
}
if style == NullInPointer {
name = "*" + name
}
}
return name, path, rrField
}
// nolint
func goTypeToProto(fields []tmplField, jsonNameType int, isCommonStyle bool) []tmplField {
var newFields []tmplField
for _, field := range fields {
switch field.GoType {
case "int":
field.GoType = "int32"
case "uint":
field.GoType = "uint32"
case "time.Time", "*time.Time":
field.GoType = "string"
case "float32":
field.GoType = "float"
case "float64":
field.GoType = "double"
case goTypeInts, "[]int64":
field.GoType = "repeated int64"
case "[]int32":
field.GoType = "repeated int32"
case "[]byte":
field.GoType = "string"
case goTypeStrings:
field.GoType = "repeated string"
case jsonTypeName:
field.GoType = "string"
}
if field.DBDriver == DBDriverMongodb && field.GoType != "" {
if field.GoType[0] == '*' {
field.GoType = field.GoType[1:]
} else if strings.Contains(field.GoType, "[]*") {
field.GoType = "repeated " + strings.ReplaceAll(field.GoType, "[]*", "")
}
if field.GoType == "[]time.Time" {
field.GoType = "repeated string"
}
} else {
if strings.ToLower(field.Name) == "id" && !isCommonStyle {
field.GoType = "uint64"
}
}
if jsonNameType == 0 { // snake case
field.JSONName = customToSnake(field.ColName)
} else {
field.JSONName = customToCamel(field.ColName) // camel case (default)
}
if field.rewriterField != nil {
switch field.rewriterField.goType {
case jsonTypeName, decimalTypeName:
field.GoType = "string"
case boolTypeName, boolTypeTinyName:
field.GoType = "bool"
}
}
newFields = append(newFields, field)
}
return newFields
}
func makeTagStr(tags []string) string {
builder := strings.Builder{}
for i := 0; i < len(tags)/2; i++ {
builder.WriteString(tags[i*2])
builder.WriteString(`:"`)
builder.WriteString(tags[i*2+1])
builder.WriteString(`" `)
}
if builder.Len() > 0 {
return builder.String()[:builder.Len()-1]
}
return builder.String()
}
func getDefaultValue(expr ast.ExprNode) (value string) {
if expr.GetDatum().Kind() != types.KindNull {
value = fmt.Sprintf("%v", expr.GetDatum().GetValue())
} else if expr.GetFlag() != ast.FlagConstant {
if expr.GetFlag() == ast.FlagHasFunc {
if funcExpr, ok := expr.(*ast.FuncCallExpr); ok {
value = funcExpr.FnName.O
}
}
}
return value
}