Skip to content

Commit

Permalink
Update xsql support for PB non-scalar fields
Browse files Browse the repository at this point in the history
  • Loading branch information
onanying committed Jun 15, 2024
1 parent 0a25d6d commit f10917a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
52 changes: 49 additions & 3 deletions src/xsql/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ type Test struct {
Enum Enum `xsql:"enum" json:"-"`
}

type TestJsonStruct struct {
Test
Json JsonItem `xsql:"json"`
}

type TestJsonStructPtr struct {
Test
Json *JsonItem `xsql:"json"`
}

type TestJsonSlice struct {
Test
Json []int `xsql:"json"`
}

type JsonItem struct {
Foo string `xsql:"foo"`
}

func (t Test) TableName() string {
return "xsql"
}
Expand Down Expand Up @@ -79,10 +98,12 @@ CREATE TABLE #xsql# (
#bar# datetime DEFAULT NULL,
#bool# int NOT NULL DEFAULT '0',
#enum# int NOT NULL DEFAULT '0',
#json# json DEFAULT NULL,
PRIMARY KEY (#id#)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (1, 'v', '2022-04-14 23:49:48', 1, 1);
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1, 1);
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (1, 'v', '2022-04-12 23:50:00', 1, 1, '{"foo":"bar"}');
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (2, 'v1', '2022-04-13 23:50:00', 1, 1, '[1,2]');
INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (3, 'v2', '2022-04-14 23:50:00', 1, 1, null);
`
DB := newDB()
_, err := DB.Exec(strings.ReplaceAll(q, "#", "`"))
Expand Down Expand Up @@ -483,7 +504,6 @@ func TestTxRollback(t *testing.T) {

func TestPbTimestamp(t *testing.T) {
a := assert.New(t)

DB := newDB()

// Insert
Expand All @@ -508,3 +528,29 @@ func TestPbTimestamp(t *testing.T) {
a.IsType(&timestamppb.Timestamp{}, test2.Bar)
a.Equal(test2.Bar.Seconds, now.Seconds)
}

func TestFetchPbJson(t *testing.T) {
a := assert.New(t)
DB := newDB()

var test1 TestJsonStruct
err := DB.First(&test1, "SELECT * FROM xsql WHERE id = 1")
if err != nil {
log.Fatal(err)
}
a.NotEmpty(test1.Json)

var test2 TestJsonStructPtr
err = DB.First(&test2, "SELECT * FROM xsql WHERE id = 1")
if err != nil {
log.Fatal(err)
}
a.NotEmpty(test2.Json)

var test3 TestJsonSlice
err = DB.First(&test3, "SELECT * FROM xsql WHERE id = 2")
if err != nil {
log.Fatal(err)
}
a.NotEmpty(test3.Json)
}
27 changes: 21 additions & 6 deletions src/xsql/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package xsql

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/sijms/go-ora/v2"
Expand Down Expand Up @@ -351,8 +352,7 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.
default:
if !res.Empty() {
vTyp := reflect.ValueOf(v).Type().String()
// 如果结构体是time.Time类型,执行转换
if typ.String() == "time.Time" {
if typ.String() == "time.Time" { // 如果结构体是time.Time类型,执行转换
if vTyp == "time.Time" {
// parseTime=true
v = res.Value()
Expand All @@ -364,20 +364,35 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.
return fmt.Errorf("time parse fail for field %s: %v", tag, e)
}
}
}
// 如果结构体是*timestamppb.Timestamp类型,执行转换
if typ.String() == "*timestamppb.Timestamp" {
} else if typ.String() == "*timestamppb.Timestamp" { // 如果结构体是*timestamppb.Timestamp类型,执行转换
if vTyp != "*timestamppb.Timestamp" {
if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil {
v = timestamppb.New(t)
} else {
return fmt.Errorf("time parse fail for field %s: %v", tag, e)
}
}
} else if typ.Kind() == reflect.Ptr || typ.Kind() == reflect.Struct || typ.Kind() == reflect.Slice || typ.Kind() == reflect.Array { // 非标量用JSON反序列化处理
jsonString := res.String()
var newInstance reflect.Value
if typ.Kind() == reflect.Ptr {
newInstance = reflect.New(typ.Elem()) // 创建的都是指针
} else {
newInstance = reflect.New(typ) // 创建的都是指针
}
if e := json.Unmarshal([]byte(jsonString), newInstance.Interface()); e != nil {
return fmt.Errorf("json unmarshal error for field %s: %v", tag, e)
}
if typ.Kind() == reflect.Ptr {
v = newInstance.Interface()
} else {
v = newInstance.Elem().Interface() // 获取的是非指针
}
}
}
}
// 追加异常信息

// 设置值
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("type mismatch for field %s: %v", tag, e)
Expand Down

0 comments on commit f10917a

Please sign in to comment.