100 lines
2.2 KiB
Go
100 lines
2.2 KiB
Go
package db
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"reflect"
|
||
|
||
"gorm.io/gorm/schema"
|
||
)
|
||
|
||
// 必须给结构体(而不是指针)实现此接口。FromString实现为静态方法
|
||
type StringDBValuer interface {
|
||
ToString() (string, error)
|
||
FromString(str string) (any, error)
|
||
}
|
||
|
||
type StringSerializer struct {
|
||
}
|
||
|
||
func (StringSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) error {
|
||
if dbValue == nil {
|
||
fieldValue := reflect.New(field.FieldType)
|
||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||
return nil
|
||
}
|
||
|
||
str := ""
|
||
switch v := dbValue.(type) {
|
||
case []byte:
|
||
str = string(v)
|
||
case string:
|
||
str = v
|
||
default:
|
||
return fmt.Errorf("expected []byte or string, got: %T", dbValue)
|
||
}
|
||
|
||
if field.FieldType.Kind() == reflect.Struct {
|
||
val := reflect.Zero(field.FieldType)
|
||
|
||
sv, ok := val.Interface().(StringDBValuer)
|
||
if !ok {
|
||
return fmt.Errorf("ref of field type %v is not StringDBValuer", field.FieldType)
|
||
}
|
||
|
||
v2, err := sv.FromString(str)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
field.ReflectValueOf(ctx, dst).Set(reflect.ValueOf(v2))
|
||
return nil
|
||
}
|
||
|
||
if field.FieldType.Kind() == reflect.Ptr {
|
||
val := reflect.Zero(field.FieldType.Elem())
|
||
|
||
sv, ok := val.Interface().(StringDBValuer)
|
||
if !ok {
|
||
return fmt.Errorf("field type %v is not StringDBValuer", field.FieldType)
|
||
}
|
||
|
||
v2, err := sv.FromString(str)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
field.ReflectValueOf(ctx, dst).Set(reflect.ValueOf(v2))
|
||
return nil
|
||
}
|
||
|
||
return fmt.Errorf("unsupported field type: %v", field.FieldType)
|
||
}
|
||
|
||
func (StringSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||
val := reflect.ValueOf(fieldValue)
|
||
if val.Kind() == reflect.Struct {
|
||
sv, ok := val.Interface().(StringDBValuer)
|
||
if !ok {
|
||
return nil, fmt.Errorf("ref of field type %v is not StringDBValuer", field.FieldType)
|
||
}
|
||
|
||
return sv.ToString()
|
||
}
|
||
|
||
if val.Kind() == reflect.Ptr {
|
||
sv, ok := val.Elem().Interface().(StringDBValuer)
|
||
if !ok {
|
||
return nil, fmt.Errorf("field type %v is not StringDBValuer", field.FieldType)
|
||
}
|
||
|
||
return sv.ToString()
|
||
}
|
||
|
||
return nil, fmt.Errorf("unsupported field type: %v", field.FieldType)
|
||
}
|
||
|
||
func init() {
|
||
schema.RegisterSerializer("string", StringSerializer{})
|
||
}
|