Skip to content

Commit

Permalink
fix(scopes): Reduce memory allocations in WithTenantSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
bartventer committed May 24, 2024
1 parent 918fa96 commit 1e5dc1b
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions scopes/scopes.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,67 +20,82 @@ import (
//
// type Book struct { ... } // does not implement Tabler interface, must set TableName manually
//
// db.Table("books").Scopes(scopes.WithTenantSchema("tenant1")).Find(&Book{})
// // SELECT * FROM tenant1.books;
// db.Table("books").Scopes(WithTenantSchema("tenant2")).Find(&Book{})
// // SELECT * FROM "tenant2"."books"
//
// Example with Tabler interface:
//
// type Book struct { ... }
//
// func (u *Book) TableName() string { return "books" } // implements Tabler interface, no need to set TableName manually
//
// db.Scopes(scopes.WithTenantSchema("tenant2")).Find(&Book{})
// // SELECT * FROM tenant2.books;
// db.Scopes(WithTenantSchema("tenant1")).Find(&Book{})
// // SELECT * FROM "tenant1"."books"
//
// Example with model set manually:
//
// type Book struct { ... }
//
// db.Model(&Book{}).Scopes(scopes.WithTenantSchema("tenant3")).Find(&Book{}) // model is set manually.
// // SELECT * FROM tenant3.books;
// db.Model(&Book{}).Scopes(WithTenantSchema("tenant1")).Find(&Book{}) // model is set manually.
// // SELECT * FROM "tenant1"."books"
//
// Example with destination set to a pointer to a struct:
//
// type Book struct { ... }
//
// db.Scopes(WithTenantSchema("tenant1")).Find(&Book{}) // destination is set to a pointer to a struct.
// // SELECT * FROM "tenant1"."books"
//
// Example with destination set to a pointer to an array/slice:
//
// type Book struct { ... }
//
// db.Scopes(WithTenantSchema("tenant1")).Find(&[]Book{}) // destination is set to a pointer to an array/slice.
// // SELECT * FROM "tenant1"."books"
func WithTenantSchema(tenant string) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
var (
tn string
)
var tableName string
switch {
case db.Statement.Table != "":
tn = db.Statement.Table
tableName = db.Statement.Table
case db.Statement.Model != nil:
tn = tableNameFromInterface(db.Statement.Model)
tableName = tableNameFromInterface(db.Statement.Model)
case db.Statement.Dest != nil:
destPtr := reflect.ValueOf(db.Statement.Dest)
if destPtr.Kind() != reflect.Ptr {
_ = db.AddError(errors.New("destination must be a pointer"))
} else {
tn = tableNameFromReflectValue(db.Statement.Dest)
var err error
tableName, err = tableNameFromReflectValue(reflect.ValueOf(db.Statement.Dest))
if err != nil {
_ = db.AddError(err)
return db
}
}
if tn != "" {
return db.Table(tenant + "." + tn)
if tableName != "" {
return db.Table(tenant + "." + tableName)
}
// otherwise, return an error
_ = db.AddError(gorm.ErrModelValueRequired)
return db
}
}

// tableNameFromInterface returns the table name from a interface.
// tableNameFromInterface returns the table name from an interface.
func tableNameFromInterface(val interface{}) string {
if s, ok := val.(schema.Tabler); ok {
return s.TableName()
}
return ""
}

// tableNameFromReflectValue returns the table name from a reflect.Value.
func tableNameFromReflectValue(valPtr interface{}) string {
val := reflect.ValueOf(valPtr).Elem()
if val.Kind() == reflect.Struct {
return tableNameFromInterface(val.Interface())
// tableNameFromReflectValue returns the table name from a [reflect.Value].
func tableNameFromReflectValue(valPtr reflect.Value) (string, error) {
if valPtr.Kind() != reflect.Ptr {
return "", errors.New("destination must be a pointer")
}
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
return tableNameFromInterface(reflect.New(val.Type().Elem()).Interface())
val := valPtr.Elem()
switch val.Kind() { //nolint:exhaustive // only interested in a struct and a array/slice
case reflect.Struct:
return tableNameFromInterface(val.Interface()), nil
case reflect.Slice, reflect.Array:
return tableNameFromInterface(reflect.New(val.Type().Elem()).Interface()), nil
default:
return "", nil
}
return ""
}

0 comments on commit 1e5dc1b

Please sign in to comment.