Skip to content

Commit

Permalink
Add support type for net/netip.Addr and net/netip.Prefix (uptrace#1028)
Browse files Browse the repository at this point in the history
* feat(schema): add support type for net/netip.Addr and net/netip.Prefix

* fix(schema): net.IPNet(not ptr) is not implement fmt.Stringer
  • Loading branch information
Aoang authored and bevzzz committed Oct 19, 2024
1 parent 760de7d commit c92ffbe
Show file tree
Hide file tree
Showing 16 changed files with 874 additions and 9 deletions.
4 changes: 4 additions & 0 deletions dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []b
return append(b, " IDENTITY"...)
}

func (d *Dialect) DefaultSchema() string {
return "dbo"
}

func sqlType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.Timestamp:
Expand Down
4 changes: 4 additions & 0 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []b
return append(b, " AUTO_INCREMENT"...)
}

func (d *Dialect) DefaultSchema() string {
return "mydb"
}

func sqlType(field *schema.Field) string {
if field.DiscoveredSQLType == sqltype.Timestamp {
return datetimeType
Expand Down
4 changes: 4 additions & 0 deletions dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
"github.com/uptrace/bun/schema/inspector"
)

var pgDialect = New()
Expand All @@ -29,6 +30,9 @@ type Dialect struct {
features feature.Feature
}

var _ schema.Dialect = (*Dialect)(nil)
var _ inspector.Dialect = (*Dialect)(nil)

func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
Expand Down
242 changes: 242 additions & 0 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package pgdialect

import (
"context"
"fmt"
"strings"

"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
)

func (d *Dialect) Inspector(db *bun.DB) schema.Inspector {
return newDatabaseInspector(db)
}

type DatabaseInspector struct {
db *bun.DB
}

var _ schema.Inspector = (*DatabaseInspector)(nil)

func newDatabaseInspector(db *bun.DB) *DatabaseInspector {
return &DatabaseInspector{db: db}
}

func (di *DatabaseInspector) Inspect(ctx context.Context) (schema.State, error) {
var state schema.State
var tables []*InformationSchemaTable
if err := di.db.NewRaw(sqlInspectTables).Scan(ctx, &tables); err != nil {
return state, err
}

for _, table := range tables {
var columns []*InformationSchemaColumn
if err := di.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return state, err
}
colDefs := make(map[string]schema.ColumnDef)
for _, c := range columns {
dataType := fromDatabaseType(c.DataType)
if strings.EqualFold(dataType, sqltype.VarChar) && c.VarcharLen > 0 {
dataType = fmt.Sprintf("%s(%d)", dataType, c.VarcharLen)
}

def := c.Default
if c.IsSerial || c.IsIdentity {
def = ""
}

colDefs[c.Name] = schema.ColumnDef{
SQLType: strings.ToLower(dataType),
IsPK: c.IsPK,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
DefaultValue: def,
}
}

state.Tables = append(state.Tables, schema.TableDef{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
})
}
return state, nil
}

type InformationSchemaTable struct {
bun.BaseModel

Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`

Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"`
}

