From c92ffbee28be5e5fc970e41768e751f7d15ad77d Mon Sep 17 00:00:00 2001 From: Aoang Date: Thu, 17 Oct 2024 13:55:16 +0800 Subject: [PATCH] Add support type for net/netip.Addr and net/netip.Prefix (#1028) * feat(schema): add support type for net/netip.Addr and net/netip.Prefix * fix(schema): net.IPNet(not ptr) is not implement fmt.Stringer --- dialect/mssqldialect/dialect.go | 4 + dialect/mysqldialect/dialect.go | 4 + dialect/pgdialect/dialect.go | 4 + dialect/pgdialect/inspector.go | 242 +++++++++++++++++++++++++++++++ dialect/pgdialect/sqltype.go | 20 ++- dialect/sqlitedialect/dialect.go | 11 ++ internal/dbtest/db_test.go | 18 +++ internal/dbtest/inspect_test.go | 112 ++++++++++++++ internal/dbtest/migrate_test.go | 138 ++++++++++++++++++ migrate/auto.go | 212 +++++++++++++++++++++++++++ schema/append_value.go | 13 +- schema/dialect.go | 3 + schema/inspector.go | 76 ++++++++++ schema/inspector/dialect.go | 11 ++ schema/reflect.go | 3 + schema/tables.go | 12 ++ 16 files changed, 874 insertions(+), 9 deletions(-) create mode 100644 dialect/pgdialect/inspector.go create mode 100644 internal/dbtest/inspect_test.go create mode 100644 migrate/auto.go create mode 100644 schema/inspector.go create mode 100644 schema/inspector/dialect.go diff --git a/dialect/mssqldialect/dialect.go b/dialect/mssqldialect/dialect.go index a5c99a274..bde140963 100755 --- a/dialect/mssqldialect/dialect.go +++ b/dialect/mssqldialect/dialect.go @@ -141,6 +141,10 @@ func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []b return append(b, " IDENTITY"...) } +func (d *Dialect) DefaultSchema() string { + return "dbo" +} + func sqlType(field *schema.Field) string { switch field.DiscoveredSQLType { case sqltype.Timestamp: diff --git a/dialect/mysqldialect/dialect.go b/dialect/mysqldialect/dialect.go index 881aa7ebf..9b4dfe87c 100644 --- a/dialect/mysqldialect/dialect.go +++ b/dialect/mysqldialect/dialect.go @@ -206,6 +206,10 @@ func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []b return append(b, " AUTO_INCREMENT"...) } +func (d *Dialect) DefaultSchema() string { + return "mydb" +} + func sqlType(field *schema.Field) string { if field.DiscoveredSQLType == sqltype.Timestamp { return datetimeType diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 358971f61..766aa1be4 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -11,6 +11,7 @@ import ( "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/schema" + "github.com/uptrace/bun/schema/inspector" ) var pgDialect = New() @@ -29,6 +30,9 @@ type Dialect struct { features feature.Feature } +var _ schema.Dialect = (*Dialect)(nil) +var _ inspector.Dialect = (*Dialect)(nil) + func New() *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go new file mode 100644 index 000000000..418140855 --- /dev/null +++ b/dialect/pgdialect/inspector.go @@ -0,0 +1,242 @@ +package pgdialect + +import ( + "context" + "fmt" + "strings" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +func (d *Dialect) Inspector(db *bun.DB) schema.Inspector { + return newDatabaseInspector(db) +} + +type DatabaseInspector struct { + db *bun.DB +} + +var _ schema.Inspector = (*DatabaseInspector)(nil) + +func newDatabaseInspector(db *bun.DB) *DatabaseInspector { + return &DatabaseInspector{db: db} +} + +func (di *DatabaseInspector) Inspect(ctx context.Context) (schema.State, error) { + var state schema.State + var tables []*InformationSchemaTable + if err := di.db.NewRaw(sqlInspectTables).Scan(ctx, &tables); err != nil { + return state, err + } + + for _, table := range tables { + var columns []*InformationSchemaColumn + if err := di.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { + return state, err + } + colDefs := make(map[string]schema.ColumnDef) + for _, c := range columns { + dataType := fromDatabaseType(c.DataType) + if strings.EqualFold(dataType, sqltype.VarChar) && c.VarcharLen > 0 { + dataType = fmt.Sprintf("%s(%d)", dataType, c.VarcharLen) + } + + def := c.Default + if c.IsSerial || c.IsIdentity { + def = "" + } + + colDefs[c.Name] = schema.ColumnDef{ + SQLType: strings.ToLower(dataType), + IsPK: c.IsPK, + IsNullable: c.IsNullable, + IsAutoIncrement: c.IsSerial, + IsIdentity: c.IsIdentity, + DefaultValue: def, + } + } + + state.Tables = append(state.Tables, schema.TableDef{ + Schema: table.Schema, + Name: table.Name, + Columns: colDefs, + }) + } + return state, nil +} + +type InformationSchemaTable struct { + bun.BaseModel + + Schema string `bun:"table_schema,pk"` + Name string `bun:"table_name,pk"` + + Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"` +} + +type InformationSchemaColumn struct { + bun.BaseModel + + Schema string `bun:"table_schema"` + Table string `bun:"table_name"` + Name string `bun:"column_name"` + DataType string `bun:"data_type"` + VarcharLen int `bun:"varchar_len"` + IsArray bool `bun:"is_array"` + ArrayDims int `bun:"array_dims"` + Default string `bun:"default"` + IsPK bool `bun:"is_pk"` + IsIdentity bool `bun:"is_identity"` + IndentityType string `bun:"identity_type"` + IsSerial bool `bun:"is_serial"` + IsNullable bool `bun:"is_nullable"` + IsUnique bool `bun:"is_unique"` + UniqueGroup []string `bun:"unique_group,array"` +} + +const ( + // sqlInspectTables retrieves all user-defined tables across all schemas. + // It excludes relations from Postgres's reserved "pg_" schemas and views from the "information_schema". + sqlInspectTables = ` +SELECT table_schema, table_name +FROM information_schema.tables +WHERE table_type = 'BASE TABLE' + AND table_schema <> 'information_schema' + AND table_schema NOT LIKE 'pg_%' + ` + + // sqlInspectColumnsQuery retrieves column definitions for the specified table. + // Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw + // with additional args for table_schema and table_name. + sqlInspectColumnsQuery = ` +SELECT + "c".table_schema, + "c".table_name, + "c".column_name, + "c".data_type, + "c".character_maximum_length::integer AS varchar_len, + "c".data_type = 'ARRAY' AS is_array, + COALESCE("c".array_dims, 0) AS array_dims, + CASE + WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') + ELSE "c".column_default + END AS "default", + 'p' = ANY("c".constraint_type) AS is_pk, + "c".is_identity = 'YES' AS is_identity, + "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial, + COALESCE("c".identity_type, '') AS identity_type, + "c".is_nullable = 'YES' AS is_nullable, + 'u' = ANY("c".constraint_type) AS is_unique, + "c"."constraint_name" AS unique_group +FROM ( + SELECT + "table_schema", + "table_name", + "column_name", + "c".data_type, + "c".character_maximum_length, + "c".column_default, + "c".is_identity, + "c".is_nullable, + att.array_dims, + att.identity_type, + att."constraint_name", + att."constraint_type" + FROM information_schema.columns "c" + LEFT JOIN ( + SELECT + s.nspname AS "table_schema", + "t".relname AS "table_name", + "c".attname AS "column_name", + "c".attndims AS array_dims, + "c".attidentity AS identity_type, + ARRAY_AGG(con.conname) AS "constraint_name", + ARRAY_AGG(con.contype) AS "constraint_type" + FROM ( + SELECT + conname, + contype, + connamespace, + conrelid, + conrelid AS attrelid, + UNNEST(conkey) AS attnum + FROM pg_constraint + ) con + LEFT JOIN pg_attribute "c" USING (attrelid, attnum) + LEFT JOIN pg_namespace s ON s.oid = con.connamespace + LEFT JOIN pg_class "t" ON "t".oid = con.conrelid + GROUP BY 1, 2, 3, 4, 5 + ) att USING ("table_schema", "table_name", "column_name") + ) "c" +WHERE "table_schema" = ? AND "table_name" = ? + ` + + // sqlInspectSchema retrieves column type definitions for all user-defined tables. + // Other relations, such as views and indices, as well as Posgres's internal relations are excluded. + sqlInspectSchema = ` +SELECT + "t"."table_schema", + "t".table_name, + "c".column_name, + "c".data_type, + "c".character_maximum_length::integer AS varchar_len, + "c".data_type = 'ARRAY' AS is_array, + COALESCE("c".array_dims, 0) AS array_dims, + CASE + WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') + ELSE "c".column_default + END AS "default", + "c".constraint_type = 'p' AS is_pk, + "c".is_identity = 'YES' AS is_identity, + "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "t".table_name, "c".column_name) AS is_serial, + COALESCE("c".identity_type, '') AS identity_type, + "c".is_nullable = 'YES' AS is_nullable, + "c".constraint_type = 'u' AS is_unique, + "c"."constraint_name" AS unique_group +FROM information_schema.tables "t" + LEFT JOIN ( + SELECT + "table_schema", + "table_name", + "column_name", + "c".data_type, + "c".character_maximum_length, + "c".column_default, + "c".is_identity, + "c".is_nullable, + att.array_dims, + att.identity_type, + att."constraint_name", + att."constraint_type" + FROM information_schema.columns "c" + LEFT JOIN ( + SELECT + s.nspname AS table_schema, + "t".relname AS "table_name", + "c".attname AS "column_name", + "c".attndims AS array_dims, + "c".attidentity AS identity_type, + con.conname AS "constraint_name", + con.contype AS "constraint_type" + FROM ( + SELECT + conname, + contype, + connamespace, + conrelid, + conrelid AS attrelid, + UNNEST(conkey) AS attnum + FROM pg_constraint + ) con + LEFT JOIN pg_attribute "c" USING (attrelid, attnum) + LEFT JOIN pg_namespace s ON s.oid = con.connamespace + LEFT JOIN pg_class "t" ON "t".oid = con.conrelid + ) att USING (table_schema, "table_name", "column_name") + ) "c" USING (table_schema, "table_name") +WHERE table_type = 'BASE TABLE' + AND table_schema <> 'information_schema' + AND table_schema NOT LIKE 'pg_%' + ` +) diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index 40802e51d..822d3207d 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net" "reflect" + "strings" "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/schema" @@ -28,8 +29,10 @@ const ( pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer // Character Types - pgTypeChar = "CHAR" // fixed length string (blank padded) - pgTypeText = "TEXT" // variable length string without limit + pgTypeChar = "CHAR" // fixed length string (blank padded) + pgTypeText = "TEXT" // variable length string without limit + pgTypeVarchar = "VARCHAR" // variable length string with optional limit + pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR // JSON Types pgTypeJSON = "JSON" // text representation of json data @@ -48,6 +51,10 @@ func (d *Dialect) DefaultVarcharLen() int { return 0 } +func (d *Dialect) DefaultSchema() string { + return "public" +} + func fieldSQLType(field *schema.Field) string { if field.UserSQLType != "" { return field.UserSQLType @@ -106,3 +113,12 @@ func sqlType(typ reflect.Type) string { return sqlType } + +// fromDatabaseType converts Postgres-specific type to a more generic `sqltype`. +func fromDatabaseType(dbType string) string { + switch strings.ToUpper(dbType) { + case pgTypeChar, pgTypeVarchar, pgTypeCharacterVarying, pgTypeText: + return sqltype.VarChar + } + return dbType +} diff --git a/dialect/sqlitedialect/dialect.go b/dialect/sqlitedialect/dialect.go index 3bfe500ff..c2c676d05 100644 --- a/dialect/sqlitedialect/dialect.go +++ b/dialect/sqlitedialect/dialect.go @@ -96,9 +96,13 @@ func (d *Dialect) DefaultVarcharLen() int { // AUTOINCREMENT is only valid for INTEGER PRIMARY KEY, and this method will be a noop for other columns. // // Because this is a valid construct: +// // CREATE TABLE ("id" INTEGER PRIMARY KEY AUTOINCREMENT); +// // and this is not: +// // CREATE TABLE ("id" INTEGER AUTOINCREMENT, PRIMARY KEY ("id")); +// // AppendSequence adds a primary key constraint as a *side-effect*. Callers should expect it to avoid building invalid SQL. // SQLite also [does not support] AUTOINCREMENT column in composite primary keys. // @@ -111,6 +115,13 @@ func (d *Dialect) AppendSequence(b []byte, table *schema.Table, field *schema.Fi return b } +// DefaultSchemaName is the "schema-name" of the main database. +// The details might differ from other dialects, but for all means and purposes +// "main" is the default schema in an SQLite database. +func (d *Dialect) DefaultSchema() string { + return "main" +} + func fieldSQLType(field *schema.Field) string { switch field.DiscoveredSQLType { case sqltype.SmallInt, sqltype.BigInt: diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 8055d6e4f..943c66d51 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -24,6 +24,7 @@ import ( "github.com/uptrace/bun/driver/pgdriver" "github.com/uptrace/bun/driver/sqliteshim" "github.com/uptrace/bun/extra/bundebug" + "github.com/uptrace/bun/schema" _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" @@ -53,6 +54,13 @@ var allDBs = map[string]func(tb testing.TB) *bun.DB{ mssql2019Name: mssql2019, } +var allDialects = []func() schema.Dialect{ + func() schema.Dialect { return pgdialect.New() }, + func() schema.Dialect { return mysqldialect.New() }, + func() schema.Dialect { return sqlitedialect.New() }, + func() schema.Dialect { return mssqldialect.New() }, +} + func pg(tb testing.TB) *bun.DB { dsn := os.Getenv("PG") if dsn == "" { @@ -216,6 +224,16 @@ func testEachDB(t *testing.T, f func(t *testing.T, dbName string, db *bun.DB)) { } } +// testEachDialect allows testing dialect-specific functionality that does not require database interactions. +func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect func() schema.Dialect)) { + for _, newDialect := range allDialects { + name := newDialect().Name().String() + t.Run(name, func(t *testing.T) { + f(t, name, newDialect) + }) + } +} + func funcName(x interface{}) string { s := runtime.FuncForPC(reflect.ValueOf(x).Pointer()).Name() if i := strings.LastIndexByte(s, '.'); i >= 0 { diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go new file mode 100644 index 000000000..9b092ef4d --- /dev/null +++ b/internal/dbtest/inspect_test.go @@ -0,0 +1,112 @@ +package dbtest_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/schema" + "github.com/uptrace/bun/schema/inspector" +) + +func TestDatabaseInspector_Inspect(t *testing.T) { + + type Book struct { + bun.BaseModel `bun:"table:books"` + + ISBN int `bun:",pk,identity"` + Author string `bun:",notnull,unique:title_author,default:'john doe'"` + Title string `bun:",notnull,unique:title_author"` + Locale string `bun:",type:varchar(5),default:'en-GB'"` + Pages int8 `bun:"page_count,notnull,default:1"` + Count int32 `bun:"book_count,autoincrement"` + } + + testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { + var dialect inspector.Dialect + dbDialect := db.Dialect() + + if id, ok := dbDialect.(inspector.Dialect); ok { + dialect = id + } else { + t.Skipf("%q dialect does not implement inspector.Dialect", dbDialect.Name()) + } + + ctx := context.Background() + createTableOrSkip(t, ctx, db, (*Book)(nil)) + + dbInspector := dialect.Inspector(db) + want := schema.State{ + Tables: []schema.TableDef{ + { + Schema: "public", + Name: "books", + Columns: map[string]schema.ColumnDef{ + "isbn": { + SQLType: "bigint", + IsPK: true, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: true, + DefaultValue: "", + }, + "author": { + SQLType: "varchar", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "john doe", + }, + "title": { + SQLType: "varchar", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "", + }, + "locale": { + SQLType: "varchar(5)", + IsPK: false, + IsNullable: true, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "en-GB", + }, + "page_count": { + SQLType: "smallint", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "1", + }, + "book_count": { + SQLType: "integer", + IsPK: false, + IsNullable: false, + IsAutoIncrement: true, + IsIdentity: false, + DefaultValue: "", + }, + }, + }, + }, + } + + got, err := dbInspector.Inspect(ctx) + require.NoError(t, err) + require.Equal(t, want, got) + }) +} + +func getDatabaseInspectorOrSkip(tb testing.TB, db *bun.DB) schema.Inspector { + dialect := db.Dialect() + if id, ok := dialect.(inspector.Dialect); ok { + return id.Inspector(db) + } + tb.Skipf("%q dialect does not implement inspector.Dialect", dialect.Name()) + return nil +} diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 74e33eab2..bab42e9b3 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/uptrace/bun" "github.com/uptrace/bun/migrate" + "github.com/uptrace/bun/schema" ) const ( @@ -158,3 +159,140 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { require.Len(t, group.Migrations, 2) require.Equal(t, []string{"down2", "down1"}, history) } + +func TestAutoMigrator_Migrate(t *testing.T) { + tests := []struct { + fn func(t *testing.T, db *bun.DB) + }{ + {testRenameTable}, + } + + testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { + for _, tt := range tests { + t.Run(funcName(tt.fn), func(t *testing.T) { + tt.fn(t, db) + }) + } + }) +} + +func testRenameTable(t *testing.T, db *bun.DB) { + type initial struct { + bun.BaseModel `bun:"table:initial"` + Foo int `bun:"foo,notnull"` + } + + type changed struct { + bun.BaseModel `bun:"table:changed"` + Foo int `bun:"foo,notnull"` + } + + // Arrange + ctx := context.Background() + di := getDatabaseInspectorOrSkip(t, db) + createTableOrSkip(t, ctx, db, (*initial)(nil)) + + m, err := migrate.NewAutoMigrator(db) + require.NoError(t, err) + m.SetModels((*changed)(nil)) + + // Act + err = m.Migrate(ctx) + require.NoError(t, err) + + // Assert + state, err := di.Inspect(ctx) + require.NoError(t, err) + + tables := state.Tables + require.Len(t, tables, 1) + require.Equal(t, "changed", tables[0].Name) +} + +func createTableOrSkip(tb testing.TB, ctx context.Context, db *bun.DB, model interface{}) { + tb.Helper() + if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil { + tb.Skip("setup failed:", err) + } + tb.Cleanup(func() { + if _, err := db.NewDropTable().IfExists().Model(model).Exec(ctx); err != nil { + tb.Log("cleanup:", err) + } + }) +} + +func TestDetector_Diff(t *testing.T) { + tests := []struct { + name string + states func(testing.TB, context.Context, func() schema.Dialect) (stateDb schema.State, stateModel schema.State) + operations []migrate.Operation + }{ + { + name: "find a renamed table", + states: renamedTableStates, + operations: []migrate.Operation{ + &migrate.RenameTable{ + From: "books", + To: "books_renamed", + }, + }, + }, + } + + testEachDialect(t, func(t *testing.T, dialectName string, dialect func() schema.Dialect) { + if dialectName != "pg" { + t.Skip() + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + var d migrate.Detector + stateDb, stateModel := tt.states(t, ctx, dialect) + + diff := d.Diff(stateDb, stateModel) + + require.Equal(t, tt.operations, diff.Operations()) + }) + } + }) +} + +func renamedTableStates(tb testing.TB, ctx context.Context, dialect func() schema.Dialect) (s1, s2 schema.State) { + type Book struct { + bun.BaseModel + + ISBN string `bun:"isbn,pk"` + Title string `bun:"title,notnull"` + Pages int `bun:"page_count,notnull,default:0"` + } + + type Author struct { + bun.BaseModel + Name string `bun:"name"` + } + + type BookRenamed struct { + bun.BaseModel `bun:"table:books_renamed"` + + ISBN string `bun:"isbn,pk"` + Title string `bun:"title,notnull"` + Pages int `bun:"page_count,notnull,default:0"` + } + return getState(tb, ctx, dialect(), + (*Author)(nil), + (*Book)(nil), + ), getState(tb, ctx, dialect(), + (*Author)(nil), + (*BookRenamed)(nil), + ) +} + +func getState(tb testing.TB, ctx context.Context, dialect schema.Dialect, models ...interface{}) schema.State { + inspector := schema.NewInspector(dialect, models...) + state, err := inspector.Inspect(ctx) + if err != nil { + tb.Skip("get state: %w", err) + } + return state +} diff --git a/migrate/auto.go b/migrate/auto.go new file mode 100644 index 000000000..8453e069d --- /dev/null +++ b/migrate/auto.go @@ -0,0 +1,212 @@ +package migrate + +import ( + "context" + "fmt" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/schema" + "github.com/uptrace/bun/schema/inspector" +) + +type AutoMigrator struct { + db *bun.DB + + // models limit the set of tables considered for the migration. + models []interface{} + + // dbInspector creates the current state for the target database. + dbInspector schema.Inspector + + // modelInspector creates the desired state based on the model definitions. + modelInspector schema.Inspector +} + +func NewAutoMigrator(db *bun.DB) (*AutoMigrator, error) { + dialect := db.Dialect() + withInspector, ok := dialect.(inspector.Dialect) + if !ok { + return nil, fmt.Errorf("%q dialect does not implement inspector.Dialect", dialect.Name()) + } + + return &AutoMigrator{ + db: db, + dbInspector: withInspector.Inspector(db), + }, nil +} + +func (am *AutoMigrator) SetModels(models ...interface{}) { + am.models = models +} + +func (am *AutoMigrator) diff(ctx context.Context) (Changeset, error) { + var changes Changeset + var err error + + // TODO: do on "SetModels" + am.modelInspector = schema.NewInspector(am.db.Dialect(), am.models...) + + _, err = am.dbInspector.Inspect(ctx) + if err != nil { + return changes, err + } + + _, err = am.modelInspector.Inspect(ctx) + if err != nil { + return changes, err + } + return changes, nil +} + +func (am *AutoMigrator) Migrate(ctx context.Context) error { + return nil +} + +// INTERNAL ------------------------------------------------------------------- + +// Operation is an abstraction a level above a MigrationFunc. +// Apart from storing the function to execute the change, +// it knows how to *write* the corresponding code, and what the reverse operation is. +type Operation interface { + Func() MigrationFunc +} + +type RenameTable struct { + From string + To string +} + +func (rt *RenameTable) Func() MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + db.Dialect() + return nil + } +} + +// Changeset is a set of changes that alter database state. +type Changeset struct { + operations []Operation +} + +func (c Changeset) Operations() []Operation { + return c.operations +} + +func (c *Changeset) Add(op Operation) { + c.operations = append(c.operations, op) +} + +type Detector struct{} + +func (d *Detector) Diff(got, want schema.State) Changeset { + var changes Changeset + + // Detect renamed models + oldModels := newTableSet(got.Tables...) + newModels := newTableSet(want.Tables...) + + addedModels := newModels.Sub(oldModels) + for _, added := range addedModels.Values() { + removedModels := oldModels.Sub(newModels) + for _, removed := range removedModels.Values() { + if !haveSameSignature(added, removed) { + continue + } + changes.Add(&RenameTable{ + From: removed.Name, + To: added.Name, + }) + } + } + + return changes +} + +// haveSameSignature determines if two tables have the same "signature". +func haveSameSignature(t1, t2 schema.TableDef) bool { + sig1 := newSignature(t1) + sig2 := newSignature(t2) + return sig1.Equals(sig2) +} + +// tableSet stores unique table definitions. +type tableSet struct { + underlying map[string]schema.TableDef +} + +func newTableSet(initial ...schema.TableDef) tableSet { + set := tableSet{ + underlying: make(map[string]schema.TableDef), + } + for _, t := range initial { + set.Add(t) + } + return set +} + +func (set tableSet) Add(t schema.TableDef) { + set.underlying[t.Name] = t +} + +func (set tableSet) Remove(s string) { + delete(set.underlying, s) +} + +func (set tableSet) Values() (tables []schema.TableDef) { + for _, t := range set.underlying { + tables = append(tables, t) + } + return +} + +func (set tableSet) Sub(other tableSet) tableSet { + res := set.clone() + for v := range other.underlying { + if _, ok := set.underlying[v]; ok { + res.Remove(v) + } + } + return res +} + +func (set tableSet) clone() tableSet { + res := newTableSet() + for _, t := range set.underlying { + res.Add(t) + } + return res +} + +// signature is a set of column definitions, which allows "relation/name-agnostic" comparison between them; +// meaning that two columns are considered equal if their types are the same. +type signature struct { + + // underlying stores the number of occurences for each unique column type. + // It helps to account for the fact that a table might have multiple columns that have the same type. + underlying map[schema.ColumnDef]int +} + +func newSignature(t schema.TableDef) signature { + s := signature{ + underlying: make(map[schema.ColumnDef]int), + } + s.scan(t) + return s +} + +// scan iterates over table's field and counts occurrences of each unique column definition. +func (s *signature) scan(t schema.TableDef) { + for _, c := range t.Columns { + s.underlying[c]++ + } +} + +// Equals returns true if 2 signatures share an identical set of columns. +func (s *signature) Equals(other signature) bool { + for k, count := range s.underlying { + if countOther, ok := other.underlying[k]; !ok || countOther != count { + return false + } + } + return true +} diff --git a/schema/append_value.go b/schema/append_value.go index 48a0761be..a67b41e38 100644 --- a/schema/append_value.go +++ b/schema/append_value.go @@ -99,10 +99,10 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { return appendTimeValue case timePtrType: return PtrAppender(appendTimeValue) - case ipType: - return appendIPValue case ipNetType: return appendIPNetValue + case ipType, netipPrefixType, netipAddrType: + return appendStringer case jsonRawMessageType: return appendJSONRawMessageValue } @@ -247,16 +247,15 @@ func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { return fmter.Dialect().AppendTime(b, tm) } -func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { - ip := v.Interface().(net.IP) - return fmter.Dialect().AppendString(b, ip.String()) -} - func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { ipnet := v.Interface().(net.IPNet) return fmter.Dialect().AppendString(b, ipnet.String()) } +func appendStringer(fmter Formatter, b []byte, v reflect.Value) []byte { + return fmter.Dialect().AppendString(b, v.Interface().(fmt.Stringer).String()) +} + func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { bytes := v.Bytes() if bytes == nil { diff --git a/schema/dialect.go b/schema/dialect.go index 330293444..a5e2afb4e 100644 --- a/schema/dialect.go +++ b/schema/dialect.go @@ -39,6 +39,9 @@ type Dialect interface { // is mandatory in queries that modify the schema (CREATE TABLE / ADD COLUMN, etc). // Dialects that do not have such requirement may return 0, which should be interpreted so by the caller. DefaultVarcharLen() int + + // DefaultSchema should returns the name of the default database schema. + DefaultSchema() string } // ------------------------------------------------------------------------------ diff --git a/schema/inspector.go b/schema/inspector.go new file mode 100644 index 000000000..464cfa81f --- /dev/null +++ b/schema/inspector.go @@ -0,0 +1,76 @@ +package schema + +import ( + "context" + "strings" +) + +type Inspector interface { + Inspect(ctx context.Context) (State, error) +} + +type State struct { + Tables []TableDef +} + +type TableDef struct { + Schema string + Name string + Columns map[string]ColumnDef +} + +type ColumnDef struct { + SQLType string + DefaultValue string + IsPK bool + IsNullable bool + IsAutoIncrement bool + IsIdentity bool +} + +type SchemaInspector struct { + dialect Dialect +} + +var _ Inspector = (*SchemaInspector)(nil) + +func NewInspector(dialect Dialect, models ...interface{}) *SchemaInspector { + dialect.Tables().Register(models...) + return &SchemaInspector{ + dialect: dialect, + } +} + +func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { + var state State + for _, t := range si.dialect.Tables().All() { + columns := make(map[string]ColumnDef) + for _, f := range t.Fields { + columns[f.Name] = ColumnDef{ + SQLType: f.CreateTableSQLType, + DefaultValue: f.SQLDefault, + IsPK: f.IsPK, + IsNullable: !f.NotNull, + IsAutoIncrement: f.AutoIncrement, + IsIdentity: f.Identity, + } + } + + schema, table := splitTableNameTag(si.dialect, t.Name) + state.Tables = append(state.Tables, TableDef{ + Schema: schema, + Name: table, + Columns: columns, + }) + } + return state, nil +} + +// splitTableNameTag +func splitTableNameTag(d Dialect, nameTag string) (string, string) { + schema, table := d.DefaultSchema(), nameTag + if schemaTable := strings.Split(nameTag, "."); len(schemaTable) == 2 { + schema, table = schemaTable[0], schemaTable[1] + } + return schema, table +} \ No newline at end of file diff --git a/schema/inspector/dialect.go b/schema/inspector/dialect.go new file mode 100644 index 000000000..701300da9 --- /dev/null +++ b/schema/inspector/dialect.go @@ -0,0 +1,11 @@ +package inspector + +import ( + "github.com/uptrace/bun" + "github.com/uptrace/bun/schema" +) + +type Dialect interface { + schema.Dialect + Inspector(db *bun.DB) schema.Inspector +} diff --git a/schema/reflect.go b/schema/reflect.go index 89be8eeb6..75980b102 100644 --- a/schema/reflect.go +++ b/schema/reflect.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "net" + "net/netip" "reflect" "time" ) @@ -14,6 +15,8 @@ var ( timeType = timePtrType.Elem() ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + netipPrefixType = reflect.TypeOf((*netip.Prefix)(nil)).Elem() + netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() diff --git a/schema/tables.go b/schema/tables.go index 985093421..58c45cbee 100644 --- a/schema/tables.go +++ b/schema/tables.go @@ -77,6 +77,7 @@ func (t *Tables) InProgress(typ reflect.Type) *Table { return table } +// ByModel gets the table by its Go name. func (t *Tables) ByModel(name string) *Table { var found *Table t.tables.Range(func(typ reflect.Type, table *Table) bool { @@ -89,6 +90,7 @@ func (t *Tables) ByModel(name string) *Table { return found } +// ByName gets the table by its SQL name. func (t *Tables) ByName(name string) *Table { var found *Table t.tables.Range(func(typ reflect.Type, table *Table) bool { @@ -100,3 +102,13 @@ func (t *Tables) ByName(name string) *Table { }) return found } + +// All returns all registered tables. +func (t *Tables) All() []*Table { + var found []*Table + t.tables.Range(func(typ reflect.Type, table *Table) bool { + found = append(found, table) + return true + }) + return found +}