Skip to content

Commit

Permalink
Merge pull request #1083 from Aoang/chore/reflect-type-for
Browse files Browse the repository at this point in the history
chore: use reflect.TypeFor for known types
  • Loading branch information
j2gg0s authored Dec 4, 2024
2 parents 7703a2e + 0cdae80 commit 1635c14
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 57 deletions.
22 changes: 11 additions & 11 deletions dialect/pgdialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ import (
)

var (
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
driverValuerType = reflect.TypeFor[driver.Valuer]()

stringType = reflect.TypeOf((*string)(nil)).Elem()
sliceStringType = reflect.TypeOf([]string(nil))
stringType = reflect.TypeFor[string]()
sliceStringType = reflect.TypeFor[[]string]()

intType = reflect.TypeOf((*int)(nil)).Elem()
sliceIntType = reflect.TypeOf([]int(nil))
intType = reflect.TypeFor[int]()
sliceIntType = reflect.TypeFor[[]int]()

int64Type = reflect.TypeOf((*int64)(nil)).Elem()
sliceInt64Type = reflect.TypeOf([]int64(nil))
int64Type = reflect.TypeFor[int64]()
sliceInt64Type = reflect.TypeFor[[]int64]()

float64Type = reflect.TypeOf((*float64)(nil)).Elem()
sliceFloat64Type = reflect.TypeOf([]float64(nil))
float64Type = reflect.TypeFor[float64]()
sliceFloat64Type = reflect.TypeFor[[]float64]()

timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
sliceTimeType = reflect.TypeOf([]time.Time(nil))
timeType = reflect.TypeFor[time.Time]()
sliceTimeType = reflect.TypeFor[[]time.Time]()
)

