From 0cdae809fda1d32672888fab9e864ef39f8fdbbe Mon Sep 17 00:00:00 2001 From: Aoang Date: Wed, 4 Dec 2024 12:06:50 +0800 Subject: [PATCH] chore: use reflect.TypeFor for known types --- dialect/pgdialect/append.go | 22 +++++++++++----------- dialect/pgdialect/append_test.go | 2 +- dialect/pgdialect/sqltype.go | 8 ++++---- internal/dbtest/orm_test.go | 2 +- internal/dbtest/pg_test.go | 2 +- internal/map_key.go | 2 +- model.go | 4 ++-- schema/hook.go | 6 +++--- schema/reflect.go | 22 +++++++++++----------- schema/scan.go | 2 +- schema/sqltype.go | 12 ++++++------ schema/table.go | 2 +- schema/table_test.go | 26 +++++++++++++------------- schema/zerochecker.go | 2 +- 14 files changed, 57 insertions(+), 57 deletions(-) diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index aa2b6850d..18a1f9baf 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -11,22 +11,22 @@ import ( ) var ( - driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + driverValuerType = reflect.TypeFor[driver.Valuer]() - stringType = reflect.TypeOf((*string)(nil)).Elem() - sliceStringType = reflect.TypeOf([]string(nil)) + stringType = reflect.TypeFor[string]() + sliceStringType = reflect.TypeFor[[]string]() - intType = reflect.TypeOf((*int)(nil)).Elem() - sliceIntType = reflect.TypeOf([]int(nil)) + intType = reflect.TypeFor[int]() + sliceIntType = reflect.TypeFor[[]int]() - int64Type = reflect.TypeOf((*int64)(nil)).Elem() - sliceInt64Type = reflect.TypeOf([]int64(nil)) + int64Type = reflect.TypeFor[int64]() + sliceInt64Type = reflect.TypeFor[[]int64]() - float64Type = reflect.TypeOf((*float64)(nil)).Elem() - sliceFloat64Type = reflect.TypeOf([]float64(nil)) + float64Type = reflect.TypeFor[float64]() + sliceFloat64Type = reflect.TypeFor[[]float64]() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - sliceTimeType = reflect.TypeOf([]time.Time(nil)) + timeType = reflect.TypeFor[time.Time]() + sliceTimeType = reflect.TypeFor[[]time.Time]() ) func appendTime(buf []byte, tm time.Time) []byte { diff --git a/dialect/pgdialect/append_test.go b/dialect/pgdialect/append_test.go index eeadb5e24..3ee71778c 100644 --- a/dialect/pgdialect/append_test.go +++ b/dialect/pgdialect/append_test.go @@ -28,7 +28,7 @@ func TestHStoreAppender(t *testing.T) { {map[string]string{"{1}": "{2}", "{3}": "{4}"}, []string{`'"{1}"=>"{2}","{3}"=>"{4}"'`, `'"{3}"=>"{4}","{1}"=>"{2}"'`}}, } - appendFunc := pgDialect.hstoreAppender(reflect.TypeOf(map[string]string{})) + appendFunc := pgDialect.hstoreAppender(reflect.TypeFor[map[string]string]()) for i, test := range tests { t.Run(fmt.Sprint(i), func(t *testing.T) { diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index bacc00e86..99075cbc1 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -44,10 +44,10 @@ const ( ) var ( - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() - nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + ipType = reflect.TypeFor[net.IP]() + ipNetType = reflect.TypeFor[net.IPNet]() + jsonRawMessageType = reflect.TypeFor[json.RawMessage]() + nullStringType = reflect.TypeFor[sql.NullString]() ) func (d *Dialect) DefaultVarcharLen() int { diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index 1830e999d..5d3be7a5d 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -635,7 +635,7 @@ type Employee struct { } func createTestSchema(t *testing.T, db *bun.DB) { - _ = db.Table(reflect.TypeOf((*BookGenre)(nil)).Elem()) + _ = db.Table(reflect.TypeFor[BookGenre]()) models := []interface{}{ (*Image)(nil), diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index 21e80b852..a1a3cffb2 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -140,7 +140,7 @@ func TestPostgresMultiTenant(t *testing.T) { db := pg(t) db = db.WithNamedArg("tenant", bun.Safe("public")) - _ = db.Table(reflect.TypeOf((*IngredientRecipe)(nil)).Elem()) + _ = db.Table(reflect.TypeFor[IngredientRecipe]()) models := []interface{}{ (*Recipe)(nil), diff --git a/internal/map_key.go b/internal/map_key.go index bb5fcca8c..d7e4de2b9 100644 --- a/internal/map_key.go +++ b/internal/map_key.go @@ -2,7 +2,7 @@ package internal import "reflect" -var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem() +var ifaceType = reflect.TypeFor[interface{}]() type MapKey struct { iface interface{} diff --git a/model.go b/model.go index 046bfdfea..6254fc3ed 100644 --- a/model.go +++ b/model.go @@ -14,8 +14,8 @@ import ( var errNilModel = errors.New("bun: Model(nil)") var ( - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() + timeType = reflect.TypeFor[time.Time]() + bytesType = reflect.TypeFor[[]byte]() ) type Model = schema.Model diff --git a/schema/hook.go b/schema/hook.go index b83106d80..f8c32f689 100644 --- a/schema/hook.go +++ b/schema/hook.go @@ -24,7 +24,7 @@ type BeforeAppendModelHook interface { BeforeAppendModel(ctx context.Context, query Query) error } -var beforeAppendModelHookType = reflect.TypeOf((*BeforeAppendModelHook)(nil)).Elem() +var beforeAppendModelHookType = reflect.TypeFor[BeforeAppendModelHook]() //------------------------------------------------------------------------------ @@ -32,7 +32,7 @@ type BeforeScanRowHook interface { BeforeScanRow(context.Context) error } -var beforeScanRowHookType = reflect.TypeOf((*BeforeScanRowHook)(nil)).Elem() +var beforeScanRowHookType = reflect.TypeFor[BeforeScanRowHook]() //------------------------------------------------------------------------------ @@ -40,4 +40,4 @@ type AfterScanRowHook interface { AfterScanRow(context.Context) error } -var afterScanRowHookType = reflect.TypeOf((*AfterScanRowHook)(nil)).Elem() +var afterScanRowHookType = reflect.TypeFor[AfterScanRowHook]() diff --git a/schema/reflect.go b/schema/reflect.go index 75980b102..3435fa1c8 100644 --- a/schema/reflect.go +++ b/schema/reflect.go @@ -10,18 +10,18 @@ import ( ) var ( - bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() - timePtrType = reflect.TypeOf((*time.Time)(nil)) - timeType = timePtrType.Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - netipPrefixType = reflect.TypeOf((*netip.Prefix)(nil)).Elem() - netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + bytesType = reflect.TypeFor[[]byte]() + timePtrType = reflect.TypeFor[*time.Time]() + timeType = reflect.TypeFor[time.Time]() + ipType = reflect.TypeFor[net.IP]() + ipNetType = reflect.TypeFor[net.IPNet]() + netipPrefixType = reflect.TypeFor[netip.Prefix]() + netipAddrType = reflect.TypeFor[netip.Addr]() + jsonRawMessageType = reflect.TypeFor[json.RawMessage]() - driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() - jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + driverValuerType = reflect.TypeFor[driver.Valuer]() + queryAppenderType = reflect.TypeFor[QueryAppender]() + jsonMarshalerType = reflect.TypeFor[json.Marshaler]() ) func indirectType(t reflect.Type) reflect.Type { diff --git a/schema/scan.go b/schema/scan.go index 4da160daf..1b7f05b7d 100644 --- a/schema/scan.go +++ b/schema/scan.go @@ -18,7 +18,7 @@ import ( "github.com/uptrace/bun/internal" ) -var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var scannerType = reflect.TypeFor[sql.Scanner]() type ScannerFunc func(dest reflect.Value, src interface{}) error diff --git a/schema/sqltype.go b/schema/sqltype.go index 233ba641b..e96174065 100644 --- a/schema/sqltype.go +++ b/schema/sqltype.go @@ -13,12 +13,12 @@ import ( ) var ( - bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem() - nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() - nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() - nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() - nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() - nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + bunNullTimeType = reflect.TypeFor[NullTime]() + nullTimeType = reflect.TypeFor[sql.NullTime]() + nullBoolType = reflect.TypeFor[sql.NullBool]() + nullFloatType = reflect.TypeFor[sql.NullFloat64]() + nullIntType = reflect.TypeFor[sql.NullInt64]() + nullStringType = reflect.TypeFor[sql.NullString]() ) var sqlTypes = []string{ diff --git a/schema/table.go b/schema/table.go index 7e6c6b73c..bb0e13bb1 100644 --- a/schema/table.go +++ b/schema/table.go @@ -22,7 +22,7 @@ const ( ) var ( - baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() + baseModelType = reflect.TypeFor[BaseModel]() tableNameInflector = inflection.Plural ) diff --git a/schema/table_test.go b/schema/table_test.go index 71b29f4a8..db1767c62 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -18,7 +18,7 @@ func TestTable(t *testing.T) { Bar string } - table := tables.Get(reflect.TypeOf((*Model)(nil))) + table := tables.Get(reflect.TypeFor[*Model]()) require.Len(t, table.allFields, 3) require.Len(t, table.Fields, 3) @@ -37,7 +37,7 @@ func TestTable(t *testing.T) { Foo string } - table := tables.Get(reflect.TypeOf((*Model1)(nil))) + table := tables.Get(reflect.TypeFor[*Model1]()) foo, ok := table.FieldMap["foo"] require.True(t, ok) @@ -54,7 +54,7 @@ func TestTable(t *testing.T) { Model } - table := tables.Get(reflect.TypeOf((*Model2)(nil))) + table := tables.Get(reflect.TypeFor[*Model2]()) foo, ok := table.FieldMap["foo"] require.True(t, ok) @@ -70,7 +70,7 @@ func TestTable(t *testing.T) { BaseModel `bun:"custom_name,alias:custom_alias"` } - table := tables.Get(reflect.TypeOf((*Model)(nil))) + table := tables.Get(reflect.TypeFor[*Model]()) require.Equal(t, "custom_name", table.Name) require.Equal(t, "custom_alias", table.Alias) }) @@ -83,7 +83,7 @@ func TestTable(t *testing.T) { Model1 `bun:",extend"` } - table := tables.Get(reflect.TypeOf((*Model2)(nil))) + table := tables.Get(reflect.TypeFor[*Model2]()) require.Equal(t, "custom_name", table.Name) require.Equal(t, "custom_alias", table.Alias) }) @@ -99,7 +99,7 @@ func TestTable(t *testing.T) { Bar Perms `bun:"embed:bar_"` } - table := tables.Get(reflect.TypeOf((*Role)(nil))) + table := tables.Get(reflect.TypeFor[*Role]()) require.Nil(t, table.StructMap["foo"]) require.Nil(t, table.StructMap["bar"]) @@ -125,7 +125,7 @@ func TestTable(t *testing.T) { Perms } - table := tables.Get(reflect.TypeOf((*Role)(nil))) + table := tables.Get(reflect.TypeFor[*Role]()) require.Nil(t, table.StructMap["foo"]) require.Nil(t, table.StructMap["bar"]) @@ -157,7 +157,7 @@ func TestTable(t *testing.T) { Model1 } - table := tables.Get(reflect.TypeOf((*Model2)(nil))) + table := tables.Get(reflect.TypeFor[*Model2]()) require.Len(t, table.FieldMap, 2) foo, ok := table.FieldMap["foo"] @@ -179,7 +179,7 @@ func TestTable(t *testing.T) { Baz Model1 `bun:"embed:baz_"` } - table := tables.Get(reflect.TypeOf((*Model2)(nil))) + table := tables.Get(reflect.TypeFor[*Model2]()) require.Len(t, table.FieldMap, 2) foo, ok := table.FieldMap["baz_foo"] @@ -202,7 +202,7 @@ func TestTable(t *testing.T) { Baz string `bun:",scanonly"` } - table := tables.Get(reflect.TypeOf((*Model2)(nil))) + table := tables.Get(reflect.TypeFor[*Model2]()) require.Len(t, table.StructMap, 1) require.NotNil(t, table.StructMap["xxx"]) @@ -229,7 +229,7 @@ func TestTable(t *testing.T) { Bar string } - table := tables.Get(reflect.TypeOf((*Model)(nil))) + table := tables.Get(reflect.TypeFor[*Model]()) foo, ok := table.FieldMap["foo"] require.True(t, ok) @@ -247,7 +247,7 @@ func TestTable(t *testing.T) { Item *Item `bun:"rel:belongs-to,join:item_id=id"` } - table := tables.Get(reflect.TypeOf((*Item)(nil))) + table := tables.Get(reflect.TypeFor[*Item]()) rel, ok := table.Relations["Item"] require.True(t, ok) @@ -268,7 +268,7 @@ func TestTable(t *testing.T) { Foo string `bun:"alt:alt_name"` } - table := tables.Get(reflect.TypeOf((*ModelTest)(nil))) + table := tables.Get(reflect.TypeFor[*ModelTest]()) foo, ok := table.FieldMap["foo"] require.True(t, ok) diff --git a/schema/zerochecker.go b/schema/zerochecker.go index 7c1f088c1..7c8418eaf 100644 --- a/schema/zerochecker.go +++ b/schema/zerochecker.go @@ -5,7 +5,7 @@ import ( "reflect" ) -var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() +var isZeroerType = reflect.TypeFor[isZeroer]() type isZeroer interface { IsZero() bool