type InformationSchemaColumn struct {
bun.BaseModel

Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsPK bool `bun:"is_pk"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
IsUnique bool `bun:"is_unique"`
UniqueGroup []string `bun:"unique_group,array"`
}

const (
// sqlInspectTables retrieves all user-defined tables across all schemas.
// It excludes relations from Postgres's reserved "pg_" schemas and views from the "information_schema".
sqlInspectTables = `
SELECT table_schema, table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema <> 'information_schema'
AND table_schema NOT LIKE 'pg_%'
`

// sqlInspectColumnsQuery retrieves column definitions for the specified table.
// Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw
// with additional args for table_schema and table_name.
sqlInspectColumnsQuery = `
SELECT
"c".table_schema,
"c".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
'p' = ANY("c".constraint_type) AS is_pk,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
'u' = ANY("c".constraint_type) AS is_unique,
"c"."constraint_name" AS unique_group
FROM (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."constraint_name",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS "table_schema",
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
ARRAY_AGG(con.conname) AS "constraint_name",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
GROUP BY 1, 2, 3, 4, 5
) att USING ("table_schema", "table_name", "column_name")
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
`

// sqlInspectSchema retrieves column type definitions for all user-defined tables.
// Other relations, such as views and indices, as well as Posgres's internal relations are excluded.
sqlInspectSchema = `
SELECT
"t"."table_schema",
"t".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".constraint_type = 'p' AS is_pk,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "t".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
"c".constraint_type = 'u' AS is_unique,
"c"."constraint_name" AS unique_group
FROM information_schema.tables "t"
LEFT JOIN (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."constraint_name",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS table_schema,
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
con.conname AS "constraint_name",
con.contype AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
) att USING (table_schema, "table_name", "column_name")
) "c" USING (table_schema, "table_name")
WHERE table_type = 'BASE TABLE'
AND table_schema <> 'information_schema'
AND table_schema NOT LIKE 'pg_%'
`
)
20 changes: 18 additions & 2 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"net"
"reflect"
"strings"

"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
Expand All @@ -28,8 +29,10 @@ const (
pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer

// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeText = "TEXT" // variable length string without limit
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeText = "TEXT" // variable length string without limit
pgTypeVarchar = "VARCHAR" // variable length string with optional limit
pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR

// JSON Types
pgTypeJSON = "JSON" // text representation of json data
Expand All @@ -48,6 +51,10 @@ func (d *Dialect) DefaultVarcharLen() int {
return 0
}

func (d *Dialect) DefaultSchema() string {
return "public"
}

func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
Expand Down Expand Up @@ -106,3 +113,12 @@ func sqlType(typ reflect.Type) string {

return sqlType
}

// fromDatabaseType converts Postgres-specific type to a more generic `sqltype`.
func fromDatabaseType(dbType string) string {
switch strings.ToUpper(dbType) {
case pgTypeChar, pgTypeVarchar, pgTypeCharacterVarying, pgTypeText:
return sqltype.VarChar
}
return dbType
}
11 changes: 11 additions & 0 deletions dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ func (d *Dialect) DefaultVarcharLen() int {
// AUTOINCREMENT is only valid for INTEGER PRIMARY KEY, and this method will be a noop for other columns.
//
// Because this is a valid construct:
//
// CREATE TABLE ("id" INTEGER PRIMARY KEY AUTOINCREMENT);
//
// and this is not:
//
// CREATE TABLE ("id" INTEGER AUTOINCREMENT, PRIMARY KEY ("id"));
//
// AppendSequence adds a primary key constraint as a *side-effect*. Callers should expect it to avoid building invalid SQL.
// SQLite also [does not support] AUTOINCREMENT column in composite primary keys.
//
Expand All @@ -111,6 +115,13 @@ func (d *Dialect) AppendSequence(b []byte, table *schema.Table, field *schema.Fi
return b
}

// DefaultSchemaName is the "schema-name" of the main database.
// The details might differ from other dialects, but for all means and purposes
// "main" is the default schema in an SQLite database.
func (d *Dialect) DefaultSchema() string {
return "main"
}

func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:
Expand Down
18 changes: 18 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/driver/sqliteshim"
"github.com/uptrace/bun/extra/bundebug"
"github.com/uptrace/bun/schema"

_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -53,6 +54,13 @@ var allDBs = map[string]func(tb testing.TB) *bun.DB{
mssql2019Name: mssql2019,
}

var allDialects = []func() schema.Dialect{
func() schema.Dialect { return pgdialect.New() },
func() schema.Dialect { return mysqldialect.New() },
func() schema.Dialect { return sqlitedialect.New() },
func() schema.Dialect { return mssqldialect.New() },
}

func pg(tb testing.TB) *bun.DB {
dsn := os.Getenv("PG")
if dsn == "" {
Expand Down Expand Up @@ -216,6 +224,16 @@ func testEachDB(t *testing.T, f func(t *testing.T, dbName string, db *bun.DB)) {
}
}

// testEachDialect allows testing dialect-specific functionality that does not require database interactions.
func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect func() schema.Dialect)) {
for _, newDialect := range allDialects {
name := newDialect().Name().String()
t.Run(name, func(t *testing.T) {
f(t, name, newDialect)
})
}
}

func funcName(x interface{}) string {
s := runtime.FuncForPC(reflect.ValueOf(x).Pointer()).Name()
if i := strings.LastIndexByte(s, '.'); i >= 0 {
Expand Down
Loading

0 comments on commit c92ffbe

Please sign in to comment.