func appendTime(buf []byte, tm time.Time) []byte {
Expand Down
2 changes: 1 addition & 1 deletion dialect/pgdialect/append_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestHStoreAppender(t *testing.T) {
{map[string]string{"{1}": "{2}", "{3}": "{4}"}, []string{`'"{1}"=>"{2}","{3}"=>"{4}"'`, `'"{3}"=>"{4}","{1}"=>"{2}"'`}},
}

appendFunc := pgDialect.hstoreAppender(reflect.TypeOf(map[string]string{}))
appendFunc := pgDialect.hstoreAppender(reflect.TypeFor[map[string]string]())

for i, test := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ const (
)

var (
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
ipType = reflect.TypeFor[net.IP]()
ipNetType = reflect.TypeFor[net.IPNet]()
jsonRawMessageType = reflect.TypeFor[json.RawMessage]()
nullStringType = reflect.TypeFor[sql.NullString]()
)

func (d *Dialect) DefaultVarcharLen() int {
Expand Down
2 changes: 1 addition & 1 deletion internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ type Employee struct {
}

func createTestSchema(t *testing.T, db *bun.DB) {
_ = db.Table(reflect.TypeOf((*BookGenre)(nil)).Elem())
_ = db.Table(reflect.TypeFor[BookGenre]())

models := []interface{}{
(*Image)(nil),
Expand Down
2 changes: 1 addition & 1 deletion internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestPostgresMultiTenant(t *testing.T) {
db := pg(t)

db = db.WithNamedArg("tenant", bun.Safe("public"))
_ = db.Table(reflect.TypeOf((*IngredientRecipe)(nil)).Elem())
_ = db.Table(reflect.TypeFor[IngredientRecipe]())

models := []interface{}{
(*Recipe)(nil),
Expand Down
2 changes: 1 addition & 1 deletion internal/map_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package internal

import "reflect"

var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem()
var ifaceType = reflect.TypeFor[interface{}]()

type MapKey struct {
iface interface{}
Expand Down
4 changes: 2 additions & 2 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
var errNilModel = errors.New("bun: Model(nil)")

var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
timeType = reflect.TypeFor[time.Time]()
bytesType = reflect.TypeFor[[]byte]()
)

type Model = schema.Model
Expand Down
6 changes: 3 additions & 3 deletions schema/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ type BeforeAppendModelHook interface {
BeforeAppendModel(ctx context.Context, query Query) error
}

var beforeAppendModelHookType = reflect.TypeOf((*BeforeAppendModelHook)(nil)).Elem()
var beforeAppendModelHookType = reflect.TypeFor[BeforeAppendModelHook]()

//------------------------------------------------------------------------------

type BeforeScanRowHook interface {
BeforeScanRow(context.Context) error
}

var beforeScanRowHookType = reflect.TypeOf((*BeforeScanRowHook)(nil)).Elem()
var beforeScanRowHookType = reflect.TypeFor[BeforeScanRowHook]()

//------------------------------------------------------------------------------

type AfterScanRowHook interface {
AfterScanRow(context.Context) error
}

var afterScanRowHookType = reflect.TypeOf((*AfterScanRowHook)(nil)).Elem()
var afterScanRowHookType = reflect.TypeFor[AfterScanRowHook]()
22 changes: 11 additions & 11 deletions schema/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ import (
)

var (
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
timePtrType = reflect.TypeOf((*time.Time)(nil))
timeType = timePtrType.Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
netipPrefixType = reflect.TypeOf((*netip.Prefix)(nil)).Elem()
netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
bytesType = reflect.TypeFor[[]byte]()
timePtrType = reflect.TypeFor[*time.Time]()
timeType = reflect.TypeFor[time.Time]()
ipType = reflect.TypeFor[net.IP]()
ipNetType = reflect.TypeFor[net.IPNet]()
netipPrefixType = reflect.TypeFor[netip.Prefix]()
netipAddrType = reflect.TypeFor[netip.Addr]()
jsonRawMessageType = reflect.TypeFor[json.RawMessage]()

driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
driverValuerType = reflect.TypeFor[driver.Valuer]()
queryAppenderType = reflect.TypeFor[QueryAppender]()
jsonMarshalerType = reflect.TypeFor[json.Marshaler]()
)

func indirectType(t reflect.Type) reflect.Type {
Expand Down
2 changes: 1 addition & 1 deletion schema/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/uptrace/bun/internal"
)

var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var scannerType = reflect.TypeFor[sql.Scanner]()

type ScannerFunc func(dest reflect.Value, src interface{}) error

Expand Down
12 changes: 6 additions & 6 deletions schema/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
)

var (
bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
bunNullTimeType = reflect.TypeFor[NullTime]()
nullTimeType = reflect.TypeFor[sql.NullTime]()
nullBoolType = reflect.TypeFor[sql.NullBool]()
nullFloatType = reflect.TypeFor[sql.NullFloat64]()
nullIntType = reflect.TypeFor[sql.NullInt64]()
nullStringType = reflect.TypeFor[sql.NullString]()
)

var sqlTypes = []string{
Expand Down
2 changes: 1 addition & 1 deletion schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
)

var (
baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
baseModelType = reflect.TypeFor[BaseModel]()
tableNameInflector = inflection.Plural
)

Expand Down
26 changes: 13 additions & 13 deletions schema/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestTable(t *testing.T) {
Bar string
}

table := tables.Get(reflect.TypeOf((*Model)(nil)))
table := tables.Get(reflect.TypeFor[*Model]())

require.Len(t, table.allFields, 3)
require.Len(t, table.Fields, 3)
Expand All @@ -37,7 +37,7 @@ func TestTable(t *testing.T) {
Foo string
}

table := tables.Get(reflect.TypeOf((*Model1)(nil)))
table := tables.Get(reflect.TypeFor[*Model1]())

foo, ok := table.FieldMap["foo"]
require.True(t, ok)
Expand All @@ -54,7 +54,7 @@ func TestTable(t *testing.T) {
Model
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
table := tables.Get(reflect.TypeFor[*Model2]())

foo, ok := table.FieldMap["foo"]
require.True(t, ok)
Expand All @@ -70,7 +70,7 @@ func TestTable(t *testing.T) {
BaseModel `bun:"custom_name,alias:custom_alias"`
}

table := tables.Get(reflect.TypeOf((*Model)(nil)))
table := tables.Get(reflect.TypeFor[*Model]())
require.Equal(t, "custom_name", table.Name)
require.Equal(t, "custom_alias", table.Alias)
})
Expand All @@ -83,7 +83,7 @@ func TestTable(t *testing.T) {
Model1 `bun:",extend"`
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
table := tables.Get(reflect.TypeFor[*Model2]())
require.Equal(t, "custom_name", table.Name)
require.Equal(t, "custom_alias", table.Alias)
})
Expand All @@ -99,7 +99,7 @@ func TestTable(t *testing.T) {
Bar Perms `bun:"embed:bar_"`
}

table := tables.Get(reflect.TypeOf((*Role)(nil)))
table := tables.Get(reflect.TypeFor[*Role]())
require.Nil(t, table.StructMap["foo"])
require.Nil(t, table.StructMap["bar"])

Expand All @@ -125,7 +125,7 @@ func TestTable(t *testing.T) {
Perms
}

table := tables.Get(reflect.TypeOf((*Role)(nil)))
table := tables.Get(reflect.TypeFor[*Role]())
require.Nil(t, table.StructMap["foo"])
require.Nil(t, table.StructMap["bar"])

Expand Down Expand Up @@ -157,7 +157,7 @@ func TestTable(t *testing.T) {
Model1
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
table := tables.Get(reflect.TypeFor[*Model2]())
require.Len(t, table.FieldMap, 2)

foo, ok := table.FieldMap["foo"]
Expand All @@ -179,7 +179,7 @@ func TestTable(t *testing.T) {
Baz Model1 `bun:"embed:baz_"`
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
table := tables.Get(reflect.TypeFor[*Model2]())
require.Len(t, table.FieldMap, 2)

foo, ok := table.FieldMap["baz_foo"]
Expand All @@ -202,7 +202,7 @@ func TestTable(t *testing.T) {
Baz string `bun:",scanonly"`
}

table := tables.Get(reflect.TypeOf((*Model2)(nil)))
table := tables.Get(reflect.TypeFor[*Model2]())

require.Len(t, table.StructMap, 1)
require.NotNil(t, table.StructMap["xxx"])
Expand All @@ -229,7 +229,7 @@ func TestTable(t *testing.T) {
Bar string
}

table := tables.Get(reflect.TypeOf((*Model)(nil)))
table := tables.Get(reflect.TypeFor[*Model]())

foo, ok := table.FieldMap["foo"]
require.True(t, ok)
Expand All @@ -247,7 +247,7 @@ func TestTable(t *testing.T) {
Item *Item `bun:"rel:belongs-to,join:item_id=id"`
}

table := tables.Get(reflect.TypeOf((*Item)(nil)))
table := tables.Get(reflect.TypeFor[*Item]())

rel, ok := table.Relations["Item"]
require.True(t, ok)
Expand All @@ -268,7 +268,7 @@ func TestTable(t *testing.T) {
Foo string `bun:"alt:alt_name"`
}

table := tables.Get(reflect.TypeOf((*ModelTest)(nil)))
table := tables.Get(reflect.TypeFor[*ModelTest]())

foo, ok := table.FieldMap["foo"]
require.True(t, ok)
Expand Down
2 changes: 1 addition & 1 deletion schema/zerochecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"reflect"
)

var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
var isZeroerType = reflect.TypeFor[isZeroer]()

type isZeroer interface {
IsZero() bool
Expand Down

0 comments on commit 1635c14

Please sign in to comment.