diff --git a/scopes/scopes.go b/scopes/scopes.go index 858f6dc..04f999d 100644 --- a/scopes/scopes.go +++ b/scopes/scopes.go @@ -20,8 +20,8 @@ import ( // // type Book struct { ... } // does not implement Tabler interface, must set TableName manually // -// db.Table("books").Scopes(scopes.WithTenantSchema("tenant1")).Find(&Book{}) -// // SELECT * FROM tenant1.books; +// db.Table("books").Scopes(WithTenantSchema("tenant2")).Find(&Book{}) +// // SELECT * FROM "tenant2"."books" // // Example with Tabler interface: // @@ -29,43 +29,54 @@ import ( // // func (u *Book) TableName() string { return "books" } // implements Tabler interface, no need to set TableName manually // -// db.Scopes(scopes.WithTenantSchema("tenant2")).Find(&Book{}) -// // SELECT * FROM tenant2.books; +// db.Scopes(WithTenantSchema("tenant1")).Find(&Book{}) +// // SELECT * FROM "tenant1"."books" // // Example with model set manually: // // type Book struct { ... } // -// db.Model(&Book{}).Scopes(scopes.WithTenantSchema("tenant3")).Find(&Book{}) // model is set manually. -// // SELECT * FROM tenant3.books; +// db.Model(&Book{}).Scopes(WithTenantSchema("tenant1")).Find(&Book{}) // model is set manually. +// // SELECT * FROM "tenant1"."books" +// +// Example with destination set to a pointer to a struct: +// +// type Book struct { ... } +// +// db.Scopes(WithTenantSchema("tenant1")).Find(&Book{}) // destination is set to a pointer to a struct. +// // SELECT * FROM "tenant1"."books" +// +// Example with destination set to a pointer to an array/slice: +// +// type Book struct { ... } +// +// db.Scopes(WithTenantSchema("tenant1")).Find(&[]Book{}) // destination is set to a pointer to an array/slice. +// // SELECT * FROM "tenant1"."books" func WithTenantSchema(tenant string) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { - var ( - tn string - ) + var tableName string switch { case db.Statement.Table != "": - tn = db.Statement.Table + tableName = db.Statement.Table case db.Statement.Model != nil: - tn = tableNameFromInterface(db.Statement.Model) + tableName = tableNameFromInterface(db.Statement.Model) case db.Statement.Dest != nil: - destPtr := reflect.ValueOf(db.Statement.Dest) - if destPtr.Kind() != reflect.Ptr { - _ = db.AddError(errors.New("destination must be a pointer")) - } else { - tn = tableNameFromReflectValue(db.Statement.Dest) + var err error + tableName, err = tableNameFromReflectValue(reflect.ValueOf(db.Statement.Dest)) + if err != nil { + _ = db.AddError(err) + return db } } - if tn != "" { - return db.Table(tenant + "." + tn) + if tableName != "" { + return db.Table(tenant + "." + tableName) } - // otherwise, return an error _ = db.AddError(gorm.ErrModelValueRequired) return db } } -// tableNameFromInterface returns the table name from a interface. +// tableNameFromInterface returns the table name from an interface. func tableNameFromInterface(val interface{}) string { if s, ok := val.(schema.Tabler); ok { return s.TableName() @@ -73,14 +84,18 @@ func tableNameFromInterface(val interface{}) string { return "" } -// tableNameFromReflectValue returns the table name from a reflect.Value. -func tableNameFromReflectValue(valPtr interface{}) string { - val := reflect.ValueOf(valPtr).Elem() - if val.Kind() == reflect.Struct { - return tableNameFromInterface(val.Interface()) +// tableNameFromReflectValue returns the table name from a [reflect.Value]. +func tableNameFromReflectValue(valPtr reflect.Value) (string, error) { + if valPtr.Kind() != reflect.Ptr { + return "", errors.New("destination must be a pointer") } - if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { - return tableNameFromInterface(reflect.New(val.Type().Elem()).Interface()) + val := valPtr.Elem() + switch val.Kind() { //nolint:exhaustive // only interested in a struct and a array/slice + case reflect.Struct: + return tableNameFromInterface(val.Interface()), nil + case reflect.Slice, reflect.Array: + return tableNameFromInterface(reflect.New(val.Type().Elem()).Interface()), nil + default: + return "", nil } - return "" }