diff --git a/lib/db/schema/schema.go b/lib/db/schema/schema.go index a2aaf9d3..d1fb9ce1 100644 --- a/lib/db/schema/schema.go +++ b/lib/db/schema/schema.go @@ -2,6 +2,7 @@ package schema import ( "fmt" + "github.com/getevo/evo/v2/lib/db/schema/table" "gorm.io/gorm" "gorm.io/gorm/schema" "path/filepath" @@ -11,17 +12,18 @@ import ( var Models []Model type Model struct { - Sample any `json:"sample"` - Value reflect.Value `json:"-"` - Type reflect.Type `json:"-"` - Kind reflect.Kind `json:"-"` - Table string `json:"table"` - Name string `json:"name"` - Package string `json:"package"` - PackagePath string `json:"package_path"` - PrimaryKey []string `json:"primary_key"` - Schema *schema.Schema `json:"-"` - Statement *gorm.Statement `json:"-"` + Sample any `json:"sample"` + Value reflect.Value `json:"-"` + Type reflect.Type `json:"-"` + Kind reflect.Kind `json:"-"` + Table string `json:"table"` + Name string `json:"name"` + Package string `json:"package"` + PackagePath string `json:"package_path"` + PrimaryKey []string `json:"primary_key"` + Joins map[string][]string `json:"joins"` + Schema *schema.Schema `json:"-"` + Statement *gorm.Statement `json:"-"` } func (m Model) Join(joins ...*Model) ([]string, []string, error) { @@ -29,6 +31,10 @@ func (m Model) Join(joins ...*Model) ([]string, []string, error) { var tables = []string{m.Table} for _, join := range joins { tables = append(tables, join.Table) + if v, ok := m.Joins[join.Table]; ok { + where = append(where, quote(m.Table)+"."+quote(v[0])+" = "+quote(join.Table)+"."+quote(v[1])) + continue + } if _, ok := join.Schema.FieldsByDBName[m.PrimaryKey[0]]; ok { where = append(where, quote(m.Table)+"."+quote(m.PrimaryKey[0])+" = "+quote(join.Table)+"."+quote(m.PrimaryKey[0])) continue @@ -47,8 +53,13 @@ func quote(s string) string { return "`" + s + "`" } +var database = "" + func UseModel(db *gorm.DB, values ...any) { migrations = append(migrations, values...) + if database == "" { + db.Raw("SELECT DATABASE();").Scan(&database) + } for index, _ := range values { ref := reflect.ValueOf(values[index]) if ref.Kind() != reflect.Struct { @@ -69,6 +80,13 @@ func UseModel(db *gorm.DB, values ...any) { model.PrimaryKey = stmt.Schema.PrimaryFieldDBNames model.Statement = stmt model.Table = stmt.Table + model.Joins = make(map[string][]string) + + var constraints []table.Constraint + db.Where(table.Constraint{Database: database}).Find(&constraints) + for _, constraint := range constraints { + model.Joins[constraint.ReferencedTable] = []string{constraint.Column, constraint.ReferencedColumn} + } Models = append(Models, model) }