diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..e4e810749 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +--- +version: 2 +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + - package-ecosystem: gomod + directory: /tests + schedule: + interval: weekly diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 5b0bd9813..868bcc348 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -10,13 +10,13 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:invalid question" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index ea3207d68..3efc90f74 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -10,12 +10,12 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 4511c3782..b252dd7ae 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,6 +6,17 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v1 + uses: actions/checkout@v2 - name: golangci-lint - uses: reviewdog/action-golangci-lint@v1 + uses: reviewdog/action-golangci-lint@v2 + + - name: Setup reviewdog + uses: reviewdog/action-setup@v1 + + - name: gofumpt -s with reviewdog + env: + REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + go install mvdan.cc/gofumpt@v0.2.0 + gofumpt -e -d . | \ + reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index f9c1becea..e0be186fa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -10,12 +10,12 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" - days-before-stale: 60 - days-before-close: 30 + stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" + days-before-stale: 360 + days-before-close: 180 stale-issue-label: "status:stale" exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' stale-pr-label: 'status:stale' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4388c31d8..91a0abc9f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,8 +13,8 @@ jobs: sqlite: strategy: matrix: - go: ['1.15', '1.14', '1.13'] - platform: [ubuntu-latest, macos-latest] # can not run in windows OS + go: ['1.17', '1.16'] + platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} steps: @@ -26,45 +26,20 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlite ./tests/tests_all.sh - - sqlite_windows: - strategy: - matrix: - go: ['1.15', '1.14', '1.13'] - platform: [windows-latest] - runs-on: ${{ matrix.platform }} - - steps: - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: go mod pakcage cache - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - - - name: Tests - run: cd tests && set GORM_DIALECT=sqlite && go test $race -count=1 -v ./... #run the line in widnows's CMD, default GORM_DIALECT is sqlite + run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.15', '1.14', '1.13'] + dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -95,21 +70,21 @@ jobs: uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: matrix: - dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.15', '1.14', '1.13'] - platform: [ubuntu-latest] # can not run in macOS and widnowsOS + dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] + go: ['1.17', '1.16'] + platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} services: @@ -138,19 +113,19 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: matrix: - go: ['1.15', '1.14', '1.13'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} @@ -181,11 +156,11 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: go mod pakcage cache + - name: go mod package cache uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/.gitignore b/.gitignore index c14d60050..e1b9ecea1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ TODO* documents coverage.txt _book +.idea diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..16903ed6c --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,11 @@ +linters: + enable: + - cyclop + - exportloopref + - gocritic + - gosec + - ineffassign + - misspell + - prealloc + - unconvert + - unparam diff --git a/README.md b/README.md index 9c0aded05..a3eabe39a 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Hooks (Before/After Create/Save/Update/Delete/Find) * Eager loading with `Preload`, `Joins` * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point -* Context, Prepared Statment Mode, DryRun Mode +* Context, Prepared Statement Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key diff --git a/association.go b/association.go index 7adb8c914..09e79ca60 100644 --- a/association.go +++ b/association.go @@ -1,7 +1,6 @@ package gorm import ( - "errors" "fmt" "reflect" "strings" @@ -27,7 +26,7 @@ func (db *DB) Association(column string) *Association { association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { - association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) } db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) @@ -66,7 +65,9 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation( /*clear*/ true, values...) + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error + } // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -78,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -95,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) - if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } @@ -116,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } @@ -142,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -185,11 +186,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.BelongsTo: tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -197,11 +198,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.HasOne, schema.HasMany: tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -227,11 +228,11 @@ func (association *Association) Delete(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -240,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { // clean up deleted values's foreign key - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { - if _, zero := rel.Field.ValueOf(data); !zero { - fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { @@ -252,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error { validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { @@ -260,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - association.Error = rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -328,14 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) @@ -343,7 +344,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) if clear { fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } @@ -354,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } else if ev.Type().Elem().AssignableTo(elemType) { fieldValue = reflect.Append(fieldValue, ev.Elem()) } else { - association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) } if elemType.Kind() == reflect.Struct { @@ -372,25 +373,55 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { + columnName = name + } + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) + } + } + } + for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } + associationDB := association.DB.Session(&Session{}).Model(nil) + if !association.DB.FullSaveAssociations { + associationDB.Select(selectedSaveColumns) + } + if len(omitColumns) > 0 { + associationDB.Omit(omitColumns...) + } + associationDB = associationDB.Session(&Session{}) + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } @@ -398,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } @@ -409,7 +440,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ break } - association.Error = errors.New("invalid association values, length doesn't match") + association.Error = ErrInvalidValueOfLength return } @@ -417,17 +448,17 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data if clear && len(values) == 0 { - association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -439,12 +470,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error } } for _, assignBack := range assignBacks { - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { @@ -455,7 +486,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ func (association *Association) buildCondition() *DB { var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) @@ -470,7 +501,7 @@ func (association *Association) buildCondition() *DB { tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } - tx.Clauses(clause.From{Joins: []clause.Join{{ + tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, }}}) diff --git a/callbacks.go b/callbacks.go index e21e0718f..f344649ea 100644 --- a/callbacks.go +++ b/callbacks.go @@ -32,6 +32,7 @@ type callbacks struct { type processor struct { db *DB + Clauses []string fns []func(*DB) callbacks []*callback } @@ -71,19 +72,38 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } -func (p *processor) Execute(db *DB) { - curTime := time.Now() - stmt := db.Statement +func (p *processor) Execute(db *DB) *DB { + // call scopes + for len(db.Statement.scopes) > 0 { + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + } + + var ( + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool + ) + + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } + // parse model values if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) @@ -91,13 +111,18 @@ func (p *processor) Execute(db *DB) { } } + // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { + if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { + stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) + } + stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { - db.AddError(fmt.Errorf("invalid value")) + db.AddError(ErrInvalidValue) } } @@ -105,14 +130,22 @@ func (p *processor) Execute(db *DB) { f(db) } - db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected - }, db.Error) + if stmt.SQL.Len() > 0 { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) + } if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil } + + if resetBuildClauses { + stmt.BuildClauses = nil + } + + return db } func (p *processor) Get(name string) func(*DB) { @@ -181,7 +214,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -189,7 +222,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -219,7 +252,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } @@ -235,7 +268,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists @@ -253,7 +286,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted diff --git a/callbacks/associations.go b/callbacks/associations.go index e66696004..d6fd21ded 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -7,53 +7,58 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) -func SaveBeforeAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Belongs To associations - for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } - - setupReferences := func(obj reflect.Value, elem reflect.Value) { - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } - if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { - dest[ref.ForeignKey.DBName] = pv - if _, ok := dest[rel.Name]; ok { - dest[rel.Name] = elem.Interface() + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } } } } } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - objs []reflect.Value - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) - - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + rValLen = db.Statement.ReflectValue.Len() + objs = make([]reflect.Value, 0, rValLen) + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + for i := 0; i < rValLen; i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() != reflect.Struct { + break + } - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) @@ -61,27 +66,25 @@ func SaveBeforeAssociations(db *gorm.DB) { elems = reflect.Append(elems, rv.Addr()) } } - } else { - break } - } - if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { - for i := 0; i < elems.Len(); i++ { - setupReferences(objs[i], elems.Index(i)) + if elems.Len() > 0 { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) + } } } - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { - setupReferences(db.Statement.ReflectValue, rv) + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } } } } @@ -89,250 +92,253 @@ func SaveBeforeAssociations(db *gorm.DB) { } } -func SaveAfterAssociations(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil { - selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) - // Save Has One associations - for _, rel := range db.Statement.Schema.Relationships.HasOne { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - fieldType = rel.Field.FieldType - isPtr = fieldType.Kind() == reflect.Ptr - ) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) + } } - } - elems = reflect.Append(elems, rv) + elems = reflect.Append(elems, rv) + } } } - } - if elems.Len() > 0 { - assignmentColumns := []string{} - for _, ref := range rel.References { - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) - } + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) - } - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if f.Kind() != reflect.Ptr { - f = f.Addr() + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } - assignmentColumns := []string{} - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + ref.ForeignKey.Set(db.Statement.Context, f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) - } - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + } } } - } - // Save Has Many associations - for _, rel := range db.Statement.Schema.Relationships.HasMany { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + ref.ForeignKey.Set(db.Statement.Context, elem, pv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) + } + } + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + identityMap[cacheKey] = true + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } + } + } - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) } } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) } - if elems.Len() > 0 { - assignmentColumns := []string{} - for _, ref := range rel.References { - assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) - } - } - - // Save Many2Many associations - for _, rel := range db.Statement.Schema.Relationships.Many2Many { - if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { - continue - } + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} - fieldType := rel.Field.IndirectFieldType.Elem() - isPtr := fieldType.Kind() == reflect.Ptr - if !isPtr { - fieldType = reflect.PtrTo(fieldType) - } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) - objs := []reflect.Value{} - - appendToJoins := func(obj reflect.Value, elem reflect.Value) { - joinValue := reflect.New(rel.JoinTable.ModelType) - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) - } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) - } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + } } + joins = reflect.Append(joins, joinValue) } - joins = reflect.Append(joins, joinValue) - } - appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { - elem := f.Index(i) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } } - } - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - obj := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(obj).Kind() == reflect.Struct { - appendToElems(obj) + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) } - case reflect.Struct: - appendToElems(db.Statement.ReflectValue) - } - if elems.Len() > 0 { - if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) - } + // optimize elems of reflect value length + if elemLen := elems.Len(); elemLen > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } - for i := 0; i < elems.Len(); i++ { - appendToJoins(objs[i], elems.Index(i)) + for i := 0; i < elemLen; i++ { + appendToJoins(objs[i], elems.Index(i)) + } } - } - if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) + } } } } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { - if stmt.DB.FullSaveAssociations { - defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) - for _, dbName := range s.DBNames { - if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { - continue - } - - if !s.LookUpField(dbName).PrimaryKey { - defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) - } - } - } - - if len(defaultUpdatingColumns) > 0 { - var columns []clause.Column - if s.PrioritizedPrimaryField != nil { - columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} - } else { - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) - } +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { + if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { + onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) + for _, dbName := range s.PrimaryFieldDBNames { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) } - return clause.OnConflict{ - Columns: columns, - DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + onConflict.UpdateAll = stmt.DB.FullSaveAssociations + if !onConflict.UpdateAll { + onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) } + } else { + onConflict.DoNothing = true } - return clause.OnConflict{DoNothing: true} + return } func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { @@ -346,8 +352,6 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, columnName := "" if strings.HasPrefix(name, refName) { columnName = strings.TrimPrefix(name, refName) - } else if strings.HasPrefix(name, clause.Associations) { - columnName = name } if columnName != "" { @@ -359,10 +363,25 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, } } - tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }) + + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if tx.Statement.FullSaveAssociations { + tx = tx.Set("gorm:update_track_time", true) + } if len(selects) > 0 { tx = tx.Select(selects) + } else if restricted && len(omits) == 0 { + tx = tx.Omit(clause.Associations) } if len(omits) > 0 { diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index dda4b0466..d681aef36 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -4,9 +4,19 @@ import ( "gorm.io/gorm" ) +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + type Config struct { LastInsertIDReversed bool - WithReturning bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string } func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { @@ -14,38 +24,60 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { return !db.SkipDefaultTransaction } + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) - createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) - createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) + queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) - deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:delete", Delete(config)) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) - updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) - updateCallback.Register("gorm:update", Update) - updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) + updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + updateCallback.Clauses = config.UpdateClauses + + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses - db.Callback().Row().Register("gorm:row", RowQuery) - db.Callback().Raw().Register("gorm:raw", RawExec) + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/create.go b/callbacks/create.go index 3ca56d733..b0964e2b6 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeCreate(db *gorm.DB) { @@ -31,172 +32,115 @@ func BeforeCreate(db *gorm.DB) { } func Create(config *Config) func(db *gorm.DB) { - if config.WithReturning { - return CreateWithReturning - } else { - return func(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } + supportReturning := utils.Contains(config.CreateClauses, "RETURNING") - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{}) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + return func(db *gorm.DB) { + if db.Error != nil { + return + } - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + if db.Statement.Schema != nil { + if !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) } + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - - if db.RowsAffected > 0 { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID-- - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID++ - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } - } - } else { - db.AddError(err) - } - } - } - } else { - db.AddError(err) + if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + if _, ok := db.Statement.Clauses["RETURNING"]; !ok { + fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) } + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } - } -} - -func CreateWithReturning(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + db.Statement.Build(db.Statement.BuildClauses...) } - if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { - db.Statement.WriteString(" RETURNING ") - - var ( - fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) - values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) - ) + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } - for idx, field := range sch.FieldsWithDefaultDBValue { - if idx > 0 { - db.Statement.WriteByte(',') + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing } - - fields[idx] = field - db.Statement.WriteQuoted(field.DBName) } - if !db.DryRun && db.Error == nil { - db.RowsAffected = 0 - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - c := db.Statement.Clauses["ON CONFLICT"] - onConflict, _ := c.Expression.(clause.OnConflict) - - for rows.Next() { - BEGIN: - reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) - if reflect.Indirect(reflectValue).Kind() != reflect.Struct { - break - } - - for idx, field := range fields { - fieldValue := field.ReflectValueOf(reflectValue) + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + db.AddError(rows.Close()) + } - if onConflict.DoNothing && !fieldValue.IsZero() { - db.RowsAffected++ + return + } - if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { - return - } + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } - goto BEGIN - } + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } - values[idx] = fieldValue.Addr().Interface() - } + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } - db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } - case reflect.Struct: - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break } - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } - } else { - db.AddError(err) } - } - } else if !db.DryRun && db.Error == nil { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) + } } } } @@ -225,6 +169,8 @@ func AfterCreate(db *gorm.DB) { // ConvertToCreateValues convert to create values func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + curTime := stmt.DB.NowFunc() + switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) @@ -237,14 +183,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) - curTime = stmt.DB.NowFunc() + _, updateTrackTime = stmt.Get("gorm:update_track_time") isZero bool ) + stmt.Settings.Delete("gorm:update_track_time") + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } } @@ -252,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) - values.Values = make([][]interface{}, stmt.ReflectValue.Len()) - defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} - if stmt.ReflectValue.Len() == 0 { + rValLen := stmt.ReflectValue.Len() + stmt.SQL.Grow(rValLen * 18) + values.Values = make([][]interface{}, rValLen) + if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } - for i := 0; i < stmt.ReflectValue.Len(); i++ { + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) if !rv.IsValid() { stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) @@ -270,24 +219,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(rv, field.DefaultValueInterface) + field.Set(stmt.Context, rv, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { - defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) + defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } - defaultValueFieldsHavingValue[field][i] = v + defaultValueFieldsHavingValue[field][i] = rvOfvalue } } } @@ -307,22 +259,25 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.ReflectValue, field.DefaultValueInterface) + field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - values.Values[0] = append(values.Values[0], v) + values.Values[0] = append(values.Values[0], rvOfvalue) } } } @@ -333,25 +288,45 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if c, ok := stmt.Clauses["ON CONFLICT"]; ok { if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { - if stmt.Schema != nil && len(values.Columns) > 1 { + if stmt.Schema != nil && len(values.Columns) >= 1 { + selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) + columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + if field.AutoUpdateTime > 0 { + assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} + switch field.AutoUpdateTime { + case schema.UnixNanosecond: + assignment.Value = curTime.UnixNano() + case schema.UnixMillisecond: + assignment.Value = curTime.UnixNano() / 1e6 + case schema.UnixSecond: + assignment.Value = curTime.Unix() + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) + } else { + columns = append(columns, column.Name) + } + } } } } - onConflict := clause.OnConflict{ - Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), - DoUpdates: clause.AssignmentColumns(columns), + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) + if len(onConflict.DoUpdates) == 0 { + onConflict.DoNothing = true } - for idx, field := range stmt.Schema.PrimaryFields { - onConflict.Columns[idx] = clause.Column{Name: field.DBName} + // use primary fields as default OnConflict columns + if len(onConflict.Columns) == 0 { + for _, field := range stmt.Schema.PrimaryFields { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) + } } - stmt.AddClause(onConflict) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 867aa6970..1fb5261cb 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeDelete(db *gorm.DB) { @@ -25,96 +26,104 @@ func BeforeDelete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } - if restricted { - for column, v := range selectColumns { - if v { - if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { - switch rel.Type { - case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) - withoutConditions := false - - if len(db.Statement.Selects) > 0 { - var selects []string - for _, s := range db.Statement.Selects { - if s == clause.Associations { - selects = append(selects, s) - } else if strings.HasPrefix(s, column+".") { - selects = append(selects, strings.TrimPrefix(s, column+".")) - } - } - - if len(selects) > 0 { - tx = tx.Select(selects) - } - } - - for _, cond := range queryConds { - if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { - withoutConditions = true - break - } - } - - if !withoutConditions { - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } - } - case schema.Many2Many: - var ( - queryConds []clause.Expression - foreignFields []*schema.Field - relForeignKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) - ) - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - queryConds = append(queryConds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) - column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) - queryConds = append(queryConds, clause.IN{Column: column, Values: values}) - - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } + for column, v := range selectColumns { + if !v { + continue + } + + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } + + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } + + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return } } } + } } -func Delete(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } +func Delete(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -122,7 +131,7 @@ func Delete(db *gorm.DB) { } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -132,7 +141,16 @@ func Delete(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("DELETE", "FROM", "WHERE") + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { @@ -141,12 +159,19 @@ func Delete(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + ok, mode := hasReturning(db, supportReturning) + if !ok { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } + + return + } - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + db.AddError(rows.Close()) } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index 3ac63fa19..a59e1880f 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -12,7 +12,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) - var keys = make([]string, 0, len(mapValue)) + keys := make([]string, 0, len(mapValue)) for k := range mapValue { keys = append(keys, k) } @@ -40,17 +40,20 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { - var ( - columns = make([]string, 0, len(mapValues)) - result = map[string][]interface{}{} - selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) - ) + columns := make([]string, 0, len(mapValues)) + // when the length of mapValues is zero,return directly here + // no need to call stmt.SelectAndOmitColumns method if len(mapValues) == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + var ( + result = make(map[string][]interface{}, len(mapValues)) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { @@ -88,3 +91,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st } return } + +func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { + if supportReturning { + if c, ok := tx.Statement.Clauses["RETURNING"]; ok { + returning, _ := c.Expression.(clause.Returning) + if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { + return true, 0 + } + return true, gorm.ScanUpdate + } + } + return false, 0 +} diff --git a/callbacks/preload.go b/callbacks/preload.go index 682427c9e..2363a8cab 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "gorm.io/gorm" @@ -9,10 +10,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { +func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { var ( reflectValue = db.Statement.ReflectValue - rel = rels[len(rels)-1] tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field @@ -22,13 +22,18 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { inlineConds []interface{} ) - if len(rels) > 1 { - reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) - } + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) if rel.JoinTable != nil { - var joinForeignFields, joinRelForeignFields []*schema.Field - var joinForeignKeys []string + var ( + joinForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinForeignKeys = make([]string, 0, len(rel.References)) + ) + for _, ref := range rel.References { if ref.OwnPrimaryKey { joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) @@ -43,25 +48,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return } joinResults := rel.JoinTable.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { + joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -70,7 +76,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -86,24 +92,31 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return } } + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) - for _, cond := range conds { - if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { - tx = fc(tx) - } else { - inlineConds = append(inlineConds, cond) + if len(values) != 0 { + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } } - } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + } fieldValues := make([]interface{}, len(relForeignFields)) @@ -112,17 +125,17 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -130,11 +143,18 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(elem) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) + } + + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", + elem.Interface())) + continue } - for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { - reflectFieldValue := rel.Field.ReflectValueOf(data) + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -142,12 +162,12 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) + rel.Field.Set(db.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index aa4629a2e..03798859d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,7 +8,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" ) func Query(db *gorm.DB) { @@ -21,28 +20,27 @@ func Query(db *gorm.DB) { db.AddError(err) return } - defer rows.Close() - - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, 0) + db.AddError(rows.Close()) } } } func BuildQuerySQL(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { + if db.Statement.Schema != nil { for _, c := range db.Statement.Schema.QueryClauses { db.Statement.AddClause(c) } } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } @@ -96,19 +94,23 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + joins := []clause.Join{} + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + + if len(db.Statement.Joins) != 0 || len(joins) != 0 { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } - joins := []clause.Join{} for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name @@ -143,6 +145,21 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.On != nil { + onStmt := gorm.Statement{Table: tableAliasName, DB: db} + join.On.Build(&onStmt) + onSQL := onStmt.SQL.String() + vars := onStmt.Vars + for idx, v := range onStmt.Vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, @@ -150,11 +167,12 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } + db.Statement.Joins = nil db.Statement.AddClause(clause.From{Joins: joins}) } else { db.Statement.AddClauseIfNotExists(clause.From{}) @@ -162,61 +180,56 @@ func BuildQuerySQL(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clauseSelect) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + db.Statement.Build(db.Statement.BuildClauses...) } } func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { - if name == clause.Associations { + preloadFields := strings.Split(name, ".") + if preloadFields[0] == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { - preloadMap[rel.Name] = []string{rel.Name} + if _, ok := preloadMap[rel.Name]; !ok { + preloadMap[rel.Name] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[rel.Name][value] = db.Statement.Preloads[name] + } } } } else { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] } } } - preloadNames := make([]string, len(preloadMap)) - idx := 0 + preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { - preloadNames[idx] = key - idx++ + preloadNames = append(preloadNames, key) } sort.Strings(preloadNames) for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } - } - - if db.Error == nil { - preload(db, rels, db.Statement.Preloads[name]) + if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { + preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) + } else { + db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } } } } func AfterQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) diff --git a/callbacks/raw.go b/callbacks/raw.go index d594ab391..013e638cb 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) - } else { - db.RowsAffected, _ = result.RowsAffected() + return } + + db.RowsAffected, _ = result.RowsAffected() } } diff --git a/callbacks/row.go b/callbacks/row.go index 10e880e12..56be742e8 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -7,15 +7,17 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) + if db.DryRun { + return + } - if !db.DryRun { - if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } - - db.RowsAffected = -1 + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 3171b5bba..50887ccce 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -5,12 +5,14 @@ import ( ) func BeginTransaction(db *gorm.DB) { - if !db.Config.SkipDefaultTransaction { + if !db.Config.SkipDefaultTransaction && db.Error == nil { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool db.InstanceSet("gorm:started_transaction", true) - } else { + } else if tx.Error == gorm.ErrInvalidTransaction { tx.Error = nil + } else { + db.Error = tx.Error } } } @@ -18,11 +20,12 @@ func BeginTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) { if !db.Config.SkipDefaultTransaction { if _, ok := db.InstanceGet("gorm:started_transaction"); ok { - if db.Error == nil { - db.Commit() - } else { + if db.Error != nil { db.Rollback() + } else { + db.Commit() } + db.Statement.ConnPool = db.ConnPool } } diff --git a/callbacks/update.go b/callbacks/update.go index c8f3922eb..4f07ca304 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SetupUpdateReflectValue(db *gorm.DB) { @@ -20,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) } } } @@ -50,23 +51,33 @@ func BeforeUpdate(db *gorm.DB) { } } -func Update(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } +func Update(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) if set := ConvertToAssignments(db.Statement); len(set) != 0 { db.Statement.AddClause(set) - } else { + } else if _, ok := db.Statement.Clauses["SET"]; !ok { return } - db.Statement.Build("UPDATE", "SET", "WHERE") + + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.Build(db.Statement.BuildClauses...) } if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { @@ -75,12 +86,20 @@ func Update(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() + if ok, mode := hasReturning(db, supportReturning); ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + dest := db.Statement.Dest + db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() + gorm.Scan(rows, db, mode) + db.Statement.Dest = dest + db.AddError(rows.Close()) + } } else { - db.AddError(err) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } } } } @@ -118,13 +137,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } default: @@ -140,23 +159,26 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - var primaryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + if size := stmt.ReflectValue.Len(); size > 0 { + var primaryKeyExprs []clause.Expression + for i := 0; i < size; i++ { + exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + } } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -202,7 +224,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { now := stmt.DB.NowFunc() assignValue(field, now) @@ -220,16 +242,24 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: + updatingSchema := stmt.Schema + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + } + } + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks { - if field.AutoUpdateTime > 0 { + if field := updatingSchema.LookUpField(dbName); field != nil { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(stmt.Context, updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { @@ -241,16 +271,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } isZero = false } - } - if ok || !isZero { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + if (ok || !isZero) && field.Updatable { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } - } - } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } diff --git a/chainable_api.go b/chainable_api.go index c3a02d205..173479d30 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -50,15 +50,14 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] - return } } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - return + } else { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = name } - - tx.Statement.Table = name return } @@ -93,10 +92,17 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) - - // normal field names - if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if strings.Count(v, "?") >= len(args) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else if strings.Count(v, "@") > 0 && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.NamedExpr{SQL: v, Vars: args}, + }) + } else { tx.Statement.Selects = []string{v} for _, arg := range args { @@ -115,11 +121,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") - } else { - tx.Statement.AddClause(clause.Select{ - Distinct: db.Statement.Distinct, - Expression: clause.Expr{SQL: v, Vars: args}, - }) } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) @@ -170,8 +171,19 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + + if len(args) == 1 { + if db, ok := args[0].(*DB); ok { + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) + return + } + } + } + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } @@ -207,12 +219,14 @@ func (db *DB) Order(value interface{}) (tx *DB) { tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, }) - default: - tx.Statement.AddClause(clause.OrderBy{ - Columns: []clause.OrderByColumn{{ - Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, - }}, - }) + case string: + if v != "" { + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: v, Raw: true}, + }}, + }) + } } return } @@ -243,11 +257,10 @@ func (db *DB) Offset(offset int) (tx *DB) { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - db = f(db) - } - return db +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + tx = db.getInstance() + tx.Statement.scopes = append(tx.Statement.scopes, funcs...) + return tx } // Preload preload associations with given conditions diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 88a238e36..e08677ac0 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -32,7 +32,8 @@ func BenchmarkComplexSelect(b *testing.B) { for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ - clause.Select{}, clause.From{}, + clause.Select{}, + clause.From{}, clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, diff --git a/clause/clause.go b/clause/clause.go index d413d0ee2..de19f2e30 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -7,7 +7,7 @@ type Interface interface { MergeClause(*Clause) } -// ClauseBuilder clause builder, allows to custmize how to build clause +// ClauseBuilder clause builder, allows to customize how to build clause type ClauseBuilder func(Clause, Builder) type Writer interface { @@ -62,9 +62,9 @@ func (c Clause) Build(builder Builder) { } const ( - PrimaryKey string = "@@@py@@@" // primary key - CurrentTable string = "@@@ct@@@" // current table - Associations string = "@@@as@@@" // associations + PrimaryKey string = "~~~py~~~" // primary key + CurrentTable string = "~~~ct~~~" // current table + Associations string = "~~~as~~~" // associations ) var ( diff --git a/clause/expression.go b/clause/expression.go index b30c46b03..dde00b1d7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -67,6 +67,12 @@ func (expr Expr) Build(builder Builder) { builder.WriteByte(v) } } + + if idx < len(expr.Vars) { + for _, v := range expr.Vars[idx:] { + builder.AddVar(builder, sql.NamedArg{Value: v}) + } + } } // NamedExpr raw expression for named expr @@ -78,9 +84,10 @@ type NamedExpr struct { // Build build raw expression func (expr NamedExpr) Build(builder Builder) { var ( - idx int - inName bool - namedMap = make(map[string]interface{}, len(expr.Vars)) + idx int + inName bool + afterParenthesis bool + namedMap = make(map[string]interface{}, len(expr.Vars)) ) for _, v := range expr.Vars { @@ -120,7 +127,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) @@ -131,19 +138,53 @@ func (expr NamedExpr) Build(builder Builder) { inName = false } + afterParenthesis = false builder.WriteByte(v) } else if v == '?' && len(expr.Vars) > idx { - builder.AddVar(builder, expr.Vars[idx]) + if afterParenthesis { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + idx++ } else if inName { name = append(name, v) } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } builder.WriteByte(v) } } if inName { - builder.AddVar(builder, namedMap[string(name)]) + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } } } @@ -175,11 +216,12 @@ func (in IN) Build(builder Builder) { } func (in IN) NegationBuild(builder Builder) { + builder.WriteQuoted(in.Column) switch len(in.Values) { case 0: + builder.WriteString(" IS NOT NULL") case 1: if _, ok := in.Values[0].([]interface{}); !ok { - builder.WriteQuoted(in.Column) builder.WriteString(" <> ") builder.AddVar(builder, in.Values[0]) break @@ -187,7 +229,6 @@ func (in IN) NegationBuild(builder Builder) { fallthrough default: - builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') @@ -203,11 +244,24 @@ type Eq struct { func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) - if eq.Value == nil { - builder.WriteString(" IS NULL") - } else { - builder.WriteString(" = ") - builder.AddVar(builder, eq.Value) + switch eq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" IN (") + rv := reflect.ValueOf(eq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } } } @@ -221,11 +275,24 @@ type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) - if neq.Value == nil { - builder.WriteString(" IS NOT NULL") - } else { - builder.WriteString(" <> ") - builder.AddVar(builder, neq.Value) + switch neq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" NOT IN (") + rv := reflect.ValueOf(neq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } } } @@ -299,3 +366,16 @@ func (like Like) NegationBuild(builder Builder) { builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } + +func eqNil(value interface{}) bool { + if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { + value, _ = valuer.Value() + } + + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/clause/expression_test.go b/clause/expression_test.go index 830824867..4826db381 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -60,6 +60,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name AND name2 = @@name", + Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}}, + Result: "name1 = ? AND name2 = @@name", + ExpectedVars: []interface{}{"jinzhu"}, }, { SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, @@ -73,17 +78,50 @@ func TestNamedExpr(t *testing.T) { }, { SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, - Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", - ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { - SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist", Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, - Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", - ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{}, Result: "create table ? (? ?, ? ?)", + }, { + SQL: "name1 = @name AND name2 = @name;", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?;", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, + Result: "`table`.`col`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}}, + Result: "table.col", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}}, + Result: "table.id", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}}, + Result: "`table`.`col` AS `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}}, + Result: "table.col AS alias", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}}, + Result: "`table` `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}}, + Result: "table alias", }} for idx, result := range results { @@ -101,3 +139,84 @@ func TestNamedExpr(t *testing.T) { }) } } + +func TestExpression(t *testing.T) { + column := "column-name" + results := []struct { + Expressions []clause.Expression + ExpectedVars []interface{} + Result string + }{{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: "column-value"}, + }, + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` = ?", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: nil}, + clause.Eq{Column: column, Value: (*string)(nil)}, + clause.Eq{Column: column, Value: (*int)(nil)}, + clause.Eq{Column: column, Value: (*bool)(nil)}, + clause.Eq{Column: column, Value: (interface{})(nil)}, + clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}}, + }, + Result: "`column-name` IS NULL", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: "column-value"}, + }, + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` <> ?", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: nil}, + clause.Neq{Column: column, Value: (*string)(nil)}, + clause.Neq{Column: column, Value: (*int)(nil)}, + clause.Neq{Column: column, Value: (*bool)(nil)}, + clause.Neq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NOT NULL", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`id`) = ?", + }, { + Expressions: []clause.Expression{ + clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`users`.`id`) >= ?", + }} + + for idx, result := range results { + for idy, expression := range result.Expressions { + t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + expression.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } + } +} diff --git a/clause/group_by.go b/clause/group_by.go index 882319169..84242fb8a 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -39,4 +39,10 @@ func (groupBy GroupBy) MergeClause(clause *Clause) { groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy + + if len(groupBy.Columns) == 0 { + clause.Name = "" + } else { + clause.Name = groupBy.Name() + } } diff --git a/clause/group_by_test.go b/clause/group_by_test.go index 589f96130..7c282cb96 100644 --- a/clause/group_by_test.go +++ b/clause/group_by_test.go @@ -18,7 +18,8 @@ func TestGroupBy(t *testing.T) { Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}, }}, - "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", + []interface{}{"admin"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ @@ -28,7 +29,8 @@ func TestGroupBy(t *testing.T) { Columns: []clause.Column{{Name: "gender"}}, Having: []clause.Expression{clause.Neq{"gender", "U"}}, }}, - "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", + []interface{}{"admin", "U"}, }, } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 47fe169ca..309c5fcd2 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -1,11 +1,13 @@ package clause type OnConflict struct { - Columns []Column - Where Where - DoNothing bool - DoUpdates Set - UpdateAll bool + Columns []Column + Where Where + TargetWhere Where + OnConstraint string + DoNothing bool + DoUpdates Set + UpdateAll bool } func (OnConflict) Name() string { @@ -25,9 +27,15 @@ func (onConflict OnConflict) Build(builder Builder) { builder.WriteString(`) `) } - if len(onConflict.Where.Exprs) > 0 { - builder.WriteString("WHERE ") - onConflict.Where.Build(builder) + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } + + if onConflict.OnConstraint != "" { + builder.WriteString("ON CONSTRAINT ") + builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') } @@ -37,6 +45,12 @@ func (onConflict OnConflict) Build(builder Builder) { builder.WriteString("DO UPDATE SET ") onConflict.DoUpdates.Build(builder) } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } } // MergeClause merge onConflict clauses diff --git a/clause/order_by_test.go b/clause/order_by_test.go index 8fd1e2a86..d8b5dfbf6 100644 --- a/clause/order_by_test.go +++ b/clause/order_by_test.go @@ -45,7 +45,8 @@ func TestOrderBy(t *testing.T) { Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, }, }, - "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", + []interface{}{1, 2, 3}, }, } diff --git a/clause/returning.go b/clause/returning.go index 04bc96dab..d94b7a4ca 100644 --- a/clause/returning.go +++ b/clause/returning.go @@ -11,12 +11,16 @@ func (returning Returning) Name() string { // Build build where clause func (returning Returning) Build(builder Builder) { - for idx, column := range returning.Columns { - if idx > 0 { - builder.WriteByte(',') - } + if len(returning.Columns) > 0 { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(column) + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') } } diff --git a/clause/select.go b/clause/select.go index b93b87690..d8e9f8015 100644 --- a/clause/select.go +++ b/clause/select.go @@ -43,3 +43,17 @@ func (s Select) MergeClause(clause *Clause) { clause.Expression = s } } + +// CommaExpression represents a group of expressions separated by commas. +type CommaExpression struct { + Exprs []Expression +} + +func (comma CommaExpression) Build(builder Builder) { + for idx, expr := range comma.Exprs { + if idx > 0 { + _, _ = builder.WriteString(", ") + } + expr.Build(builder) + } +} diff --git a/clause/select_test.go b/clause/select_test.go index b7296434e..18bc2693b 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -31,6 +31,35 @@ func TestSelect(t *testing.T) { }, clause.From{}}, "SELECT `name` FROM `users`", nil, }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}}, + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}}, + clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}}, + }, + }, + }, clause.From{}}, + "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "? as name", + Vars: []interface{}{clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, + }, + }, + }, + }, + }, clause.From{}}, + "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + }, } for idx, result := range results { diff --git a/clause/set.go b/clause/set.go index 6a885711a..75eb6bdda 100644 --- a/clause/set.go +++ b/clause/set.go @@ -24,9 +24,9 @@ func (set Set) Build(builder Builder) { builder.AddVar(builder, assignment.Value) } } else { - builder.WriteQuoted(PrimaryColumn) + builder.WriteQuoted(Column{Name: PrimaryKey}) builder.WriteByte('=') - builder.WriteQuoted(PrimaryColumn) + builder.WriteQuoted(Column{Name: PrimaryKey}) } } diff --git a/clause/set_test.go b/clause/set_test.go index 56fac7060..7a9ee895a 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -20,7 +20,8 @@ func TestSet(t *testing.T) { clause.Update{}, clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), }, - "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, + "UPDATE `users` SET `users`.`id`=?", + []interface{}{1}, }, { []clause.Interface{ @@ -28,7 +29,8 @@ func TestSet(t *testing.T) { clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), }, - "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, + "UPDATE `users` SET `name`=?", + []interface{}{"jinzhu"}, }, } diff --git a/clause/values_test.go b/clause/values_test.go index 9c02c8a54..1eea8652c 100644 --- a/clause/values_test.go +++ b/clause/values_test.go @@ -21,7 +21,8 @@ func TestValues(t *testing.T) { Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, }, }, - "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", + []interface{}{"jinzhu", 18, "josh", 1}, }, } diff --git a/clause/where.go b/clause/where.go index 00b1a40e9..10b6df856 100644 --- a/clause/where.go +++ b/clause/where.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + // Where where clause type Where struct { Exprs []Expression @@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) { } } - buildExprs(where.Exprs, builder, " AND ") + buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { @@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") + builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } @@ -46,20 +51,23 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + case NamedExpr: + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } @@ -89,9 +97,14 @@ func (where Where) MergeClause(clause *Clause) { func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil - } else if len(exprs) == 1 { - return exprs[0] } + + if len(exprs) == 1 { + if _, ok := exprs[0].(OrConditions); !ok { + return exprs[0] + } + } + return AndConditions{Exprs: exprs} } @@ -102,10 +115,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) } } @@ -123,10 +136,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) } } @@ -148,7 +161,7 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(" AND ") + builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { @@ -157,8 +170,8 @@ func (not NotConditions) Build(builder Builder) { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { - sql := strings.ToLower(e.SQL) - if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } diff --git a/clause/where_test.go b/clause/where_test.go index 2fa11d76e..35e3dbeeb 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -17,25 +17,29 @@ func TestWhere(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", + []interface{}{"1", 18, "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", + []interface{}{"1", "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ @@ -43,7 +47,8 @@ func TestWhere(t *testing.T) { }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ @@ -51,13 +56,54 @@ func TestWhere(t *testing.T) { }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, }}, - "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, - "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + []interface{}{18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, }, } diff --git a/clause/with.go b/clause/with.go index 7e9eaef17..0768488e5 100644 --- a/clause/with.go +++ b/clause/with.go @@ -1,4 +1,3 @@ package clause -type With struct { -} +type With struct{} diff --git a/errors.go b/errors.go index 087550832..49cbfe64a 100644 --- a/errors.go +++ b/errors.go @@ -2,13 +2,15 @@ package gorm import ( "errors" + + "gorm.io/gorm/logger" ) var ( // ErrRecordNotFound record not found error - ErrRecordNotFound = errors.New("record not found") + ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") + ErrInvalidTransaction = errors.New("invalid transaction") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause @@ -31,4 +33,12 @@ var ( ErrEmptySlice = errors.New("empty slice found") // ErrDryRunModeUnsupported dry run mode unsupported ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") + // ErrInvalidDB invalid db + ErrInvalidDB = errors.New("invalid db") + // ErrInvalidValue invalid value + ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") + // ErrInvalidValueOfLength invalid values do not match length + ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ) diff --git a/finisher_api.go b/finisher_api.go index f2aed8da6..f994ec318 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -15,10 +15,13 @@ import ( // Create insert the value into database func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) - return + return tx.callbacks.Create().Execute(tx) } // CreateInBatches insert the value in batches into database @@ -27,19 +30,40 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: + var rowsAffected int64 tx = db.getInstance() - for i := 0; i < reflectValue.Len(); i += batchSize { - tx.AddError(tx.Transaction(func(tx *DB) error { + + callFc := func(tx *DB) error { + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize - if ends > reflectValue.Len() { - ends = reflectValue.Len() + if ends > reflectLen { + ends = reflectLen } - return tx.Create(reflectValue.Slice(i, ends).Interface()).Error - })) + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil } + + if tx.SkipDefaultTransaction { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + + tx.RowsAffected = rowsAffected default: - return db.Create(value) + tx = db.getInstance() + tx.Statement.Dest = value + tx = tx.callbacks.Create().Execute(tx) } return } @@ -55,13 +79,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx.callbacks.Create().Execute(tx) + tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { - if _, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) - return + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { + return tx.callbacks.Create().Execute(tx) } } } @@ -74,11 +97,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx.callbacks.Update().Execute(tx) + tx = tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -93,24 +116,26 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Take return a record that match given conditions, the order will depend on the database implementation func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Last find last record that match given conditions, order by primary key @@ -120,23 +145,25 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { Desc: true, }) if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Find find records that match given conditions func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // FindInBatches find records in batches @@ -157,15 +184,23 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if result.Error == nil && result.RowsAffected != 0 { tx.AddError(fc(result, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) } if tx.Error != nil || int(result.RowsAffected) < batchSize { break - } else { - resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } + + // Optimize for-break + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } + + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected @@ -181,11 +216,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { @@ -193,8 +228,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: - exprs := tx.Statement.BuildCondition(value) - tx.assignInterfacesToValue(exprs) + if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } default: if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -202,23 +238,24 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(reflectValue); !isZero { + if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) } } } } } } else if len(values) > 0 { - exprs := tx.Statement.BuildCondition(values[0], values[1:]...) - tx.assignInterfacesToValue(exprs) + if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } return } } } } - +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -244,95 +281,97 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + if tx = queryTx.Find(dest, conds...); tx.Error == nil { + if tx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } } - } - // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) - } + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } - // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) - } + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } } } - } - return tx.Model(dest).Updates(assigns) + return tx.Model(dest).Updates(assigns) + } } - - return db + return tx } // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)}) + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } } tx.Statement.Dest = value - tx.callbacks.Delete().Execute(tx) - return + return tx.callbacks.Delete().Execute(tx) } func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() + if len(tx.Statement.Preloads) > 0 { + tx.AddError(ErrPreloadNotAllowed) + return + } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { @@ -340,51 +379,62 @@ func (db *DB) Count(count *int64) (tx *DB) { }() } - if len(tx.Statement.Selects) == 0 { - tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + tx.Statement.Clauses["SELECT"] = selectClause + }() + } else { defer delete(tx.Statement.Clauses, "SELECT") - } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { - expr := clause.Expr{SQL: "count(1)"} + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) + } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { + expr := clause.Expr{SQL: "count(*)"} if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(dbName); f != nil { - dbName = f.DBName + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } } - } - if tx.Statement.Distinct { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} - } else { - expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else if dbName != "*" { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } } tx.Statement.AddClause(clause.Select{Expression: expr}) - defer tx.Statement.AddClause(clause.Select{}) } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { - delete(db.Statement.Clauses, "ORDER BY") + delete(tx.Statement.Clauses, "ORDER BY") defer func() { - db.Statement.Clauses["ORDER BY"] = orderByClause + tx.Statement.Clauses["ORDER BY"] = orderByClause }() } } tx.Statement.Dest = count - tx.callbacks.Query().Execute(tx) - if tx.RowsAffected != 1 { + tx = tx.callbacks.Query().Execute(tx) + + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } + return } func (db *DB) Row() *sql.Row { - tx := db.getInstance().InstanceSet("rows", false) - tx.callbacks.Row().Execute(tx) + tx := db.getInstance().Set("rows", false) + tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) @@ -393,8 +443,8 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().InstanceSet("rows", true) - tx.callbacks.Row().Execute(tx) + tx := db.getInstance().Set("rows", true) + tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { tx.Error = ErrDryRunModeUnsupported @@ -411,13 +461,13 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() tx.Config = &config - if rows, err := tx.Rows(); err != nil { - tx.AddError(err) - } else { - defer rows.Close() + if rows, err := tx.Rows(); err == nil { if rows.Next() { tx.ScanRows(rows, dest) + } else { + tx.RowsAffected = 0 } + tx.AddError(rows.Close()) } currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { @@ -429,7 +479,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // Pluck used to query single column from a model as a map // var ages []int64 -// db.Find(&users).Pluck("age", &ages) +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { @@ -438,18 +488,17 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { column = f.DBName } } - } else if tx.Statement.Table == "" { - tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) - tx.Statement.AddClauseIfNotExists(clause.Select{ - Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, - }) + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { @@ -460,31 +509,65 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { - tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + elem := tx.Statement.ReflectValue.Elem() + if !elem.IsValid() { + elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) + tx.Statement.ReflectValue.Set(elem) + } + tx.Statement.ReflectValue = elem } - Scan(rows, tx, true) + Scan(rows, tx, ScanInitialized) return tx.Error } +// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +func (db *DB) Connection(fc func(tx *DB) error) (err error) { + if db.Error != nil { + return db.Error + } + + tx := db.getInstance() + sqlDB, err := tx.DB() + if err != nil { + return + } + + conn, err := sqlDB.Conn(tx.Statement.Context) + if err != nil { + return + } + + defer conn.Close() + tx.Statement.ConnPool = conn + return fc(tx) +} + // Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) + if !db.DisableNestedTransaction { + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + if err != nil { + return } - }() - if err == nil { - err = fc(db.Session(&Session{})) + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() } + + err = fc(db.Session(&Session{})) } else { tx := db.Begin(opts...) + if tx.Error != nil { + return tx.Error + } defer func() { // Make sure to rollback when panic, Block error or Commit error @@ -493,12 +576,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - if err = tx.Error; err == nil { - err = fc(tx) - } - - if err == nil { - err = tx.Commit().Error + if err = fc(tx); err == nil { + panicked = false + return tx.Commit().Error } } @@ -510,7 +590,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.Session(&Session{Context: db.Statement.Context}) + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions err error ) @@ -585,6 +665,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) } - tx.callbacks.Raw().Execute(tx) - return + return tx.callbacks.Raw().Execute(tx) } diff --git a/go.mod b/go.mod index faf63a46b..573627455 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.1 + github.com/jinzhu/now v1.1.4 ) diff --git a/go.sum b/go.sum index 148bd6f53..50fbba2fc 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= -github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 1947b4dff..7967b0945 100644 --- a/gorm.go +++ b/gorm.go @@ -3,8 +3,8 @@ package gorm import ( "context" "database/sql" - "errors" "fmt" + "sort" "sync" "time" @@ -13,6 +13,9 @@ import ( "gorm.io/gorm/schema" ) +// for Config.cacheStore store PreparedStmtDB key +const preparedStmtDBKey = "preparedStmt" + // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity @@ -34,10 +37,14 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -52,6 +59,32 @@ type Config struct { cacheStore *sync.Map } +// Apply update config to new config +func (c *Config) Apply(config *Config) error { + if config != c { + *config = *c + } + return nil +} + +// AfterInitialize initialize plugins after db connected +func (c *Config) AfterInitialize(db *DB) error { + if db != nil { + for _, plugin := range c.Plugins { + if err := plugin.Initialize(db); err != nil { + return err + } + } + } + return nil +} + +// Option gorm option interface +type Option interface { + Apply(*Config) error + AfterInitialize(*DB) error +} + // DB GORM DB definition type DB struct { *Config @@ -63,23 +96,49 @@ type DB struct { // Session session config when create session with Session() method type Session struct { - DryRun bool - PrepareStmt bool - NewDB bool - SkipHooks bool - SkipDefaultTransaction bool - AllowGlobalUpdate bool - FullSaveAssociations bool - QueryFields bool - Context context.Context - Logger logger.Interface - NowFunc func() time.Time + DryRun bool + PrepareStmt bool + NewDB bool + Initialized bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector -func Open(dialector Dialector, config *Config) (db *DB, err error) { - if config == nil { - config = &Config{} +func Open(dialector Dialector, opts ...Option) (db *DB, err error) { + config := &Config{} + + sort.Slice(opts, func(i, j int) bool { + _, isConfig := opts[i].(*Config) + _, isConfig2 := opts[j].(*Config) + return isConfig && !isConfig2 + }) + + for _, opt := range opts { + if opt != nil { + if err := opt.Apply(config); err != nil { + return nil, err + } + defer func(opt Option) { + if errr := opt.AfterInitialize(db); errr != nil { + err = errr + } + }(opt) + } + } + + if d, ok := dialector.(interface{ Apply(*Config) error }); ok { + if err = d.Apply(config); err != nil { + return + } } if config.NamingStrategy == nil { @@ -120,11 +179,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, + Stmts: map[string]Stmt{}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } - db.cacheStore.Store("preparedStmt", preparedStmt) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) if config.PrepareStmt { db.ConnPool = preparedStmt @@ -157,9 +216,13 @@ func (db *DB) Session(config *Session) *DB { tx = &DB{ Config: &txConfig, Statement: db.Statement, + Error: db.Error, clone: 1, } ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } if config.SkipDefaultTransaction { tx.Config.SkipDefaultTransaction = true @@ -183,7 +246,7 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { - if v, ok := db.cacheStore.Load("preparedStmt"); ok { + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, @@ -199,6 +262,10 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.SkipHooks = true } + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + if !config.NewDB { tx.clone = 2 } @@ -219,6 +286,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.NowFunc = config.NowFunc } + if config.Initialized { + tx = tx.getInstance() + } + return tx } @@ -277,20 +348,20 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - if stmtDB, ok := connPool.(*PreparedStmtDB); ok { - connPool = stmtDB.ConnPool + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() } if sqldb, ok := connPool.(*sql.DB); ok { return sqldb, nil } - return nil, errors.New("invalid db") + return nil, ErrInvalidDB } func (db *DB) getInstance() *DB { if db.clone > 0 { - tx := &DB{Config: db.Config} + tx := &DB{Config: db.Config, Error: db.Error} if db.clone == 1 { // clone with new statement @@ -313,10 +384,12 @@ func (db *DB) getInstance() *DB { return db } +// Expr returns clause.Expr, which can be used to pass SQL expression as params func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } +// SetupJoinTable setup join table schema func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { var ( tx = db.getInstance() @@ -324,56 +397,73 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac modelSchema, joinSchema *schema.Schema ) - if err := stmt.Parse(model); err == nil { - modelSchema = stmt.Schema - } else { + err := stmt.Parse(model) + if err != nil { return err } + modelSchema = stmt.Schema - if err := stmt.Parse(joinTable); err == nil { - joinSchema = stmt.Schema - } else { + err = stmt.Parse(joinTable) + if err != nil { return err } + joinSchema = stmt.Schema - if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { - for _, ref := range relation.References { - if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { - f.DataType = ref.ForeignKey.DataType - f.GORMDataType = ref.ForeignKey.GORMDataType - if f.Size == 0 { - f.Size = ref.ForeignKey.Size - } - ref.ForeignKey = f - } else { - return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) - } + relation, ok := modelSchema.Relationships.Relations[field] + isRelation := ok && relation.JoinTable != nil + if !isRelation { + return fmt.Errorf("failed to found relation: %s", field) + } + + for _, ref := range relation.References { + f := joinSchema.LookUpField(ref.ForeignKey.DBName) + if f == nil { + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) } - for name, rel := range relation.JoinTable.Relationships.Relations { - if _, ok := joinSchema.Relationships.Relations[name]; !ok { - rel.Schema = joinSchema - joinSchema.Relationships.Relations[name] = rel - } + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size } + ref.ForeignKey = f + } - relation.JoinTable = joinSchema - } else { - return fmt.Errorf("failed to found relation: %v", field) + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } } + relation.JoinTable = joinSchema return nil } -func (db *DB) Use(plugin Plugin) (err error) { +// Use use plugin +func (db *DB) Use(plugin Plugin) error { name := plugin.Name() - if _, ok := db.Plugins[name]; !ok { - if err = plugin.Initialize(db); err == nil { - db.Plugins[name] = plugin - } - } else { + if _, ok := db.Plugins[name]; ok { return ErrRegistered } + if err := plugin.Initialize(db); err != nil { + return err + } + db.Plugins[name] = plugin + return nil +} - return err +// ToSQL for generate SQL string. +// +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) +func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { + tx := queryFn(db.Session(&Session{DryRun: true})) + stmt := tx.Statement + + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) } diff --git a/interfaces.go b/interfaces.go index e933952bb..44a85cb51 100644 --- a/interfaces.go +++ b/interfaces.go @@ -40,14 +40,17 @@ type SavePointerDialectorInterface interface { RollbackTo(tx *DB, name string) error } +// TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +// ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error @@ -57,3 +60,8 @@ type TxCommitter interface { type Valuer interface { GormValue(context.Context, *DB) clause.Expr } + +// GetDBConnector SQL db connector +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} diff --git a/logger/logger.go b/logger/logger.go index 11619c92f..2ffd28d5a 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "errors" "fmt" "io/ioutil" "log" @@ -11,6 +12,9 @@ import ( "gorm.io/gorm/utils" ) +// ErrRecordNotFound record not found error +var ErrRecordNotFound = errors.New("record not found") + // Colors const ( Reset = "\033[0m" @@ -27,13 +31,17 @@ const ( YellowBold = "\033[33;1m" ) -// LogLevel +// LogLevel log level type LogLevel int const ( + // Silent silent log level Silent LogLevel = iota + 1 + // Error error log level Error + // Warn warn log level Warn + // Info info log level Info ) @@ -42,10 +50,12 @@ type Writer interface { Printf(string, ...interface{}) } +// Config logger config type Config struct { - SlowThreshold time.Duration - Colorful bool - LogLevel LogLevel + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + LogLevel LogLevel } // Interface logger interface @@ -54,19 +64,24 @@ type Interface interface { Info(context.Context, string, ...interface{}) Warn(context.Context, string, ...interface{}) Error(context.Context, string, ...interface{}) - Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) + Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) } var ( + // Discard Discard logger will print any log to ioutil.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: Warn, - Colorful: true, + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + IgnoreRecordNotFoundError: false, + Colorful: true, }) + // Recorder Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +// New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " @@ -135,31 +150,33 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - if l.LogLevel > 0 { - elapsed := time.Since(begin) - switch { - case err != nil && l.LogLevel >= Error: - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: - sql, rows := fc() - slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) - if rows == -1 { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case l.LogLevel >= Info: - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) - } + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } } @@ -172,10 +189,12 @@ type traceRecorder struct { Err error } +// New new trace recorder func (l traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } +// Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() diff --git a/logger/sql.go b/logger/sql.go index d080def21..c8b194c3c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -13,20 +13,29 @@ import ( "gorm.io/gorm/utils" ) -func isPrintable(s []byte) bool { +const ( + tmFmtWithMS = "2006-01-02 15:04:05.999" + tmFmtZero = "0000-00-00 00:00:00" + nullStr = "NULL" +) + +func isPrintable(s string) bool { for _, r := range s { - if !unicode.IsPrint(rune(r)) { + if !unicode.IsPrint(r) { return false } } return true } -var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { - var convertParams func(interface{}, int) - var vars = make([]string, len(avars)) + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { @@ -34,26 +43,19 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { - vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + vars[idx] = escaper + tmFmtZero + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } case *time.Time: if v != nil { if v.IsZero() { - vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + vars[idx] = escaper + tmFmtZero + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } } else { - vars[idx] = "NULL" - } - case fmt.Stringer: - reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = "NULL" + vars[idx] = nullStr } case driver.Valuer: reflectValue := reflect.ValueOf(v) @@ -61,11 +63,29 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a r, _ := v.Value() convertParams(r, idx) } else { - vars[idx] = "NULL" + vars[idx] = nullStr + } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + } else { + vars[idx] = nullStr + } } case []byte: - if isPrintable(v) { - vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } @@ -74,24 +94,24 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: - vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { - vars[idx] = "NULL" + vars[idx] = nullStr } else if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) } else { - for _, t := range convertableTypes { + for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { convertParams(rv.Convert(t).Interface(), idx) return } } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 71aa841af..c5b181a9c 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper } func TestExplainSQL(t *testing.T) { diff --git a/migrator.go b/migrator.go index 28ac35e7d..524438770 100644 --- a/migrator.go +++ b/migrator.go @@ -1,13 +1,26 @@ package gorm import ( + "reflect" + "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db.Session(&Session{})) + tx := db.getInstance() + + // apply scopes to migrator + for len(tx.Statement.scopes) > 0 { + scopes := tx.Statement.scopes + tx.Statement.scopes = nil + for _, scope := range scopes { + tx = scope(tx) + } + } + + return tx.Dialector.Migrator(tx.Session(&Session{})) } // AutoMigrate run auto migration for given models @@ -22,14 +35,23 @@ type ViewOption struct { Query *DB } +// ColumnType column type interface type ColumnType interface { Name() string - DatabaseTypeName() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) Length() (length int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool) Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) } +// Migrator migrator interface type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error @@ -43,6 +65,7 @@ type Migrator interface { DropTable(dst ...interface{}) error HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error + GetTables() (tableList []string, err error) // Columns AddColumn(dst interface{}, field string) error diff --git a/migrator/column_type.go b/migrator/column_type.go new file mode 100644 index 000000000..cc1331b92 --- /dev/null +++ b/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. lke `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 5de820a83..a50bb3ff8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -2,6 +2,7 @@ package migrator import ( "context" + "database/sql" "fmt" "reflect" "regexp" @@ -12,6 +13,11 @@ import ( "gorm.io/gorm/schema" ) +var ( + regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) +) + // Migrator m struct type Migrator struct { Config @@ -24,10 +30,12 @@ type Config struct { gorm.Dialector } +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } +// RunWithValue run migration with statement value func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -37,13 +45,14 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - } else if err := stmt.Parse(value); err != nil { + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } return fc(stmt) } +// DataTypeOf return field's db data type func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { @@ -55,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } +// FullDataTypeOf returns field's db full data type func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) @@ -79,10 +89,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } -// AutoMigrate +// AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err @@ -91,11 +101,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType for _, columnType := range columnTypes { - if columnType.Name() == field.DBName { + if columnType.Name() == dbName { foundColumn = columnType break } @@ -103,7 +114,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { + if err := tx.Migrator().AddColumn(value, dbName); err != nil { return err } } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { @@ -114,13 +125,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err } } } @@ -152,9 +160,17 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +// GetTables returns tables +func (m Migrator) GetTables() (tableList []string, err error) { + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return +} + +// CreateTable create table in database for values func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" @@ -164,10 +180,12 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] - createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) - createTableSQL += "," + if !field.IgnoreMigration { + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } } if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { @@ -183,7 +201,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { - errr = tx.Migrator().CreateIndex(value, name) + if errr == nil { + errr = tx.Migrator().CreateIndex(value, name) + } }(value, idx.Name) } else { if idx.Class != "" { @@ -191,6 +211,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { } createTableSQL += "INDEX ? ?" + if idx.Comment != "" { + createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + if idx.Option != "" { createTableSQL += " " + idx.Option } @@ -234,10 +258,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { return nil } +// DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - tx := m.DB.Session(&gorm.Session{NewDB: true}) + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { @@ -247,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error { return nil } +// HasTable returns table exists or not for value, value could be a struct or string func (m Migrator) HasTable(value interface{}) bool { var count int64 @@ -258,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +// RenameTable rename table from oldName to newName func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable interface{} if v, ok := oldName.(string); ok { @@ -285,18 +312,27 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } -func (m Migrator) AddColumn(value interface{}, field string) error { +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { + // avoid using the same name field + f := stmt.Schema.LookUpField(name) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", name) + } + + if !f.IgnoreMigration { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), ).Error } - return fmt.Errorf("failed to look up field with name: %s", field) + + return nil }) } +// DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { @@ -309,10 +345,11 @@ func (m Migrator) DropColumn(value interface{}, name string) error { }) } +// AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - fileType := clause.Expr{SQL: m.DataTypeOf(field)} + fileType := m.FullDataTypeOf(field) return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, @@ -323,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +// HasColumn check has column `field` for value or not func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -341,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +// RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { @@ -358,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +// MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) @@ -366,14 +406,16 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn := false // check size - if length, _ := columnType.Length(); length != int64(field.Size) { + if length, ok := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1) - matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { alterColumn = true } } @@ -381,7 +423,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check precision if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true } } @@ -394,35 +436,72 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn { + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } return nil } -func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { - columnTypes = make([]gorm.ColumnType, 0) - err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() - if err == nil { - defer rows.Close() - rawColumnTypes, err := rows.ColumnTypes() - if err == nil { - for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) - } - } +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + if err != nil { + return err } - return err + + defer func() { + err = rows.Close() + }() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) + } + + return }) - return + + return columnTypes, execErr } +// CreateView create view func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } +// DropView drop view func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -449,66 +528,110 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +// GuessConstraintAndTable guess statement's constraint and it's table based on name +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for k := range checkConstraints { + if checkConstraints[k].Field == field { + v := checkConstraints[k] + return nil, &v, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + + return nil, nil, stmt.Schema.Table +} + +// CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - checkConstraints := stmt.Schema.ParseCheckConstraints() - if chk, ok := checkConstraints[name]; ok { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - sql, values := buildConstraint(constraint) - return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error - } - } - - err := fmt.Errorf("failed to create constraint with name %v", name) - if field := stmt.Schema.LookUpField(name); field != nil { - for _, cc := range checkConstraints { - if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { - return err - } - } - - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { - return err - } - } + if constraint != nil { + vars := []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - return err + return nil }) } +// DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec( - "ALTER TABLE ? DROP CONSTRAINT ?", - m.CurrentTable(stmt), clause.Column{Name: name}, - ).Error + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) } +// HasConstraint check has constraint or not func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", - currentDatabase, stmt.Table, name, + currentDatabase, table, name, ).Row().Scan(&count) }) return count > 0 } +// BuildIndexOptions build index options func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) @@ -530,10 +653,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } +// BuildIndexOptionsInterface build index options interface type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } +// CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -550,6 +675,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " USING " + idx.Type } + if idx.Comment != "" { + createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + if idx.Option != "" { createIndexSQL += " " + idx.Option } @@ -557,10 +686,11 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return m.DB.Exec(createIndexSQL, values...).Error } - return fmt.Errorf("failed to create index with name %v", name) + return fmt.Errorf("failed to create index with name %s", name) }) } +// DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -571,6 +701,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { }) } +// HasIndex check has index `name` or not func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -588,6 +719,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } +// RenameIndex rename index from oldName to newName func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( @@ -597,6 +729,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error }) } +// CurrentDatabase returns current database name func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return @@ -667,13 +800,15 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } orderedModelNamesMap[name] = true - dep := valuesMap[name] - for _, d := range dep.Depends { - if _, ok := valuesMap[d.Table]; ok { - insertIntoOrderedList(d.Table) - } else if autoAdd { - parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) - insertIntoOrderedList(d.Table) + if autoAdd { + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } } } @@ -698,6 +833,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i return } +// CurrentTable returns current statement's table expression func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { if stmt.TableExpr != nil { return *stmt.TableExpr diff --git a/prepare_stmt.go b/prepare_stmt.go index eddee1f2b..88bec4e95 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -6,48 +6,67 @@ import ( "sync" ) +type Stmt struct { + *sql.Stmt + Transaction bool +} + type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt + Stmts map[string]Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool } +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + func (db *PreparedStmtDB) Close() { db.Mux.Lock() + defer db.Mux.Unlock() + for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) - stmt.Close() + go stmt.Close() } } - - db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() - if stmt, ok := db.Stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() return stmt, nil } db.Mux.RUnlock() db.Mux.Lock() + defer db.Mux.Unlock() + // double check - if stmt, ok := db.Stmts[query]; ok { - db.Mux.Unlock() + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { return stmt, nil + } else if ok { + go stmt.Close() } - stmt, err := db.ConnPool.PrepareContext(ctx, query) + stmt, err := conn.PrepareContext(ctx, query) if err == nil { - db.Stmts[query] = stmt + db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } - db.Mux.Unlock() - return stmt, err + return db.Stmts[query], err } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -59,35 +78,36 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { db.Mux.Lock() - stmt.Close() + defer db.Mux.Unlock() + go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return result, err } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { db.Mux.Lock() - stmt.Close() + defer db.Mux.Unlock() + + go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return rows, err } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(ctx, query) + stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -114,37 +134,39 @@ func (tx *PreparedStmtTX) Rollback() error { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() - stmt.Close() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return result, err } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() - stmt.Close() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return rows, err } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(ctx, query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/scan.go b/scan.go index c9c8f442e..0da12dafb 100644 --- a/scan.go +++ b/scan.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/schema" ) +// prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { @@ -49,9 +50,78 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func Scan(rows *sql.Rows, db *DB, initialized bool) { - columns, _ := rows.Columns() - values := make([]interface{}, len(columns)) +func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, column := range columns { + if sch == nil { + values[idx] = reflectValue.Interface() + } else if field := sch.LookUpField(column); field != nil && field.Readable { + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) + continue + } + } + values[idx] = &sql.RawBytes{} + } else if len(columns) == 1 { + sch = nil + values[idx] = reflectValue.Interface() + } else { + values[idx] = &sql.RawBytes{} + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + field.Set(db.Statement.Context, reflectValue, values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(db.Statement.Context, relValue, values[idx]) + } + } + } + } + } +} + +// ScanMode scan data mode +type ScanMode uint8 + +// scan modes +const ( + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 +) + +// Scan scan rows into db statement +func Scan(rows *sql.Rows, db *DB, mode ScanMode) { + var ( + columns, _ = rows.Columns() + values = make([]interface{}, len(columns)) + initialized = mode&ScanInitialized != 0 + update = mode&ScanUpdate != 0 + onConflictDonothing = mode&ScanOnConflictDoNothing != 0 + ) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { @@ -66,6 +136,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue, ok := dest.(map[string]interface{}) if !ok { if v, ok := dest.(*map[string]interface{}); ok { + if *v == nil { + *v = map[string]interface{}{} + } mapValue = *v } } @@ -84,159 +157,133 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: for initialized || rows.Next() { initialized = false db.RowsAffected++ db.AddError(rows.Scan(dest)) } default: - Schema := db.Statement.Schema + var ( + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue + ) - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - reflectValueType = db.Statement.ReflectValue.Type().Elem() - isPtr = reflectValueType.Kind() == reflect.Ptr - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - ) - - if isPtr { - reflectValueType = reflectValueType.Elem() - } + for reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + reflectValueType := reflectValue.Type() + switch reflectValueType.Kind() { + case reflect.Array, reflect.Slice: + reflectValueType = reflectValueType.Elem() + } + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } - if Schema != nil { - if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } + if sch != nil { + if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { + sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} } } - // pluck values into slice of data - isPluck := false - if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + if len(columns) == 1 { + // isPluck + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct - Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time - isPluck = true + sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + sch = nil } } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var elem reflect.Value + + if !update || reflectValue.Len() == 0 { + update = false + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } for initialized || rows.Next() { + BEGIN: initialized = false - db.RowsAffected++ - elem := reflect.New(reflectValueType) - if isPluck { - db.AddError(rows.Scan(elem.Interface())) - } else { - for idx, field := range fields { - if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } + if update { + if int(db.RowsAffected) >= reflectValue.Len() { + return } - - db.AddError(rows.Scan(values...)) - - for idx, field := range fields { - if len(joinFields) != 0 && joinFields[idx][0] != nil { - value := reflect.ValueOf(values[idx]).Elem() - relValue := joinFields[idx][0].ReflectValueOf(elem) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) + elem = reflectValue.Index(int(db.RowsAffected)) + if onConflictDonothing { + for _, field := range fields { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { + db.RowsAffected++ + goto BEGIN } - - field.Set(relValue, values[idx]) - } else if field != nil { - field.Set(elem, values[idx]) } } - } - - if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) } else { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + elem = reflect.New(reflectValueType) } - } - case reflect.Struct: - if db.Statement.ReflectValue.Type() != Schema.ModelType { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } - if initialized || rows.Next() { - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - continue - } - } - values[idx] = &sql.RawBytes{} + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + + if !update { + if isPtr { + reflectValue = reflect.Append(reflectValue, elem) } else { - values[idx] = &sql.RawBytes{} + reflectValue = reflect.Append(reflectValue, elem.Elem()) } } + } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.ReflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - value := reflect.ValueOf(values[idx]).Elem() - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(relValue, values[idx]) - } - } - } - } + if !update { + db.Statement.ReflectValue.Set(reflectValue) + } + case reflect.Struct, reflect.Ptr: + if initialized || rows.Next() { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } + default: + db.AddError(rows.Scan(dest)) } } - if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { + if err := rows.Err(); err != nil && err != db.Error { + db.AddError(err) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index dec41eba0..4583a2072 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -9,8 +9,7 @@ import ( "gorm.io/gorm/schema" ) -type UserWithCallback struct { -} +type UserWithCallback struct{} func (UserWithCallback) BeforeSave(*gorm.DB) error { return nil diff --git a/schema/check.go b/schema/check.go index 7d31ec70e..89e732d36 100644 --- a/schema/check.go +++ b/schema/check.go @@ -5,6 +5,9 @@ import ( "strings" ) +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") + type Check struct { Name string Constraint string // length(phone) >= 10 @@ -13,11 +16,11 @@ type Check struct { // ParseCheckConstraints parse schema check constraints func (schema *Schema) ParseCheckConstraints() map[string]Check { - var checks = map[string]Check{} + checks := map[string]Check{} for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") - if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) { + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { if names[0] == "" { diff --git a/schema/field.go b/schema/field.go index b303fb302..319f36934 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,6 +1,7 @@ package schema import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -14,18 +15,29 @@ import ( "gorm.io/gorm/utils" ) -type DataType string - -type TimeType int64 +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) -var TimeReflectType = reflect.TypeOf(time.Time{}) +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) +// GORM time types const ( - UnixSecond TimeType = 1 - UnixMillisecond TimeType = 2 - UnixNanosecond TimeType = 3 + UnixTime TimeType = 1 + UnixSecond TimeType = 2 + UnixMillisecond TimeType = 3 + UnixNanosecond TimeType = 4 ) +// GORM fields types const ( Bool DataType = "bool" Int DataType = "int" @@ -36,56 +48,73 @@ const ( Bytes DataType = "bytes" ) +// Field is the representation of model schema's field type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - GORMDataType DataType - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - Readable bool - HasDefaultValue bool - AutoCreateTime TimeType - AutoUpdateTime TimeType - DefaultValue string - DefaultValueInterface interface{} - NotNull bool - Unique bool - Comment string - Size int - Precision int - Scale int - FieldType reflect.Type - IndirectFieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + GORMDataType DataType + PrimaryKey bool + AutoIncrement bool + AutoIncrementIncrement int64 + Creatable bool + Updatable bool + Readable bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + HasDefaultValue bool + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + Scale int + IgnoreMigration bool + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + OwnerSchema *Schema + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface + NewValuePool FieldNewValuePool } +// ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { - var err error + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) field := &Field{ - Name: fieldStruct.Name, - BindNames: []string{fieldStruct.Name}, - FieldType: fieldStruct.Type, - IndirectFieldType: fieldStruct.Type, - StructField: fieldStruct, - Creatable: true, - Updatable: true, - Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, + Creatable: true, + Updatable: true, + Readable: true, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], + AutoIncrementIncrement: 1, } for field.IndirectFieldType.Kind() == reflect.Ptr { @@ -93,7 +122,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - // if field is valuer, used its value or first fields as data type + // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { @@ -101,31 +130,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { - rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { - for i := 0; i < rv.Type().NumField(); i++ { - newFieldType := rv.Type().Field(i).Type + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - - if rv.Type() != reflect.Indirect(fieldValue).Type() { + if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } - - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } } } } @@ -134,19 +169,27 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if dbName, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = dbName - } - - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + field.Serializer = v + } else { + var serializerName = field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } } - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { - field.AutoIncrement = true - field.HasDefaultValue = true + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { + field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) } if v, ok := field.TagSettings["DEFAULT"]; ok { @@ -168,21 +211,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Scale, _ = strconv.Atoi(s) } - if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } - - if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { - field.Unique = true - } - - if val, ok := field.TagSettings["COMMENT"]; ok { - field.Comment = val - } - // default value is function or null or blank (primary keys) + field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { @@ -190,36 +220,35 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Bool if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: field.DataType = String - if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") - field.DefaultValue = strings.Trim(field.DefaultValue, "\"") + field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: @@ -227,23 +256,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } case reflect.Array, reflect.Slice: - if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes } } - field.GORMDataType = field.DataType - if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoCreateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoCreateTime = UnixMillisecond @@ -253,7 +282,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoUpdateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoUpdateTime = UnixMillisecond @@ -289,11 +320,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // setup permission - if _, ok := field.TagSettings["-"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = false - field.DataType = "" + if val, ok := field.TagSettings["-"]; ok { + val = strings.ToLower(strings.TrimSpace(val)) + switch val { + case "-": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + case "all": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + field.IgnoreMigration = true + case "migration": + field.IgnoreMigration = true + } } if v, ok := field.TagSettings["->"]; ok { @@ -321,8 +364,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { - if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { + kind := reflect.Indirect(fieldValue).Kind() + switch kind { + case reflect.Struct: var err error field.Creatable = false field.Updatable = false @@ -330,7 +377,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } @@ -371,8 +418,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - } else { - schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: + schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } @@ -381,131 +429,157 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // ValueOf - switch { - case len(field.StructField.Index) == 1: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() + // Setup NewValuePool + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, } - default: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - v := reflect.Indirect(value) + } else if _, ok := fieldValue.(sql.Scanner); !ok { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } + } - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) - } else { - v = v.Field(-idx - 1) + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } - if v.Type().Elem().Kind() == reflect.Struct { - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } - } else { - return nil, true - } + // ValueOf returns field's value and if it is zero + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true } } - return v.Interface(), v.IsZero() } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero } - // ReflectValueOf - switch { - case len(field.StructField.Index) == 1: - if field.FieldType.Kind() == reflect.Ptr { - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + if zero { + return value, zero } - } else { - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer } + + return serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, false } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - } - default: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - v := reflect.Indirect(value) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) - } + } - if v.Kind() == reflect.Ptr { - if v.Type().Elem().Kind() == reflect.Struct { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - } + // ReflectValueOf returns field's reflect value + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if idx < len(field.StructField.Index)-1 { - v = v.Elem() - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() } } - return v } + return v } - fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) + // Optimal value type acquisition for v + reflectValType := reflectV.Type() - if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + if reflectValType.AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV) return - } else if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectValType.ConvertibleTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) + fieldType := field.FieldType.Elem() - if reflectV.Type().AssignableTo(field.FieldType.Elem()) { + if reflectValType.AssignableTo(fieldType) { if !fieldValue.IsValid() { - fieldValue = reflect.New(field.FieldType.Elem()) + fieldValue = reflect.New(fieldType) } else if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Set(reflect.New(fieldType)) } fieldValue.Elem().Set(reflectV) return - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + } else if reflectValType.ConvertibleTo(fieldType) { if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Set(reflect.New(fieldType)) } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue.Elem().Set(reflectV.Convert(fieldType)) return } } if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return } else { - err = setter(value, reflectV.Elem().Interface()) + err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = setter(value, v) + err = setter(ctx, value, v) } } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) } } @@ -515,191 +589,201 @@ func (field *Field) setupValuerAndSetter() { // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { - case bool: - field.ReflectValueOf(value).SetBool(data) - case *bool: - if data != nil { - field.ReflectValueOf(value).SetBool(*data) - } else { - field.ReflectValueOf(value).SetBool(false) + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) } + case bool: + field.ReflectValueOf(ctx, value).SetBool(data) case int64: - if data > 0 { - field.ReflectValueOf(value).SetBool(true) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) - field.ReflectValueOf(value).SetBool(b) + field.ReflectValueOf(ctx, value).SetBool(b) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } case int64: - field.ReflectValueOf(value).SetInt(data) + field.ReflectValueOf(ctx, value).SetInt(data) case int: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetInt(i) + field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { - field.ReflectValueOf(value).SetInt(0) + field.ReflectValueOf(ctx, value).SetInt(0) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } case uint64: - field.ReflectValueOf(value).SetUint(data) + field.ReflectValueOf(ctx, value).SetUint(data) case uint: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) } else { - field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetUint(i) + field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } case float64: - field.ReflectValueOf(value).SetFloat(data) + field.ReflectValueOf(ctx, value).SetFloat(data) case float32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValueOf(value).SetFloat(i) + field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } case string: - field.ReflectValueOf(value).SetString(data) + field.ReflectValueOf(ctx, value).SetString(data) case []byte: - field.ReflectValueOf(value).SetString(string(data)) + field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(utils.ToString(data)) + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: - field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } @@ -707,41 +791,49 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } case time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { - field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } case time.Time: - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil @@ -750,30 +842,30 @@ func (field *Field) setupValuerAndSetter() { } fieldValue.Elem().Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -788,32 +880,61 @@ func (field *Field) setupValuerAndSetter() { } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { - field.Set = func(value reflect.Value, v interface{}) (err error) { - return fallbackSetter(value, v, field.Set) + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) + } + } + } + } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } } + } else { + err = oldFieldSetter(ctx, value, v) } + return } } } diff --git a/schema/field_test.go b/schema/field_test.go index 64f4a909d..300e375b4 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "database/sql" "reflect" "sync" @@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -235,6 +236,7 @@ type UserWithPermissionControl struct { Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` Name7 string `gorm:"->:false;<-:create,update"` + Name8 string `gorm:"->;-:migration"` } func TestParseFieldWithPermission(t *testing.T) { @@ -243,7 +245,7 @@ func TestParseFieldWithPermission(t *testing.T) { t.Fatalf("Failed to parse user with permission, got error %v", err) } - fields := []schema.Field{ + fields := []*schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, @@ -252,9 +254,81 @@ func TestParseFieldWithPermission(t *testing.T) { {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, } for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) {}) + checkSchemaField(t, user, f, func(f *schema.Field) {}) + } +} + +type ( + ID int64 + INT int + INT8 int8 + INT16 int16 + INT32 int32 + INT64 int64 + UINT uint + UINT8 uint8 + UINT16 uint16 + UINT32 uint32 + UINT64 uint64 + FLOAT32 float32 + FLOAT64 float64 + BOOL bool + STRING string + TIME time.Time + BYTES []byte + + TypeAlias struct { + ID + INT `gorm:"column:fint"` + INT8 `gorm:"column:fint8"` + INT16 `gorm:"column:fint16"` + INT32 `gorm:"column:fint32"` + INT64 `gorm:"column:fint64"` + UINT `gorm:"column:fuint"` + UINT8 `gorm:"column:fuint8"` + UINT16 `gorm:"column:fuint16"` + UINT32 `gorm:"column:fuint32"` + UINT64 `gorm:"column:fuint64"` + FLOAT32 `gorm:"column:ffloat32"` + FLOAT64 `gorm:"column:ffloat64"` + BOOL `gorm:"column:fbool"` + STRING `gorm:"column:fstring"` + TIME `gorm:"column:ftime"` + BYTES `gorm:"column:fbytes"` + } +) + +func TestTypeAliasField(t *testing.T) { + alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) + } + + fields := []*schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true}, + {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, + {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, + {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, + {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, + {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, + {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, + {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, + {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, + {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, + {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, + {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, + {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, + {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, + {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, + {Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`}, + {Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`}, + } + + for _, f := range fields { + checkSchemaField(t, alias, f, func(f *schema.Field) {}) } } diff --git a/schema/index.go b/schema/index.go index b54e08ad2..5f775f30f 100644 --- a/schema/index.go +++ b/schema/index.go @@ -27,7 +27,7 @@ type IndexOption struct { // ParseIndexes parse schema indexes func (schema *Schema) ParseIndexes() map[string]Index { - var indexes = map[string]Index{} + indexes := map[string]Index{} for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { diff --git a/schema/interfaces.go b/schema/interfaces.go index 98abffbd4..a75a33c0d 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,22 +4,33 @@ import ( "gorm.io/gorm/clause" ) +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string } +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface } +// QueryClausesInterface query clauses interface type QueryClausesInterface interface { QueryClauses(*Field) []clause.Interface } +// UpdateClausesInterface update clauses interface type UpdateClausesInterface interface { UpdateClauses(*Field) []clause.Interface } +// DeleteClausesInterface delete clauses interface type DeleteClausesInterface interface { DeleteClauses(*Field) []clause.Interface } diff --git a/schema/model_test.go b/schema/model_test.go index 1f2b09481..9e6c3590f 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -26,9 +26,11 @@ type User struct { Active *bool } -type mytime time.Time -type myint int -type mybool = bool +type ( + mytime time.Time + myint int + mybool = bool +) type AdvancedDataTypeUser struct { ID sql.NullInt64 diff --git a/schema/naming.go b/schema/naming.go index e3b2104af..125094bcf 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -2,9 +2,10 @@ package schema import ( "crypto/sha1" + "encoding/hex" "fmt" + "regexp" "strings" - "sync" "unicode/utf8" "github.com/jinzhu/inflection" @@ -13,6 +14,7 @@ import ( // Namer namer interface type Namer interface { TableName(table string) string + SchemaName(table string) string ColumnName(table, column string) string JoinTableName(joinTable string) string RelationshipFKName(Relationship) string @@ -20,82 +22,115 @@ type Namer interface { IndexName(table, column string) string } +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string SingularTable bool + NameReplacer Replacer + NoLowerCase bool } // TableName convert string to table name func (ns NamingStrategy) TableName(str string) string { if ns.SingularTable { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName +func (ns NamingStrategy) SchemaName(table string) string { + table = strings.TrimPrefix(table, ns.TablePrefix) + + if ns.SingularTable { + return ns.toSchemaName(table) } - return ns.TablePrefix + inflection.Plural(toDBName(str)) + return ns.toSchemaName(inflection.Singular(table)) } // ColumnName convert string to column name func (ns NamingStrategy) ColumnName(table, column string) string { - return toDBName(column) + return ns.toDBName(column) } // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { - if strings.ToLower(str) == str { + if !ns.NoLowerCase && strings.ToLower(str) == str { return ns.TablePrefix + str } if ns.SingularTable { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + ns.toDBName(str) } - return ns.TablePrefix + inflection.Plural(toDBName(str)) + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) + return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) + return ns.formatName("chk", table, column) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { - idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) - idxName = strings.Replace(idxName, ".", "_", -1) + return ns.formatName("idx", table, ns.toDBName(column)) +} - if utf8.RuneCountInString(idxName) > 64 { +func (ns NamingStrategy) formatName(prefix, table, name string) string { + formattedName := strings.Replace(strings.Join([]string{ + prefix, table, name, + }, "_"), ".", "_", -1) + + if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() - h.Write([]byte(idxName)) + h.Write([]byte(formattedName)) bs := h.Sum(nil) - idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8] + formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] } - return idxName + return formattedName } var ( - smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialismsReplacer *strings.Replacer ) func init() { - var commonInitialismsForReplacer []string + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) for _, initialism := range commonInitialisms { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } -func toDBName(name string) string { +func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" - } else if v, ok := smap.Load(name); ok { - return v.(string) + } + + if ns.NameReplacer != nil { + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName + } + + if ns.NoLowerCase { + return name } var ( @@ -135,6 +170,13 @@ func toDBName(name string) string { buf.WriteByte(value[len(value)-1]) } ret := buf.String() - smap.Store(name, ret) return ret } + +func (ns NamingStrategy) toSchemaName(name string) string { + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") + for _, initialism := range commonInitialisms { + result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + } + return result +} diff --git a/schema/naming_test.go b/schema/naming_test.go index 26b0dcf6b..1fdab9a06 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,11 +1,12 @@ package schema import ( + "strings" "testing" ) func TestToDBName(t *testing.T) { - var maps = map[string]string{ + maps := map[string]string{ "": "", "x": "x", "X": "x", @@ -26,17 +27,39 @@ func TestToDBName(t *testing.T) { "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", } + ns := NamingStrategy{} for key, value := range maps { - if toDBName(key) != value { - t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key)) + if ns.toDBName(key) != value { + t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) + } + } + + maps = map[string]string{ + "x": "X", + "user_restrictions": "UserRestriction", + "this_is_a_test": "ThisIsATest", + "abc_and_jkl": "AbcAndJkl", + "employee_id": "EmployeeID", + "field_x": "FieldX", + "http_and_smtp": "HTTPAndSMTP", + "http_server_handler_for_url_id": "HTTPServerHandlerForURLID", + "uuid": "UUID", + "http_url": "HTTPURL", + "sha256_hash": "Sha256Hash", + "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID", + } + for key, value := range maps { + if ns.SchemaName(key) != value { + t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key)) } } } func TestNamingStrategy(t *testing.T) { - var ns = NamingStrategy{ + ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, + NameReplacer: strings.NewReplacer("CID", "Cid"), } idxName := ns.IndexName("public.table", "name") @@ -63,4 +86,125 @@ func TestNamingStrategy(t *testing.T) { if tableName != "public.company" { t.Errorf("invalid table name generated, got %v", tableName) } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +type CustomReplacer struct { + f func(string) string +} + +func (r CustomReplacer) Replace(name string) string { + return r.f(name) +} + +func TestCustomReplacer(t *testing.T) { + ns := NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: false, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_replaced_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.replaced_userlanguage" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.replaced_company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "replaced_name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestCustomReplacerWithNoLowerCase(t *testing.T) { + ns := NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: true, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_REPLACED_NAME" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.REPLACED_USER_LANGUAGES" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.REPLACED_USERLANGUAGE" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.REPLACED_COMPANY" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "REPLACED_NAME_Cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { + ns := NamingStrategy{} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + +func TestReplaceEmptyTableName(t *testing.T) { + ns := NamingStrategy{ + SingularTable: true, + NameReplacer: strings.NewReplacer("Model", ""), + } + tableName := ns.TableName("Model") + if tableName != "Model" { + t.Errorf("invalid table name generated, got %v", tableName) + } } diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 000000000..f5c73153d --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,62 @@ +package schema + +import ( + "reflect" + "sync" + "time" +) + +// sync pools +var ( + normalPool sync.Map + stringPool = &sync.Pool{ + New: func() interface{} { + var v string + ptrV := &v + return &ptrV + }, + } + intPool = &sync.Pool{ + New: func() interface{} { + var v int64 + ptrV := &v + return &ptrV + }, + } + uintPool = &sync.Pool{ + New: func() interface{} { + var v uint64 + ptrV := &v + return &ptrV + }, + } + floatPool = &sync.Pool{ + New: func() interface{} { + var v float64 + ptrV := &v + return &ptrV + }, + } + boolPool = &sync.Pool{ + New: func() interface{} { + var v bool + ptrV := &v + return &ptrV + }, + } + timePool = &sync.Pool{ + New: func() interface{} { + var v time.Time + ptrV := &v + return &ptrV + }, + } + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/schema/relationship.go b/schema/relationship.go index 35af111f3..eae8ab0b1 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,9 +1,9 @@ package schema import ( + "context" "fmt" "reflect" - "regexp" "strings" "github.com/jinzhu/inflection" @@ -18,6 +18,7 @@ const ( HasMany RelationshipType = "has_many" // HasManyRel has many relationship BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + has RelationshipType = "has" ) type Relationships struct { @@ -53,7 +54,7 @@ type Reference struct { OwnPrimaryKey bool } -func (schema *Schema) parseRelation(field *Field) { +func (schema *Schema) parseRelation(field *Field) *Relationship { var ( err error fieldValue = reflect.New(field.IndirectFieldType).Interface() @@ -67,32 +68,32 @@ func (schema *Schema) parseRelation(field *Field) { ) cacheStore := schema.cacheStore - if field.OwnerSchema != nil { - cacheStore = field.OwnerSchema.cacheStore - } - if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { schema.err = err - return + return nil } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { schema.buildPolymorphicRelation(relation, field, polymorphic) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) + } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { + schema.guessRelation(relation, field, guessBelongs) } else { switch field.IndirectFieldType.Kind() { case reflect.Struct: - schema.guessRelation(relation, field, guessBelongs) + schema.guessRelation(relation, field, guessGuess) case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) } } - if relation.Type == "has" { - if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil { + if relation.Type == has { + // don't add relations to embedded schema, which might be shared + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation } @@ -117,6 +118,8 @@ func (schema *Schema) parseRelation(field *Field) { schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) } } + + return relation } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` @@ -142,11 +145,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -158,12 +161,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) } } // use same data type for foreign keys - relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + if copyableDataType(primaryKeyField.DataType) { + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + } relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType if relation.Polymorphic.PolymorphicID.Size == 0 { relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size @@ -176,7 +181,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi }) } - relation.Type = "has" + relation.Type = has } func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { @@ -200,7 +205,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := schema.LookUpField(foreignKey); field != nil { ownForeignFields = append(ownForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } @@ -212,14 +217,14 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { refForeignFields = append(refForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } } for idx, ownField := range ownForeignFields { - joinFieldName := schema.Name + ownField.Name + joinFieldName := strings.Title(schema.Name) + ownField.Name if len(joinForeignKeys) > idx { joinFieldName = strings.Title(joinForeignKeys[idx]) } @@ -235,7 +240,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, relField := range refForeignFields { - joinFieldName := relation.FieldSchema.Name + relField.Name + joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name if len(joinReferences) > idx { joinFieldName = strings.Title(joinReferences[idx]) } @@ -258,7 +263,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } joinTableFields = append(joinTableFields, reflect.StructField{ - Name: schema.Name + field.Name, + Name: strings.Title(schema.Name) + field.Name, Type: schema.ModelType, Tag: `gorm:"-"`, }) @@ -302,15 +307,17 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for _, f := range relation.JoinTable.Fields { if f.Creatable || f.Readable || f.Updatable { // use same data type for foreign keys - f.DataType = fieldsMap[f.Name].DataType + if copyableDataType(fieldsMap[f.Name].DataType) { + f.DataType = fieldsMap[f.Name].DataType + } f.GORMDataType = fieldsMap[f.Name].GORMDataType if f.Size == 0 { f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPriamryField { + if ownPrimaryField { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ @@ -331,7 +338,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.References = append(relation.References, &Reference{ PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPrimaryKey: ownPriamryField, + OwnPrimaryKey: ownPrimaryField, }) } } @@ -340,20 +347,32 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessBelongs guessLevel = iota + guessGuess guessLevel = iota + guessBelongs guessEmbeddedBelongs guessHas guessEmbeddedHas ) -func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { +func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema + gl = cgl ) + if gl == guessGuess { + if field.Schema == relation.FieldSchema { + gl = guessBelongs + } else { + gl = guessHas + } + } + reguessOrErr := func() { - switch gl { + switch cgl { + case guessGuess: + schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) case guessEmbeddedBelongs: @@ -362,7 +381,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) } } @@ -396,15 +415,35 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } } } else { - for _, primaryField := range primarySchema.PrimaryFields { + var primaryFields []*Field + + if len(relation.primaryKeys) > 0 { + for _, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + primaryFields = append(primaryFields, f) + } + } + } else { + primaryFields = primarySchema.PrimaryFields + } + + for _, primaryField := range primaryFields { lookUpName := primarySchema.Name + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } - if f := foreignSchema.LookUpField(lookUpName); f != nil { - foreignFields = append(foreignFields, f) - primaryFields = append(primaryFields, primaryField) + lookUpNames := []string{lookUpName} + if len(primaryFields) == 1 { + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + } + + for _, name := range lookUpNames { + if f := foreignSchema.LookUpField(name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + break + } } } } @@ -427,7 +466,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } } } else if len(primaryFields) == 0 { - if len(foreignFields) == 1 { + if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) @@ -440,7 +479,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys - foreignField.DataType = primaryFields[idx].DataType + if copyableDataType(primaryFields[idx].DataType) { + foreignField.DataType = primaryFields[idx].DataType + } foreignField.GORMDataType = primaryFields[idx].GORMDataType if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size @@ -454,7 +495,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if gl == guessHas || gl == guessEmbeddedHas { - relation.Type = "has" + relation.Type = has } else { relation.Type = BelongsTo } @@ -477,13 +518,35 @@ func (rel *Relationship) ParseConstraint() *Constraint { return nil } + if rel.Type == BelongsTo { + for _, r := range rel.FieldSchema.Relationships.Relations { + if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + matched := true + for idx, ref := range r.References { + if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && + rel.References[idx].PrimaryValue == ref.PrimaryValue) { + matched = false + } + } + + if matched { + return nil + } + } + } + } + var ( name string idx = strings.Index(str, ",") settings = ParseTagSetting(str, ",") ) - if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) { + // optimize match english letters and midline + // The following code is basically called in for. + // In order to avoid the performance problems caused by repeated compilation of regular expressions, + // it only needs to be done once outside, so optimization is done here. + if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { name = str[0:idx] } else { name = rel.Schema.namer.RelationshipFKName(*rel) @@ -497,7 +560,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { } for _, ref := range rel.References { - if ref.PrimaryKey != nil { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) @@ -511,14 +574,10 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } - if rel.JoinTable != nil { - return nil - } - return &constraint } -func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} @@ -558,9 +617,18 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } } - _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) return } + +func copyableDataType(str DataType) bool { + for _, s := range []string{"auto_increment", "primary key"} { + if strings.Contains(strings.ToLower(string(str)), s) { + return false + } + } + return true +} diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 7d7fd9c9b..e2cf11a91 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,58 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestBelongsToWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestSelfReferentialBelongsTo(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatorID *int32 + Creator *User + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatorID", "User", "", false}}, + }) +} + func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` @@ -106,6 +158,62 @@ func TestHasOneOverrideReferences(t *testing.T) { }) } +func TestHasOneOverrideReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + ProfileID uint `gorm:"column:profile_id"` + Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + func TestHasManyOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model @@ -321,3 +429,150 @@ func TestMultipleMany2Many(t *testing.T) { }, ) } + +func TestSelfReferentialMany2Many(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy int32 + Creators []User `gorm:"foreignKey:CreatedBy"` + AnotherPro interface{} `gorm:"-"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}}, + }) + + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema") + } + + relSchema := user.Relationships.Relations["Creators"].FieldSchema + if user != relSchema { + t.Fatalf("schema should be same, expects %p but got %p", user, relSchema) + } +} + +type CreatedByModel struct { + CreatedByID uint + CreatedBy *CreatedUser +} + +type CreatedUser struct { + gorm.Model + CreatedByModel +} + +func TestEmbeddedRelation(t *testing.T) { + checkStructRelation(t, &CreatedUser{}, Relation{ + Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser", + References: []Reference{ + {"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false}, + }, + }) + + userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema, got error %v", err) + } + + if len(userSchema.Relationships.Relations) != 1 { + t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations)) + } + + if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok { + if createdByRel.FieldSchema != userSchema { + t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema) + } + } else { + t.Fatalf("expects created by relations, but not found") + } +} + +func TestSameForeignKey(t *testing.T) { + type UserAux struct { + gorm.Model + Aux string + UUID string + } + + type User struct { + gorm.Model + Name string + UUID string + Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"` + } + + checkStructRelation(t, &User{}, + Relation{ + Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", true}, + }, + }, + ) +} + +func TestBelongsToSameForeignKey(t *testing.T) { + type User struct { + gorm.Model + Name string + UUID string + } + + type UserAux struct { + gorm.Model + Aux string + UUID string + User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"` + } + + checkStructRelation(t, &UserAux{}, + Relation{ + Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", false}, + }, + }, + ) +} + +func TestHasOneWithSameForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + ProfileRefer int // not used in relationship + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}}, + }) +} + +func TestHasManySameForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + UserRefer uint + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} diff --git a/schema/schema.go b/schema/schema.go index 05db641f0..eca113e96 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -38,15 +38,16 @@ type Schema struct { BeforeSave, AfterSave bool AfterFind bool err error + initialized chan struct{} namer Namer cacheStore *sync.Map } func (schema Schema) String() string { if schema.ModelType.Name() == "" { - return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } - return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) + return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema Schema) MakeSlice() reflect.Value { @@ -70,13 +71,27 @@ type Tabler interface { TableName() string } -// get data type from dialector +// Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + return ParseWithSpecialTableName(dest, cacheStore, namer, "") +} + +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - modelType := reflect.ValueOf(dest).Type() + value := reflect.ValueOf(dest) + if value.Kind() == reflect.Ptr && value.IsNil() { + value = reflect.New(value.Type().Elem()) + } + modelType := reflect.Indirect(value).Type() + + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } @@ -85,11 +100,24 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), nil + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if specialTableName != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) + } else { + schemaCacheKey = modelType + } + + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err } modelValue := reflect.New(modelType) @@ -100,6 +128,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName + } schema := &Schema{ Name: modelType.Name(), @@ -110,14 +141,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, + initialized: make(chan struct{}), + } + // When the schema initialization is completed, the channel will be closed + defer close(schema.initialized) + + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err } - - defer func() { - if schema.err != nil { - logger.Default.Error(context.Background(), schema.err.Error()) - cacheStore.Delete(modelType) - } - }() for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { @@ -157,7 +192,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if _, ok := schema.FieldsByName[field.Name]; !ok { + if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } @@ -187,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } - for _, field := range schema.FieldsByDBName { + for _, field := range schema.Fields { if field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } @@ -214,39 +249,75 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) } } } - if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err - } - } + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } - fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field } + } - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } + fieldValue := reflect.New(field.IndirectFieldType) + fieldInterface := fieldValue.Interface() + if fc, ok := fieldInterface.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + if fc, ok := fieldInterface.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldInterface.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldInterface.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } } return schema, schema.err } + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index cc0306e0a..9abaecba3 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "fmt" "reflect" "strings" @@ -29,7 +30,7 @@ func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields } if !found { - t.Errorf("schema %v failed to found priamry key: %v", s, field) + t.Errorf("schema %v failed to found primary key: %v", s, field) } } }) @@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - fv, _ := s.FieldsByDBName[k].ValueOf(value) + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } diff --git a/schema/schema_test.go b/schema/schema_test.go index a426cd905..8a752fb7b 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -145,8 +145,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { } } -type CustomizeTable struct { -} +type CustomizeTable struct{} func (CustomizeTable) TableName() string { return "customize" @@ -165,7 +164,6 @@ func TestCustomizeTableName(t *testing.T) { func TestNestedModel(t *testing.T) { versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) - if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } @@ -204,7 +202,6 @@ func TestEmbeddedStruct(t *testing.T) { } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) - if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } @@ -273,7 +270,6 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) - if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 000000000..68597538d --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,125 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct { +} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil { + err = field.Set(ctx, dst, t.Time) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.ValueOf(v).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} diff --git a/schema/utils.go b/schema/utils.go index 6e5fd5284..2720c5304 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -1,6 +1,7 @@ package schema import ( + "context" "reflect" "regexp" "strings" @@ -59,22 +60,22 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct } // GetRelationsValues get relations's values from a reflect value -func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) case reflect.Slice, reflect.Array: for i := 0; i < result.Len(); i++ { - if result.Index(i).Kind() == reflect.Ptr { - reflectResults = reflect.Append(reflectResults, result.Index(i)) + if elem := result.Index(i); elem.Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, elem) } else { - reflectResults = reflect.Append(reflectResults, result.Index(i).Addr()) + reflectResults = reflect.Append(reflectResults, elem.Addr()) } } } @@ -97,7 +98,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle } // GetIdentityFieldValuesMap get identity map from fields -func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} @@ -110,7 +111,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - results[0][idx], zero = field.ValueOf(reflectValue) + results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } @@ -135,14 +136,14 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(elem) + fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } if notZero { dataKey := utils.ToStringKey(fieldValues...) if _, ok := dataResults[dataKey]; !ok { - results = append(results, fieldValues[:]) + results = append(results, fieldValues) dataResults[dataKey] = []reflect.Value{elem} } else { dataResults[dataKey] = append(dataResults[dataKey], elem) @@ -155,12 +156,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map } // GetIdentityFieldValuesMapFromValues get identity map from fields -func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { - rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } @@ -178,17 +179,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa } return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues - } else { - columns := make([]clause.Column, len(foreignKeys)) - for idx, key := range foreignKeys { - columns[idx] = clause.Column{Table: table, Name: key} - } + } - for idx, r := range foreignValues { - queryValues[idx] = r - } - return columns, queryValues + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + + return columns, queryValues } type embeddedNamer struct { diff --git a/soft_delete.go b/soft_delete.go index cb56035d6..ba6d2118d 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -63,9 +63,9 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { - if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { if c, ok := stmt.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { for _, expr := range where.Exprs { if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { where.Exprs = []clause.Expression{clause.And(where.Exprs...)} @@ -84,6 +84,32 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } } +func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteUpdateClause{Field: f}} +} + +type SoftDeleteUpdateClause struct { + Field *schema.Field +} + +func (sd SoftDeleteUpdateClause) Name() string { + return "" +} + +func (sd SoftDeleteUpdateClause) Build(clause.Builder) { +} + +func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + } + } +} + func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteDeleteClause{Field: f}} } @@ -103,11 +129,13 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { - if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -115,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -127,10 +155,10 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { stmt.DB.AddError(ErrMissingWhereClause) } else { - SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt) + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } stmt.AddClauseIfNotExists(clause.Update{}) - stmt.Build("UPDATE", "SET", "WHERE") + stmt.Build(stmt.DB.Callback().Update().Clauses...) } } diff --git a/statement.go b/statement.go index 27edf9da8..cb4717766 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -27,6 +28,7 @@ type Statement struct { Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause + BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns @@ -43,11 +45,13 @@ type Statement struct { CurDestIndex int attrs []interface{} assigns []interface{} + scopes []func(*DB) *DB } type join struct { Name string Conds []interface{} + On *clause.Where } // StatementModifier statement modifier interface @@ -55,12 +59,12 @@ type StatementModifier interface { ModifyStatement(*Statement) } -// Write write string +// WriteString write string func (stmt *Statement) WriteString(str string) (int, error) { return stmt.SQL.WriteString(str) } -// Write write string +// WriteByte write byte func (stmt *Statement) WriteByte(c byte) error { return stmt.SQL.WriteByte(c) } @@ -72,30 +76,36 @@ func (stmt *Statement) WriteQuoted(value interface{}) { // QuoteTo write quoted value to writer func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + write := func(raw bool, str string) { + if raw { + writer.WriteString(str) + } else { + stmt.DB.Dialector.QuoteTo(writer, str) + } + } + switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) } else { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + write(v.Raw, stmt.Table) } - } else if v.Raw { - writer.WriteString(v.Name) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Name) + write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteByte(' ') - stmt.DB.Dialector.QuoteTo(writer, v.Alias) + write(v.Raw, v.Alias) } case clause.Column: if v.Table != "" { if v.Table == clause.CurrentTable { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + write(v.Raw, stmt.Table) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Table) + write(v.Raw, v.Table) } writer.WriteByte('.') } @@ -104,19 +114,17 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { - stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { - stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) + write(v.Raw, stmt.Schema.DBNames[0]) } - } else if v.Raw { - writer.WriteString(v.Name) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Name) + write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteString(" AS ") - stmt.DB.Dialector.QuoteTo(writer, v.Alias) + write(v.Raw, v.Alias) } case []clause.Column: writer.WriteByte('(') @@ -127,6 +135,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { stmt.QuoteTo(writer, d) } writer.WriteByte(')') + case clause.Expr: + v.Build(stmt) case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: @@ -150,7 +160,7 @@ func (stmt *Statement) Quote(field interface{}) string { return builder.String() } -// Write write string +// AddVar add var func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for idx, v := range vars { if idx > 0 { @@ -163,18 +173,14 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case Valuer: - stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) - case clause.Expr: - var varStr strings.Builder - var sql = v.SQL - for _, arg := range v.Vars { - stmt.Vars = append(stmt.Vars, arg) - stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) - sql = strings.Replace(sql, "?", varStr.String(), 1) - varStr.Reset() - } - - writer.WriteString(sql) + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { + stmt.AddVar(writer, nil) + } else { + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + } + case clause.Expression: + v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) @@ -191,8 +197,32 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } case *DB: subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) - subdb.callbacks.Query().Execute(subdb) + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + writer.WriteString(subdb.Statement.SQL.String()) stmt.Vars = subdb.Statement.Vars default: @@ -200,6 +230,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { writer.WriteString("(NULL)") + } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) { + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) } else { writer.WriteByte('(') for i := 0; i < rv.Len(); i++ { @@ -245,13 +278,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { return nil - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + } + + if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} - } else if len(args) > 0 && strings.Contains(s, "@") { + } + + if len(args) > 0 && strings.Contains(s, "@") { // looks like a named query return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} - } else if len(args) == 1 { + } + + if strings.Contains(strings.TrimSpace(s), " ") { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + + if len(args) == 1 { return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } } @@ -259,7 +303,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) - for _, arg := range args { + for idx, arg := range args { if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } @@ -270,6 +314,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case *DB: if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { + if len(where.Exprs) == 1 { + if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { + where.Exprs[0] = clause.AndConditions(orConds) + } + } conds = append(conds, clause.And(where.Exprs...)) } else if cs.Expression != nil { conds = append(conds, cs.Expression) @@ -280,7 +329,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds = append(conds, clause.Eq{Column: i, Value: j}) } case map[string]string: - var keys = make([]string, 0, len(v)) + keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } @@ -290,7 +339,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } case map[string]interface{}: - var keys = make([]string, 0, len(v)) + keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } @@ -305,8 +354,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } else if _, ok := v[key].(Valuer); ok { conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } else { - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } @@ -318,12 +369,27 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -335,8 +401,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -347,12 +414,20 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) } else if len(conds) == 0 { if len(args) == 1 { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } @@ -392,7 +467,11 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + return stmt.ParseWithSpecialTableName(value, "") +} + +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] @@ -424,6 +503,12 @@ func (stmt *Statement) clone() *Statement { SkipHooks: stmt.SkipHooks, } + if stmt.SQL.Len() > 0 { + newStmt.SQL.WriteString(stmt.SQL.String()) + newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) + newStmt.Vars = append(newStmt.Vars, stmt.Vars...) + } + for k, c := range stmt.Clauses { newStmt.Clauses[k] = c } @@ -437,6 +522,11 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.Joins, stmt.Joins) } + if len(stmt.scopes) > 0 { + newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) + copy(newStmt.scopes, stmt.scopes) + } + stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) return true @@ -445,11 +535,16 @@ func (stmt *Statement) clone() *Statement { return newStmt } -// Helpers // SetColumn set column's value -func (stmt *Statement) SetColumn(name string, value interface{}) { +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { destValue := reflect.ValueOf(stmt.Dest) @@ -467,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { switch destValue.Kind() { case reflect.Struct: - field.Set(destValue, value) + field.Set(stmt.Context, destValue, value) default: stmt.AddError(ErrInvalidData) } @@ -475,9 +570,20 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } + } else { + field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) + } case reflect.Struct: - field.Set(stmt.ReflectValue, value) + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + + field.Set(stmt.Context, stmt.ReflectValue, value) } } else { stmt.AddError(ErrInvalidField) @@ -497,7 +603,7 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, _ := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { @@ -511,7 +617,7 @@ func (stmt *Statement) Changed(fields ...string) bool { destValue = destValue.Elem() } - changedValue, zero := field.ValueOf(destValue) + changedValue, zero := field.ValueOf(stmt.Context, destValue) return !zero && !utils.AssertEqual(changedValue, fieldValue) } } @@ -537,6 +643,8 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } +var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) + // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} @@ -544,17 +652,21 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( // select columns for _, column := range stmt.Selects { - if column == "*" { + if stmt.Schema == nil { + results[column] = true + } else if column == "*" { notRestricted = true for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - } else if column == clause.Associations && stmt.Schema != nil { + } else if column == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = true } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { + results[matches[1]] = true } else { results[column] = true } @@ -562,21 +674,27 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( // omit columns for _, omit := range stmt.Omits { - if omit == clause.Associations { - if stmt.Schema != nil { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } + if stmt.Schema == nil { + results[omit] = false + } else if omit == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = false + } + } else if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false + } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { + results[matches[1]] = false } else { results[omit] = false } } if stmt.Schema != nil { - for _, field := range stmt.Schema.Fields { + for _, field := range stmt.Schema.FieldsByName { name := field.DBName if name == "" { name = field.Name diff --git a/statement_test.go b/statement_test.go index 03ad81dc6..3f099d611 100644 --- a/statement_test.go +++ b/statement_test.go @@ -34,3 +34,16 @@ func TestWhereCloneCorruption(t *testing.T) { }) } } + +func TestNameMatcher(t *testing.T) { + for k, v := range map[string]string{ + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + } { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + } + } +} diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 3e4de7260..f74799cea 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -7,7 +7,7 @@ import ( ) func TestBelongsToAssociation(t *testing.T) { - var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + user := *GetUser("belongs-to", Config{Company: true, Manager: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -31,8 +31,8 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user, "Manager", 1, "") // Append - var company = Company{Name: "company-belongs-to-append"} - var manager = GetUser("manager-belongs-to-append", Config{}) + company := Company{Name: "company-belongs-to-append"} + manager := GetUser("manager-belongs-to-append", Config{}) if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) @@ -60,8 +60,8 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") // Replace - var company2 = Company{Name: "company-belongs-to-replace"} - var manager2 = GetUser("manager-belongs-to-replace", Config{}) + company2 := Company{Name: "company-belongs-to-replace"} + manager2 := GetUser("manager-belongs-to-replace", Config{}) if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { t.Fatalf("Error happened when replace Company, got %v", err) @@ -132,10 +132,17 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear") + + // unexist company id + unexistCompanyID := company.ID + 9999999 + user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} + if err := DB.Create(&user).Error; err == nil { + t.Errorf("should have gotten foreign key violation error") + } } func TestBelongsToAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 173e92319..002ae6364 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -7,7 +7,7 @@ import ( ) func TestHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Pets: 2}) + user := *GetUser("hasmany", Config{Pets: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -42,7 +42,7 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 2, "") // Append - var pet = Pet{Name: "pet-has-many-append"} + pet := Pet{Name: "pet-has-many-append"} if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -57,14 +57,14 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + pets2 := []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } for _, pet := range pets2 { - var pet = pet + pet := pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") } @@ -77,7 +77,7 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") // Replace - var pet2 = Pet{Name: "pet-has-many-replace"} + pet2 := Pet{Name: "pet-has-many-replace"} if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) @@ -119,7 +119,7 @@ func TestHasManyAssociation(t *testing.T) { } func TestSingleTableHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Team: 2}) + user := *GetUser("hasmany", Config{Team: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -137,7 +137,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 2, "") // Append - var team = *GetUser("team", Config{}) + team := *GetUser("team", Config{}) if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -152,14 +152,14 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 3, "AfterAppend") - var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + teams := []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { t.Fatalf("Error happened when append team, got %v", err) } for _, team := range teams { - var team = team + team := team if team.ID == 0 { t.Fatalf("Team's ID should be created") } @@ -172,7 +172,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") // Replace - var team2 = *GetUser("team-replace", Config{}) + team2 := *GetUser("team-replace", Config{}) if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { t.Fatalf("Error happened when append team, got %v", err) @@ -214,7 +214,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { } func TestHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Pets: 2}), *GetUser("slice-hasmany-2", Config{Pets: 0}), *GetUser("slice-hasmany-3", Config{Pets: 4}), @@ -268,7 +268,7 @@ func TestHasManyAssociationForSlice(t *testing.T) { } func TestSingleTableHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Team: 2}), *GetUser("slice-hasmany-2", Config{Team: 0}), *GetUser("slice-hasmany-3", Config{Team: 4}), @@ -324,7 +324,7 @@ func TestSingleTableHasManyAssociationForSlice(t *testing.T) { } func TestPolymorphicHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Toys: 2}) + user := *GetUser("hasmany", Config{Toys: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -342,7 +342,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 2, "") // Append - var toy = Toy{Name: "toy-has-many-append"} + toy := Toy{Name: "toy-has-many-append"} if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -357,14 +357,14 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") - var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + toys := []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { t.Fatalf("Error happened when append toy, got %v", err) } for _, toy := range toys { - var toy = toy + toy := toy if toy.ID == 0 { t.Fatalf("Toy's ID should be created") } @@ -377,7 +377,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") // Replace - var toy2 = Toy{Name: "toy-has-many-replace"} + toy2 := Toy{Name: "toy-has-many-replace"} if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { t.Fatalf("Error happened when append toy, got %v", err) @@ -419,7 +419,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { } func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), *GetUser("slice-hasmany-2", Config{Toys: 0}), *GetUser("slice-hasmany-3", Config{Toys: 4}), diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a4fc8c4fc..a2c075090 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -7,7 +7,7 @@ import ( ) func TestHasOneAssociation(t *testing.T) { - var user = *GetUser("hasone", Config{Account: true}) + user := *GetUser("hasone", Config{Account: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -25,7 +25,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user, "Account", 1, "") // Append - var account = Account{Number: "account-has-one-append"} + account := Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -41,7 +41,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user, "Account", 1, "AfterAppend") // Replace - var account2 = Account{Number: "account-has-one-replace"} + account2 := Account{Number: "account-has-one-replace"} if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { t.Fatalf("Error happened when append Account, got %v", err) @@ -84,7 +84,7 @@ func TestHasOneAssociation(t *testing.T) { } func TestHasOneAssociationWithSelect(t *testing.T) { - var user = *GetUser("hasone", Config{Account: true}) + user := *GetUser("hasone", Config{Account: true}) DB.Omit("Account.Number").Create(&user) @@ -98,7 +98,7 @@ func TestHasOneAssociationWithSelect(t *testing.T) { } func TestHasOneAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasone-1", Config{Account: true}), *GetUser("slice-hasone-2", Config{Account: false}), *GetUser("slice-hasone-3", Config{Account: true}), @@ -139,7 +139,7 @@ func TestHasOneAssociationForSlice(t *testing.T) { } func TestPolymorphicHasOneAssociation(t *testing.T) { - var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + pet := Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -157,7 +157,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet, "Toy", 1, "") // Append - var toy = Toy{Name: "toy-has-one-append"} + toy := Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append toy, got %v", err) @@ -173,7 +173,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") // Replace - var toy2 = Toy{Name: "toy-has-one-replace"} + toy2 := Toy{Name: "toy-has-one-replace"} if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) @@ -216,7 +216,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { } func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { - var pets = []Pet{ + pets := []Pet{ {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, {Name: "hasone-2", Toy: Toy{}}, {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 1ddd3b858..28b441bd8 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -7,7 +7,7 @@ import ( ) func TestMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Languages: 2}) + user := *GetUser("many2many", Config{Languages: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -26,7 +26,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 2, "") // Append - var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} + language := Language{Code: "language-many2many-append", Name: "language-many2many-append"} DB.Create(&language) if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { @@ -38,7 +38,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") - var languages = []Language{ + languages := []Language{ {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, } @@ -55,7 +55,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") // Replace - var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} + language2 := Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} DB.Create(&language2) if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { @@ -94,7 +94,7 @@ func TestMany2ManyAssociation(t *testing.T) { } func TestMany2ManyOmitAssociations(t *testing.T) { - var user = *GetUser("many2many_omit_associations", Config{Languages: 2}) + user := *GetUser("many2many_omit_associations", Config{Languages: 2}) if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { t.Fatalf("should raise error when create users without languages reference") @@ -113,10 +113,15 @@ func TestMany2ManyOmitAssociations(t *testing.T) { if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { t.Errorf("languages count should be %v, but got %v", 2, len(languages)) } + + newLang := Language{Code: "omitmany2many", Name: "omitmany2many"} + if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { + t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) + } } func TestMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-many2many-1", Config{Languages: 2}), *GetUser("slice-many2many-2", Config{Languages: 0}), *GetUser("slice-many2many-3", Config{Languages: 4}), @@ -134,11 +139,11 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { } // Append - var languages1 = []Language{ + languages1 := []Language{ {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, } - var languages2 = []Language{} - var languages3 = []Language{ + languages2 := []Language{} + languages3 := []Language{ {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, } @@ -186,7 +191,7 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { } func TestSingleTableMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Friends: 2}) + user := *GetUser("many2many", Config{Friends: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -205,7 +210,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 2, "") // Append - var friend = *GetUser("friend", Config{}) + friend := *GetUser("friend", Config{}) if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -216,7 +221,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") - var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + friends := []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { t.Fatalf("Error happened when append friend, got %v", err) @@ -229,7 +234,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") // Replace - var friend2 = *GetUser("friend-replace-2", Config{}) + friend2 := *GetUser("friend-replace-2", Config{}) if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { t.Fatalf("Error happened when append friend, got %v", err) @@ -267,7 +272,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { } func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-many2many-1", Config{Team: 2}), *GetUser("slice-many2many-2", Config{Team: 0}), *GetUser("slice-many2many-3", Config{Team: 4}), @@ -285,17 +290,17 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { } // Append - var teams1 = []User{*GetUser("friend-append-1", Config{})} - var teams2 = []User{} - var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + teams1 := []User{*GetUser("friend-append-1", Config{})} + teams2 := []User{} + teams3 := []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) AssertAssociationCount(t, users, "Team", 9, "After Append") - var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} - var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} - var teams2_3 = GetUser("friend-replace-3-1", Config{}) + teams2_1 := []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + teams2_2 := []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + teams2_3 := GetUser("friend-replace-3-1", Config{}) // Replace DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) diff --git a/tests/associations_test.go b/tests/associations_test.go index f470338fd..5ce98c7dc 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -27,7 +27,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result } func TestInvalidAssociation(t *testing.T) { - var user = *GetUser("invalid", Config{Company: true, Manager: true}) + user := *GetUser("invalid", Config{Company: true, Manager: true}) if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { t.Fatalf("should return errors for invalid association, but got nil") } @@ -64,7 +64,7 @@ func TestAssociationNotNullClear(t *testing.T) { } if err := DB.Model(member).Association("Profiles").Clear(); err == nil { - t.Fatalf("No error occured during clearind not null association") + t.Fatalf("No error occurred during clearind not null association") } } @@ -176,3 +176,47 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { t.Fatalf("Should not find deleted profile") } } + +func TestFullSaveAssociations(t *testing.T) { + coupon := &Coupon{ + AppliesToProduct: []*CouponProduct{ + {ProductId: "full-save-association-product1"}, + }, + AmountOff: 10, + PercentOff: 0.0, + } + + err := DB. + Session(&gorm.Session{FullSaveAssociations: true}). + Create(coupon).Error + if err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if DB.First(&Coupon{}, "id = ?", coupon.ID).Error != nil { + t.Errorf("Failed to query saved coupon") + } + + if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", coupon.ID, "full-save-association-product1").Error != nil { + t.Errorf("Failed to query saved association") + } + + orders := []Order{{Num: "order1", Coupon: coupon}, {Num: "order2", Coupon: coupon}} + if err := DB.Create(&orders).Error; err != nil { + t.Errorf("failed to create orders, got %v", err) + } + + coupon2 := Coupon{ + AppliesToProduct: []*CouponProduct{{Desc: "coupon-description"}}, + } + + DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&coupon2) + var result Coupon + if err := DB.Preload("AppliesToProduct").First(&result, "id = ?", coupon2.ID).Error; err != nil { + t.Errorf("Failed to create coupon w/o name, got error: %v", err) + } + + if len(result.AppliesToProduct) != 1 { + t.Errorf("Failed to preload AppliesToProduct") + } +} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go index c6ce93a26..d897a6341 100644 --- a/tests/benchmark_test.go +++ b/tests/benchmark_test.go @@ -7,7 +7,7 @@ import ( ) func BenchmarkCreate(b *testing.B) { - var user = *GetUser("bench", Config{}) + user := *GetUser("bench", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 @@ -16,7 +16,7 @@ func BenchmarkCreate(b *testing.B) { } func BenchmarkFind(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { @@ -25,7 +25,7 @@ func BenchmarkFind(b *testing.B) { } func BenchmarkUpdate(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { @@ -34,7 +34,7 @@ func BenchmarkUpdate(b *testing.B) { } func BenchmarkDelete(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 diff --git a/tests/connection_test.go b/tests/connection_test.go new file mode 100644 index 000000000..7bc23009d --- /dev/null +++ b/tests/connection_test.go @@ -0,0 +1,46 @@ +package tests_test + +import ( + "fmt" + "testing" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func TestWithSingleConnection(t *testing.T) { + expectedName := "test" + var actualName string + + setSQL, getSQL := getSetSQL(DB.Dialector.Name()) + if len(setSQL) == 0 || len(getSQL) == 0 { + return + } + + err := DB.Connection(func(tx *gorm.DB) error { + if err := tx.Exec(setSQL, expectedName).Error; err != nil { + return err + } + + if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { + return err + } + return nil + }) + if err != nil { + t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) + } + + if actualName != expectedName { + t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) + } +} + +func getSetSQL(driverName string) (string, string) { + switch driverName { + case mysql.Dialector{}.Name(): + return "SET @testName := ?", "SELECT @testName" + default: + return "", "" + } +} diff --git a/tests/count_test.go b/tests/count_test.go index 55fb71e20..b63a55fcc 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -3,6 +3,8 @@ package tests_test import ( "fmt" "regexp" + "sort" + "strings" "testing" "gorm.io/gorm" @@ -77,4 +79,79 @@ func TestCount(t *testing.T) { if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } + + var count6 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", + ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects := []User{{Name: "main"}, {Name: "other"}, {Name: "other"}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count7 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", + ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count8 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", + ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { + t.Fatalf("Count should work, but got err %v", err) + } + + expects = []User{{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count9 int64 + if err := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("users") + }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { + t.Fatalf("Count should work, but got err %v", err) + } + + var count10 int64 + if err := DB.Model(&User{}).Select("*").Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count10).Error; err != nil || count10 != 3 { + t.Fatalf("Count should be 3, but got count: %v err %v", count10, err) + } + + var count11 int64 + sameUsers := make([]*User, 0) + for i := 0; i < 3; i++ { + sameUsers = append(sameUsers, GetUser("count-4", Config{})) + } + DB.Create(sameUsers) + + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) + } + + var count12 int64 + if err := DB.Table("users"). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { + t.Errorf("should returns preload not allowed error, but got %v", err) + } + } diff --git a/tests/create_test.go b/tests/create_test.go index 8d005d0b5..2b23d4409 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "regexp" "testing" "time" @@ -12,7 +13,7 @@ import ( ) func TestCreate(t *testing.T) { - var user = *GetUser("create", Config{}) + user := *GetUser("create", Config{}) if results := DB.Create(&user); results.Error != nil { t.Fatalf("errors happened when create: %v", results.Error) @@ -50,7 +51,39 @@ func TestCreateInBatches(t *testing.T) { *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), } - DB.CreateInBatches(&users, 2) + result := DB.CreateInBatches(&users, 2) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateInBatchesWithDefaultSize(t *testing.T) { + users := []User{ + *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } for _, user := range users { if user.ID == 0 { @@ -90,7 +123,7 @@ func TestCreateFromMap(t *testing.T) { {"name": "create_from_map_3", "Age": 20}, } - if err := DB.Model(&User{}).Create(datas).Error; err != nil { + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } @@ -106,7 +139,7 @@ func TestCreateFromMap(t *testing.T) { } func TestCreateWithAssociations(t *testing.T) { - var user = *GetUser("create_with_associations", Config{ + user := *GetUser("create_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -190,7 +223,7 @@ func TestBulkCreatePtrDataWithAssociations(t *testing.T) { func TestPolymorphicHasOne(t *testing.T) { t.Run("Struct", func(t *testing.T) { - var pet = Pet{ + pet := Pet{ Name: "PolymorphicHasOne", Toy: Toy{Name: "Toy-PolymorphicHasOne"}, } @@ -207,7 +240,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("Slice", func(t *testing.T) { - var pets = []Pet{{ + pets := []Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { @@ -236,7 +269,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("SliceOfPtr", func(t *testing.T) { - var pets = []*Pet{{ + pets := []*Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { @@ -257,7 +290,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("Array", func(t *testing.T) { - var pets = [...]Pet{{ + pets := [...]Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { @@ -278,7 +311,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("ArrayPtr", func(t *testing.T) { - var pets = [...]*Pet{{ + pets := [...]*Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { @@ -315,12 +348,12 @@ func TestCreateEmptyStruct(t *testing.T) { } func TestCreateEmptySlice(t *testing.T) { - var data = []User{} + data := []User{} if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } - var sliceMap = []map[string]interface{}{} + sliceMap := []map[string]interface{}{} if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } @@ -461,3 +494,35 @@ func TestFirstOrCreateWithPrimaryKey(t *testing.T) { t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) } } + +func TestCreateFromSubQuery(t *testing.T) { + user := User{Name: "jinzhu"} + + DB.Create(&user) + + subQuery := DB.Table("users").Where("name=?", user.Name).Select("id") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&Pet{}).Create([]map[string]interface{}{ + { + "name": "cat", + "user_id": gorm.Expr("(?)", DB.Table("(?) as tmp", subQuery).Select("@uid:=id")), + }, + { + "name": "dog", + "user_id": gorm.Expr("@uid"), + }, + }) + + if !regexp.MustCompile(`INSERT INTO .pets. \(.name.,.user_id.\) .*VALUES \(.+,\(SELECT @uid:=id FROM \(SELECT id FROM .users. WHERE name=.+\) as tmp\)\),\(.+,@uid\)`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestCreateNilPointer(t *testing.T) { + var user *User + + err := DB.Create(user).Error + if err == nil || err != gorm.ErrInvalidValue { + t.Fatalf("it is not ErrInvalidValue") + } +} diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 14a0a9774..5e00b1546 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -23,7 +23,7 @@ func TestDefaultValue(t *testing.T) { t.Fatalf("Failed to migrate with default value, got error: %v", err) } - var harumph = Harumph{Email: "hello@gorm.io"} + harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { diff --git a/tests/delete_test.go b/tests/delete_test.go index 954c70974..5cb4b91e6 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -10,7 +10,7 @@ import ( ) func TestDelete(t *testing.T) { - var users = []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + users := []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} if err := DB.Create(&users).Error; err != nil { t.Errorf("errors happened when create: %v", err) @@ -22,8 +22,8 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(&users[1]).Error; err != nil { - t.Errorf("errors happened when delete: %v", err) + if res := DB.Delete(&users[1]); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when delete: %v, affected: %v", res.Error, res.RowsAffected) } var result User @@ -45,7 +45,7 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(users[0]).Error; err != nil { + if err := DB.Delete(&users[0]).Error; err != nil { t.Errorf("errors happened when delete: %v", err) } @@ -153,6 +153,30 @@ func TestDeleteWithAssociations(t *testing.T) { } } +func TestDeleteAssociationsWithUnscoped(t *testing.T) { + user := GetUser("unscoped_delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Unscoped().Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + func TestDeleteSliceWithAssociations(t *testing.T) { users := []User{ *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), @@ -181,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } } + +// only sqlite, postgres support returning +func TestSoftDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("delete-returning-1", Config{}), + GetUser("delete-returning-2", Config{}), + GetUser("delete-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} + +func TestDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + companies := []Company{ + {Name: "delete-returning-1"}, + {Name: "delete-returning-2"}, + {Name: "delete-returning-3"}, + } + DB.Create(&companies) + + var results []Company + DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} diff --git a/tests/distinct_test.go b/tests/distinct_test.go index 29a320ff7..8c8298ae2 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -9,7 +9,7 @@ import ( ) func TestDistinct(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), @@ -31,6 +31,12 @@ func TestDistinct(t *testing.T) { AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + var names2 []string + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("users") + }).Where("name like ?", "distinct%").Order("name").Pluck("name", &names2) + AssertEqual(t, names2, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + var results []User if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { t.Errorf("failed to query users, got error: %v", err) diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 05e0956ee..9ab4ddb66 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: mysql: - image: 'mysql:latest' + image: 'mysql/mysql-server:latest' ports: - 9910:3306 environment: @@ -20,7 +20,7 @@ services: - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: - image: 'mcmoe/mssqldocker:latest' + image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - 9930:1433 environment: diff --git a/tests/go.mod b/tests/go.mod index 55495de30..1c1fb2389 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,17 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/google/uuid v1.1.1 - github.com/jinzhu/now v1.1.1 - github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.3 - gorm.io/driver/postgres v1.0.5 - gorm.io/driver/sqlite v1.1.3 - gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.5 + github.com/google/uuid v1.3.0 + github.com/jackc/pgx/v4 v4.15.0 // indirect + github.com/jinzhu/now v1.1.4 + github.com/lib/pq v1.10.4 + github.com/mattn/go-sqlite3 v1.14.11 // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + gorm.io/driver/mysql v1.3.1 + gorm.io/driver/postgres v1.3.1 + gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/sqlserver v1.3.1 + gorm.io/gorm v1.23.0 ) replace gorm.io/gorm => ../ diff --git a/tests/gorm_test.go b/tests/gorm_test.go new file mode 100644 index 000000000..9827465cc --- /dev/null +++ b/tests/gorm_test.go @@ -0,0 +1,93 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestReturningWithNullToZeroValues(t *testing.T) { + dialect := DB.Dialector.Name() + switch dialect { + case "mysql", "sqlserver": + // these dialects do not support the "returning" clause + return + default: + // This user struct will leverage the existing users table, but override + // the Name field to default to null. + type user struct { + gorm.Model + Name string `gorm:"default:null"` + } + u1 := user{} + + if results := DB.Create(&u1); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } + + got := user{} + results := DB.First(&got, "id = ?", u1.ID) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("first expects: %v, got %v", u1, got) + } + + results = DB.Select("id, name").Find(&got) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1, got) + } + + u1.Name = "jinzhu" + if results := DB.Save(&u1); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + u1 = user{} // important to reinitialize this before creating it again + u2 := user{} + db := DB.Session(&gorm.Session{CreateBatchSize: 10}) + + if results := db.Create([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } else if u2.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u2.ID) + } + + var gotUsers []user + results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) + } else if gotUsers[0].ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) + } else if gotUsers[1].ID != u2.ID { + t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) + } + + u1.Name = "Jinzhu" + u2.Name = "Zhang" + if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + } +} diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 7e41e94a2..5335fed14 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -7,7 +7,7 @@ import ( ) func TestGroupBy(t *testing.T) { - var users = []User{{ + users := []User{{ Name: "groupby", Age: 10, Birthday: Now(), @@ -67,7 +67,7 @@ func TestGroupBy(t *testing.T) { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } - var result = struct { + result := struct { Name string Total int64 }{} @@ -96,4 +96,14 @@ func TestGroupBy(t *testing.T) { if name != "groupby" || active != true || total != 40 { t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) } + + if DB.Dialector.Name() == "mysql" { + if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 330 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index fe3f7d082..0e6ab2fe2 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -133,6 +133,15 @@ func TestRunCallbacks(t *testing.T) { if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { t.Fatalf("Can't find a deleted record") } + + beforeCallTimes := p.AfterFindCallTimes + if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { + t.Fatalf("Find don't raise error when record not found") + } + + if p.AfterFindCallTimes != beforeCallTimes { + t.Fatalf("AfterFind should not be called") + } } func TestCallbacksWithErrors(t *testing.T) { diff --git a/tests/joins_test.go b/tests/joins_test.go index f78ddf679..4c9cffae9 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -57,45 +57,78 @@ func TestJoinsForSlice(t *testing.T) { } func TestJoinConds(t *testing.T) { - var user = *GetUser("joins-conds", Config{Account: true, Pets: 3}) + user := *GetUser("joins-conds", Config{Account: true, Pets: 3}) DB.Save(&user) var users1 []User - DB.Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + DB.Joins("inner join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) if len(users1) != 3 { t.Errorf("should find two users using left join, but got %v", len(users1)) } var users2 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) if len(users2) != 1 { t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) } var users3 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) if len(users3) != 1 { t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) } var users4 []User - DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) if len(users4) != 0 { t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) } var users5 []User - db5 := DB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + db5 := DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + var users6 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = @Name", user.Pets[0]).Where("users.name = ?", user.Name).First(&users6) + if len(users6) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users6)) + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) } + + iv := DB.Table(`table_invoices`).Select(`seller, SUM(total) as total, SUM(paid) as paid, SUM(balance) as balance`).Group(`seller`) + stmt = dryDB.Table(`table_employees`).Select(`id, name, iv.total, iv.paid, iv.balance`).Joins(`LEFT JOIN (?) AS iv ON iv.seller = table_employees.id`, iv).Scan(&user).Statement + if !regexp.MustCompile("SELECT id, name, iv.total, iv.paid, iv.balance FROM .table_employees. LEFT JOIN \\(SELECT seller, SUM\\(total\\) as total, SUM\\(paid\\) as paid, SUM\\(balance\\) as balance FROM .table_invoices. GROUP BY .seller.\\) AS iv ON iv.seller = table_employees.id").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } +} + +func TestJoinOn(t *testing.T) { + user := *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) + + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") } func TestJoinsWithSelect(t *testing.T) { @@ -124,3 +157,46 @@ func TestJoinsWithSelect(t *testing.T) { t.Errorf("Should find all two pets with Join select, got %+v", results) } } + +func TestJoinWithOmit(t *testing.T) { + user := *GetUser("joins_with_omit", Config{Pets: 2}) + DB.Save(&user) + + results := make([]*User, 0) + + if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { + return + } + + if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { + t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) + return + } +} + +func TestJoinCount(t *testing.T) { + companyA := Company{Name: "A"} + companyB := Company{Name: "B"} + DB.Create(&companyA) + DB.Create(&companyB) + + user := User{Name: "kingGo", CompanyID: &companyB.ID} + DB.Create(&user) + + query := DB.Model(&User{}).Joins("Company") + // Bug happens when .Count is called on a query. + // Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 + query.Count(&total) + + var result User + + // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id + if err := query.First(&result, user.ID).Error; err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + if result.ID != user.ID { + t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 275fe6349..94f562b47 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,11 +2,13 @@ package tests_test import ( "math/rand" + "reflect" "strings" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -14,8 +16,7 @@ func TestMigrate(t *testing.T) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - - DB.Migrator().DropTable("user_speaks", "user_friends") + DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err) @@ -25,12 +26,37 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to auto migrate, but got error %v", err) } + if tables, err := DB.Migrator().GetTables(); err != nil { + t.Fatalf("Failed to get database all tables, but got error %v", err) + } else { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + hasTable := false + for _, t2 := range tables { + if t2 == t1 { + hasTable = true + break + } + } + if !hasTable { + t.Fatalf("Failed to get table %v when GetTables", t1) + } + } + } + for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v---", m) + t.Fatalf("Failed to create table for %#v", m) } } + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("ccc") + }).Migrator().CreateTable(&Company{}) + + if !DB.Migrator().HasTable("ccc") { + t.Errorf("failed to create table ccc") + } + for _, indexes := range [][2]string{ {"user_speaks", "fk_user_speaks_user"}, {"user_speaks", "fk_user_speaks_language"}, @@ -38,7 +64,6 @@ func TestMigrate(t *testing.T) { {"user_friends", "fk_user_friends_friends"}, {"accounts", "fk_users_account"}, {"users", "fk_users_team"}, - {"users", "fk_users_manager"}, {"users", "fk_users_company"}, } { if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { @@ -47,6 +72,25 @@ func TestMigrate(t *testing.T) { } } +func TestAutoMigrateSelfReferential(t *testing.T) { + type MigratePerson struct { + ID uint + Name string + ManagerID *uint + Manager *MigratePerson + } + + DB.Migrator().DropTable(&MigratePerson{}) + + if err := DB.AutoMigrate(&MigratePerson{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } + + if !DB.Migrator().HasConstraint("migrate_people", "fk_migrate_people_manager") { + t.Fatalf("Failed to find has one constraint between people and managers") + } +} + func TestSmartMigrateColumn(t *testing.T) { fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] @@ -62,10 +106,11 @@ func TestSmartMigrateColumn(t *testing.T) { DB.AutoMigrate(&UserMigrateColumn{}) type UserMigrateColumn2 struct { - ID uint - Name string `gorm:"size:128"` - Salary float64 `gorm:"precision:2"` - Birthday time.Time `gorm:"precision:2"` + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + NameIgnoreMigration string `gorm:"size:100"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { @@ -95,10 +140,11 @@ func TestSmartMigrateColumn(t *testing.T) { } type UserMigrateColumn3 struct { - ID uint - Name string `gorm:"size:256"` - Salary float64 `gorm:"precision:3"` - Birthday time.Time `gorm:"precision:3"` + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + NameIgnoreMigration string `gorm:"size:128;-:migration"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { @@ -124,22 +170,44 @@ func TestSmartMigrateColumn(t *testing.T) { if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } + case "name_ignore_migration": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { + t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) + } } } +} + +func TestMigrateWithColumnComment(t *testing.T) { + type UserWithColumnComment struct { + gorm.Model + Name string `gorm:"size:111;comment:this is a 字段"` + } + + if err := DB.Migrator().DropTable(&UserWithColumnComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + if err := DB.AutoMigrate(&UserWithColumnComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } } -func TestMigrateWithComment(t *testing.T) { - type UserWithComment struct { +func TestMigrateWithIndexComment(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type UserWithIndexComment struct { gorm.Model - Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"` + Name string `gorm:"size:111;index:,comment:这是一个index"` } - if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil { + if err := DB.Migrator().DropTable(&UserWithIndexComment{}); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } - if err := DB.AutoMigrate(&UserWithComment{}); err != nil { + if err := DB.AutoMigrate(&UserWithIndexComment{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } } @@ -245,9 +313,16 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { + sqlite := DB.Dialector.Name() == "sqlite" + sqlserver := DB.Dialector.Name() == "sqlserver" + type ColumnStruct struct { gorm.Model - Name string + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -258,13 +333,20 @@ func TestMigrateColumns(t *testing.T) { type ColumnStruct2 struct { gorm.Model - Name string `gorm:"size:100"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string } - if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { @@ -272,11 +354,45 @@ func TestMigrateColumns(t *testing.T) { stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { - if columnType.Name() == "name" { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } } } } @@ -323,3 +439,138 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Found deleted column") } } + +func TestMigrateConstraint(t *testing.T) { + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} + + for _, name := range names { + if !DB.Migrator().HasConstraint(&User{}, name) { + DB.Migrator().CreateConstraint(&User{}, name) + } + + if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { + t.Fatalf("failed to drop constraint %v, got error %v", name, err) + } + + if DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("constraint %v should been deleted", name) + } + + if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { + t.Fatalf("failed to create constraint %v, got error %v", name, err) + } + + if !DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("failed to found constraint %v", name) + } + } +} + +type DynamicUser struct { + gorm.Model + Name string + CompanyID string `gorm:"index"` +} + +// To test auto migrate crate indexes for dynamic table name +// https://github.com/go-gorm/gorm/issues/4752 +func TestMigrateIndexesWithDynamicTableName(t *testing.T) { + // Create primary table + if err := DB.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + // Create sub tables + for _, v := range []string{"01", "02", "03"} { + tableName := "dynamic_users_" + v + m := DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table(tableName) + }).Migrator() + + if err := m.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + if !m.HasTable(tableName) { + t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) + } + + if !m.HasIndex(&DynamicUser{}, "CompanyID") { + t.Fatalf("Should have index on %s", "CompanyI.") + } + + if !m.HasIndex(&DynamicUser{}, "DeletedAt") { + t.Fatalf("Should have index on deleted_at.") + } + } +} + +// check column order after migration, flaky test +// https://github.com/go-gorm/gorm/issues/4351 +func TestMigrateColumnOrder(t *testing.T) { + type UserMigrateColumn struct { + ID uint + } + DB.Migrator().DropTable(&UserMigrateColumn{}) + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + F1 string + F2 string + F3 string + F4 string + F5 string + F6 string + F7 string + F8 string + F9 string + F10 string + F11 string + F12 string + F13 string + F14 string + F15 string + F16 string + F17 string + F18 string + F19 string + F20 string + F21 string + F22 string + F23 string + F24 string + F25 string + F26 string + F27 string + F28 string + F29 string + F30 string + F31 string + F32 string + F33 string + F34 string + F35 string + } + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn2{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + typ := reflect.Indirect(reflect.ValueOf(&UserMigrateColumn2{})).Type() + numField := typ.NumField() + if numField != len(columnTypes) { + t.Fatalf("column's number not match struct and ddl, %d != %d", numField, len(columnTypes)) + } + namer := schema.NamingStrategy{} + for i := 0; i < numField; i++ { + expectName := namer.ColumnName("", typ.Field(i).Name) + if columnTypes[i].Name() != expectName { + t.Fatalf("column order not match struct and ddl, idx %d: %s != %s", + i, columnTypes[i].Name(), expectName) + } + } +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index dcc90cd9a..4a7ab9f61 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -71,7 +71,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { @@ -95,8 +95,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog).Association("Tags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("Tags").Find(&tags2) @@ -170,7 +170,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") @@ -201,7 +201,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { t.Fatalf("Preload many2many relations") } - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("SharedTags").Append(tag4) DB.Model(&blog).Association("SharedTags").Find(&tags) @@ -215,8 +215,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("SharedTags").Find(&tags2) @@ -291,7 +291,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { DB.Create(&blog2) // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") @@ -322,7 +322,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("Preload many2many relations") } - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("LocaleTags").Append(tag4) DB.Model(&blog).Association("LocaleTags").Find(&tags) @@ -336,8 +336,8 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) var tags2 []Tag @@ -427,7 +427,7 @@ func TestCompositePrimaryKeysAssociations(t *testing.T) { DB.Migrator().DropTable(&Label{}, &Book{}) if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { - t.Fatalf("failed to migrate") + t.Fatalf("failed to migrate, got %v", err) } book := Book{ diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index d0a6f915f..a3a25f7ba 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "errors" "testing" "gorm.io/gorm" @@ -66,4 +67,16 @@ func TestNamedArg(t *testing.T) { } AssertEqual(t, result6, namedUser) + + var result7 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } + + DB.Delete(&namedUser) + + var result8 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } } diff --git a/tests/non_std_test.go b/tests/non_std_test.go index d3561b11e..8ae426911 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -8,7 +8,7 @@ import ( type Animal struct { Counter uint64 `gorm:"primary_key:yes"` Name string `gorm:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name + From string // test reserved sql keyword as field name Age *time.Time unexported string // unexported value CreatedAt time.Time diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85cd34d46..85671864f 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -15,6 +15,7 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model + Name string `gorm:"check:name_checker,name <> ''"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` Things pq.StringArray `gorm:"type:text[]"` } @@ -30,10 +31,21 @@ func TestPostgres(t *testing.T) { } harumph := Harumph{} - DB.Create(&harumph) + if err := DB.Create(&harumph).Error; err == nil { + t.Fatalf("should failed to create data, name can't be blank") + } + + harumph = Harumph{Name: "jinzhu"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } var result Harumph - if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil { + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index d40309e7d..0ef8890b0 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "reflect" "sort" + "sync/atomic" "testing" "gorm.io/gorm" @@ -1497,10 +1498,9 @@ func TestPreloadManyToManyCallbacks(t *testing.T) { } DB.Save(&lvl) - called := 0 - + var called int64 DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { - called = called + 1 + atomic.AddInt64(&called, 1) }) DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) diff --git a/tests/preload_test.go b/tests/preload_test.go index d90356613..adb54ee19 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "regexp" "sort" "strconv" + "sync" "testing" "gorm.io/gorm" @@ -13,7 +14,7 @@ import ( ) func TestPreloadWithAssociations(t *testing.T) { - var user = *GetUser("preload_with_associations", Config{ + user := *GetUser("preload_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -34,7 +35,7 @@ func TestPreloadWithAssociations(t *testing.T) { DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) - var user3 = *GetUser("preload_with_associations_new", Config{ + user3 := *GetUser("preload_with_associations_new", Config{ Account: true, Pets: 2, Toys: 3, @@ -50,7 +51,7 @@ func TestPreloadWithAssociations(t *testing.T) { } func TestNestedPreload(t *testing.T) { - var user = *GetUser("nested_preload", Config{Pets: 2}) + user := *GetUser("nested_preload", Config{Pets: 2}) for idx, pet := range user.Pets { pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} @@ -62,12 +63,19 @@ func TestNestedPreload(t *testing.T) { var user2 User DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) - CheckUser(t, user2, user) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + CheckUser(t, *user4, user) } func TestNestedPreloadForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), @@ -97,7 +105,7 @@ func TestNestedPreloadForSlice(t *testing.T) { } func TestPreloadWithConds(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Account: true}), *GetUser("slice_nested_preload_2", Config{Account: false}), *GetUser("slice_nested_preload_3", Config{Account: true}), @@ -139,10 +147,23 @@ func TestPreloadWithConds(t *testing.T) { for i, u := range users3 { CheckUser(t, u, users[i]) } + + var user4 User + DB.Delete(&users3[0].Account) + + if err := DB.Preload(clause.Associations).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID != 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } + + if err := DB.Preload(clause.Associations, func(tx *gorm.DB) *gorm.DB { + return tx.Unscoped() + }).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID == 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } } func TestNestedPreloadWithConds(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), @@ -192,7 +213,7 @@ func TestNestedPreloadWithConds(t *testing.T) { } func TestPreloadEmptyData(t *testing.T) { - var user = *GetUser("user_without_associations", Config{}) + user := *GetUser("user_without_associations", Config{}) DB.Create(&user) DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) @@ -212,3 +233,21 @@ func TestPreloadEmptyData(t *testing.T) { t.Errorf("json marshal is not empty slice, got %v", string(r)) } } + +func TestPreloadGoroutine(t *testing.T) { + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + var user2 []User + tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) + + if err := tx.Preload("Team").Find(&user2).Error; err != nil { + t.Error(err) + } + }() + } + wg.Wait() +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 6b10b6dc9..8730e5474 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -50,3 +50,41 @@ func TestPreparedStmt(t *testing.T) { t.Fatalf("no error should happen but got %v", err) } } + +func TestPreparedStmtFromTransaction(t *testing.T) { + db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() +} diff --git a/tests/query_test.go b/tests/query_test.go index c4162bdc6..d10df1807 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -17,7 +17,7 @@ import ( ) func TestFind(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("find", Config{}), *GetUser("find", Config{}), *GetUser("find", Config{}), @@ -57,7 +57,7 @@ func TestFind(t *testing.T) { } t.Run("FirstMap", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -88,7 +88,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstMapWithTable", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -120,7 +120,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstPtrMap", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -135,7 +135,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstSliceOfMap", func(t *testing.T) { - var allMap = []map[string]interface{}{} + allMap := []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { @@ -170,7 +170,7 @@ func TestFind(t *testing.T) { }) t.Run("FindSliceOfMapWithTable", func(t *testing.T) { - var allMap = []map[string]interface{}{} + allMap := []map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { @@ -241,7 +241,7 @@ func TestQueryWithAssociation(t *testing.T) { } func TestFindInBatches(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), @@ -292,6 +292,38 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithError(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + + users := []User{ + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error == nil || result.RowsAffected > 0 { + t.Fatal("expected errors to have occurred, but nothing happened") + } + if totalBatch != 0 { + t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) + } +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) @@ -404,6 +436,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Not(map[string]interface{}{"name": []string{}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) @@ -475,7 +512,23 @@ func TestNotWithAllFields(t *testing.T) { func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + var count int64 + result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) + if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin").Or("role = ?", "admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND (.*role.* = .+ OR .*role.* = .+)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } @@ -612,11 +665,21 @@ func TestSelect(t *testing.T) { t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) } + // SELECT COALESCE(age,'42') FROM users; r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } - // SELECT COALESCE(age,'42') FROM users; + + // named arguments + r = dryDB.Table("users").Select("COALESCE(age, @default)", sql.Named("default", 42)).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + + if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { + t.Fatalf("Failed, got error: %v", err) + } r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { @@ -677,7 +740,7 @@ func TestPluckWithSelect(t *testing.T) { DB.Create(&users) var userAges []int - err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error if err != nil { t.Fatalf("got error when pluck user_age: %v", err) } @@ -790,7 +853,17 @@ func TestSearchWithEmptyChain(t *testing.T) { func TestOrder(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Order("age desc, name").Find(&User{}) + result := dryDB.Order("").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order(nil).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc, name").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } @@ -917,6 +990,30 @@ func TestSearchWithMap(t *testing.T) { } } +func TestSearchWithStruct(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + func TestSubQuery(t *testing.T) { users := []User{ {Name: "subquery_1", Age: 10}, @@ -953,7 +1050,16 @@ func TestSubQueryWithRaw(t *testing.T) { DB.Create(&users) var count int64 - err := DB.Raw("select count(*) from (?) tmp", + err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", DB.Table("users"). Select("name"). Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). diff --git a/tests/scan_test.go b/tests/scan_test.go index 785bb97e4..1a188facf 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -28,6 +28,13 @@ func TestScan(t *testing.T) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } + var resPointer *result + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) @@ -38,7 +45,7 @@ func TestScan(t *testing.T) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } - var doubleAgeRes = &result{} + doubleAgeRes := &result{} if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { t.Errorf("Scan to pointer of pointer") } @@ -57,6 +64,53 @@ func TestScan(t *testing.T) { if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Scan into struct map, got %#v", results) } + + type ID uint64 + var id ID + DB.Raw("select id from users where id = ?", user2.ID).Scan(&id) + if uint(id) != user2.ID { + t.Errorf("Failed to scan to customized data type") + } + + var resInt interface{} + resInt = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) + } + + var resInt2 interface{} + resInt2 = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) + } + + var resInt3 interface{} + resInt3 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) + } + + var resInt4 interface{} + resInt4 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) + } + + var resInt5 interface{} + resInt5 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id IN ?", []uint{user1.ID, user2.ID, user3.ID}).Find(&resInt5).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt5.([]User); len(rus) != 3 { + t.Fatalf("Scan into struct should work, got %+v, len %v", resInt5, len(rus)) + } } func TestScanRows(t *testing.T) { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index fb1f57917..14121699e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -182,11 +182,11 @@ func (data *EncryptedData) Scan(value interface{}) error { func (data EncryptedData) Value() (driver.Value, error) { if len(data) > 0 && data[0] == 'x' { - //needed to test failures + // needed to test failures return nil, errors.New("Should not start with 'x'") } - //prepend asterisks + // prepend asterisks return append([]byte("***"), data...), nil } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index c9787d362..ab3807ea2 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "testing" "gorm.io/gorm" @@ -22,7 +23,7 @@ func NameIn(names []string) func(d *gorm.DB) *gorm.DB { } func TestScopes(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("ScopeUser1", Config{}), GetUser("ScopeUser2", Config{}), GetUser("ScopeUser3", Config{}), @@ -45,4 +46,29 @@ func TestScopes(t *testing.T) { if len(users3) != 2 { t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) } + + db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("custom_table") + }).Session(&gorm.Session{}) + + db.AutoMigrate(&User{}) + if db.Find(&User{}).Statement.Table != "custom_table" { + t.Errorf("failed to call Scopes") + } + + result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { + return tx.Session(&gorm.Session{}) + }).Find(&users1) + + if result.RowsAffected != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) + } + + var maxId int64 + userTable := func(db *gorm.DB) *gorm.DB { + return db.WithContext(context.Background()).Table("users") + } + if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { + t.Errorf("select max(id)") + } } diff --git a/tests/serializer_test.go b/tests/serializer_test.go new file mode 100644 index 000000000..3ed733d9f --- /dev/null +++ b/tests/serializer_test.go @@ -0,0 +1,71 @@ +package tests_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +type SerializerStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + EncryptedString EncryptedString +} + +type Roles []string +type EncryptedString string + +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %v", dbValue) + } + return nil +} + +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +func TestSerializer(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) +} diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index f1ea8a517..9ac8da10d 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "encoding/json" "errors" "regexp" @@ -29,6 +30,10 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("No error should happen when soft delete user, but got %v", err) } + if sql.NullTime(user.DeletedAt).Time.IsZero() { + t.Fatalf("user's deleted at is zero") + } + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) @@ -78,3 +83,13 @@ func TestDeletedAtUnMarshal(t *testing.T) { t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) } } + +func TestDeletedAtOneOr(t *testing.T) { + actualSQL := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Or("id = ?", 1).Find(&User{}) + }) + + if !regexp.MustCompile(` WHERE id = 1 AND .users.\..deleted_at. IS NULL`).MatchString(actualSQL) { + t.Fatalf("invalid sql generated, got %v", actualSQL) + } +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index acb08130b..bc917c32d 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -4,8 +4,10 @@ import ( "regexp" "strings" "testing" + "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -166,6 +168,59 @@ func TestDryRun(t *testing.T) { } } +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint @@ -242,3 +297,180 @@ func TestCombineStringConditions(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } } + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +} + +func TestToSQL(t *testing.T) { + // By default DB.DryRun should false + if DB.DryRun { + t.Fatal("Failed expect DB.DryRun to be false") + } + + if DB.Dialector.Name() == "sqlserver" { + t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") + } + + date, _ := time.Parse("2006-01-02", "2021-10-18") + + // find + sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Limit(10).Order("age desc").Find(&[]User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) + + // after model chagned + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } + + if DB.Statement.SQL.String() != "" { + t.Fatal("Failed expect DB.Statement.SQL to be empty") + } + + // first + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + + // last and unscoped + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Unscoped().Where(&User{Name: "bar", Age: 12}).Limit(10).Offset(5).Order("name ASC").Last(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'bar' AND "users"."age" = 12 ORDER BY name ASC,"users"."id" DESC LIMIT 1 OFFSET 5`, sql) + + // create + user := &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Create(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // save + user = &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Save(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // updates + user = &User{Name: "bar", Age: 22} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Updates(user) + }) + assertEqualSQL(t, `UPDATE "users" SET "created_at"='2021-10-18 00:00:00',"updated_at"='2021-10-18 19:50:09.438',"name"='bar',"age"=22 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // update + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Update("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar',"updated_at"='2021-10-18 19:50:09.438' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumn + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumn("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumns(User{Name: "Foo", Age: 100}) + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // after model chagned + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } +} + +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. +func assertEqualSQL(t *testing.T, expected string, actually string) { + t.Helper() + + // replace SQL quote, convert into postgresql like "" + expected = replaceQuoteInSQL(expected) + actually = replaceQuoteInSQL(actually) + + // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. + updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) + actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) + expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) + + // ignore RETURNING "id" (only in PostgreSQL) + returningRe := regexp.MustCompile(`(?i)RETURNING "id"`) + actually = returningRe.ReplaceAllString(actually, ``) + expected = returningRe.ReplaceAllString(expected, ``) + + actually = strings.TrimSpace(actually) + expected = strings.TrimSpace(expected) + + if actually != expected { + t.Fatalf("\nexpected: %s\nactually: %s", expected, actually) + } +} + +func replaceQuoteInSQL(sql string) string { + // convert single quote into double quote + sql = strings.ReplaceAll(sql, `'`, `"`) + + // convert dialect speical quote into double quote + switch DB.Dialector.Name() { + case "postgres": + sql = strings.ReplaceAll(sql, `"`, `"`) + case "mysql", "sqlite": + sql = strings.ReplaceAll(sql, "`", `"`) + case "sqlserver": + sql = strings.ReplaceAll(sql, `'`, `"`) + } + + return sql +} diff --git a/tests/table_test.go b/tests/table_test.go index 0c6b3eb04..0289b7b83 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -30,6 +30,26 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Table("`people`").Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user as u").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.people").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 744a40e92..5b9bae97a 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,11 +9,31 @@ fi if [ -d tests ] then cd tests - cp go.mod go.mod.bak - sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod + go get -u -t ./... + go mod download + go mod tidy cd .. fi +# SqlServer for Mac M1 +if [[ -z $GITHUB_ACTION ]]; then + if [ -d tests ] + then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true + else + docker-compose start + fi + cd .. + fi +fi + + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then @@ -39,9 +59,3 @@ for dialect in "${dialects[@]}" ; do fi fi done - -if [ -d tests ] -then - cd tests - mv go.mod.bak go.mod -fi diff --git a/tests/tests_test.go b/tests/tests_test.go index cb73d267f..11b6f0675 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -25,12 +25,15 @@ func init() { os.Exit(1) } else { sqlDB, err := DB.DB() - if err == nil { - err = sqlDB.Ping() + if err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) } + err = sqlDB.Ping() if err != nil { - log.Printf("failed to connect database, got error %v", err) + log.Printf("failed to ping sqlDB, got error %v", err) + os.Exit(1) } RunMigrations() @@ -59,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) { PreferSimpleProtocol: true, }), &gorm.Config{}) case "sqlserver": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // CREATE DATABASE gorm; - // USE gorm; + // GO + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - // npm install -g sql-cli - // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 + // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; + // GO log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" @@ -76,6 +80,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) } + if err != nil { + return + } + if debug := os.Getenv("DEBUG"); debug == "true" { db.Logger = db.Logger.LogMode(logger.Info) } else if debug == "false" { @@ -87,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 334600b87..4e4b61494 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -41,7 +41,8 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should not find record after rollback, but got %v", err) } - tx2 := DB.Begin() + txDB := DB.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() user2 := *GetUser("transaction-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) @@ -283,6 +284,69 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDisabledNestedTransaction(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + func TestTransactionOnClosedConn(t *testing.T) { DB, err := OpenTestConnection() if err != nil { diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 736dfc5b6..8fe0f2897 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateBelongsTo(t *testing.T) { - var user = *GetUser("update-belongs-to", Config{}) + user := *GetUser("update-belongs-to", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 9066cbacc..2ca93e2be 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateHasManyAssociations(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) + user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -44,7 +44,7 @@ func TestUpdateHasManyAssociations(t *testing.T) { CheckUser(t, user4, user) t.Run("Polymorphic", func(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) + user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 545685465..c926fbcf8 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -1,14 +1,16 @@ package tests_test import ( + "database/sql" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestUpdateHasOne(t *testing.T) { - var user = *GetUser("update-has-one", Config{}) + user := *GetUser("update-has-one", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -31,7 +33,10 @@ func TestUpdateHasOne(t *testing.T) { var user3 User DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + lastUpdatedAt := user2.Account.UpdatedAt + time.Sleep(time.Second) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) @@ -39,10 +44,16 @@ func TestUpdateHasOne(t *testing.T) { var user4 User DB.Preload("Account").Find(&user4, "id = ?", user.ID) - CheckUser(t, user4, user) + + if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { + t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) + } else { + user.Account.UpdatedAt = user4.Account.UpdatedAt + CheckUser(t, user4, user) + } t.Run("Polymorphic", func(t *testing.T) { - var pet = Pet{Name: "create"} + pet := Pet{Name: "create"} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -75,4 +86,48 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) CheckPet(t, pet4, pet) }) + + t.Run("Restriction", func(t *testing.T) { + type CustomizeAccount struct { + gorm.Model + UserID sql.NullInt64 + Number string `gorm:"<-:create"` + } + + type CustomizeUser struct { + gorm.Model + Name string + Account CustomizeAccount `gorm:"foreignkey:UserID"` + } + + DB.Migrator().DropTable(&CustomizeUser{}) + DB.Migrator().DropTable(&CustomizeAccount{}) + + if err := DB.AutoMigrate(&CustomizeUser{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + + number := "number-has-one-associations" + cusUser := CustomizeUser{ + Name: "update-has-one-associations", + Account: CustomizeAccount{ + Number: number, + }, + } + + if err := DB.Create(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + cusUser.Account.Number += "-update" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var account2 CustomizeAccount + DB.Find(&account2, "user_id = ?", cusUser.ID) + AssertEqual(t, account2.Number, number) + }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index d94ef4ab9..f1218cc0f 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateMany2ManyAssociations(t *testing.T) { - var user = *GetUser("update-many2many", Config{}) + user := *GetUser("update-many2many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_test.go b/tests/update_test.go index a660647cd..b471ba9be 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -9,6 +9,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -69,8 +70,10 @@ func TestUpdate(t *testing.T) { } values := map[string]interface{}{"Active": true, "age": 5} - if err := DB.Model(user).Updates(values).Error; err != nil { - t.Errorf("errors happened when update: %v", err) + if res := DB.Model(user).Updates(values); res.Error != nil { + t.Errorf("errors happened when update: %v", res.Error) + } else if res.RowsAffected != 1 { + t.Errorf("rows affected should be 1, but got : %v", res.RowsAffected) } else if user.Age != 5 { t.Errorf("Age should equals to 5, but got %v", user.Age) } else if user.Active != true { @@ -122,7 +125,7 @@ func TestUpdate(t *testing.T) { } func TestUpdates(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("updates_01", Config{}), GetUser("updates_02", Config{}), } @@ -131,7 +134,10 @@ func TestUpdates(t *testing.T) { lastUpdatedAt := users[0].UpdatedAt // update with map - DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}) + if res := DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("Failed to update users") + } + if users[0].Name != "updates_01_newname" || users[0].Age != 100 { t.Errorf("Record should be updated also with map") } @@ -148,13 +154,17 @@ func TestUpdates(t *testing.T) { CheckUser(t, user2, *users[1]) // update with struct + time.Sleep(1 * time.Second) DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) var user3 User if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { t.Errorf("User2's name should be updated") } - AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) + + if user2.UpdatedAt.Format(time.RFC1123Z) == user3.UpdatedAt.Format(time.RFC1123Z) { + t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123Z), user3.UpdatedAt.Format(time.RFC1123Z)) + } // update with gorm exprs if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { @@ -168,7 +178,7 @@ func TestUpdates(t *testing.T) { } func TestUpdateColumn(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("update_column_01", Config{}), GetUser("update_column_02", Config{}), } @@ -466,7 +476,9 @@ func TestSelectWithUpdateColumn(t *testing.T) { var result2 User DB.First(&result2, user.ID) - AssertEqual(t, lastUpdatedAt, result2.UpdatedAt) + if lastUpdatedAt.Format(time.RFC3339Nano) == result2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be changed") + } if result2.Name == user.Name || result2.Age != user.Age { t.Errorf("Should only update users with name column") @@ -604,11 +616,78 @@ func TestSave(t *testing.T) { t.Fatalf("failed to find updated user") } + user2 := *GetUser("save2", Config{}) + DB.Create(&user2) + + time.Sleep(time.Second) + user1UpdatedAt := result.UpdatedAt + user2UpdatedAt := user2.UpdatedAt + users := []*User{&result, &user2} + DB.Save(&users) + + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + DB.First(&result) + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed after reload, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + DB.First(&user2) + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user2's updated at should be changed after reload, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } + + dryDB = DB.Session(&gorm.Session{DryRun: true}) + stmt = dryDB.Unscoped().Save(&user).Statement + if !regexp.MustCompile(`WHERE .id. = [^ ]+$`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } + + user3 := *GetUser("save3", Config{}) + DB.Create(&user3) + + if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user3.Name = "save3_" + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } + + var result2 User + if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { + t.Fatalf("failed to find updated user, got %v", err) + } + + if err := DB.Model(User{Model: user3.Model}).Save(&struct { + gorm.Model + Placeholder string + Name string + }{ + Model: user3.Model, + Placeholder: "placeholder", + Name: "save3__", + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } + + var result3 User + if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { + t.Fatalf("failed to find updated user") + } } func TestSaveWithPrimaryValue(t *testing.T) { @@ -652,3 +731,35 @@ func TestSaveWithPrimaryValue(t *testing.T) { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) } } + +// only sqlite, postgres support returning +func TestUpdateReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("update-returning-1", Config{}), + GetUser("update-returning-2", Config{}), + GetUser("update-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) + if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { + t.Errorf("failed to return updated data, got %v", results) + } + + if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if results[1].Age-results[0].Age != 100 { + t.Errorf("failed to return updated age column") + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 0ba8b9f06..c5d196055 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -1,9 +1,11 @@ package tests_test import ( + "regexp" "testing" "time" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -51,6 +53,39 @@ func TestUpsert(t *testing.T) { if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { t.Fatalf("failed to upsert, got name %v", result.Name) } + + if name := DB.Dialector.Name(); name != "sqlserver" { + type RestrictedLanguage struct { + Code string `gorm:"primarykey"` + Name string + Lang string `gorm:"<-:create"` + } + + r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + user := *GetUser("upsert_on_conflict", Config{}) + user.Age = 20 + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error %v", err) + } + + var user2 User + DB.First(&user2, user.ID) + user2.Age = 30 + time.Sleep(time.Second) + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { + t.Fatalf("failed to onconflict create user, got error %v", err) + } else { + var user3 User + DB.First(&user3, user.ID) + if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { + t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) + } + } } func TestUpsertSlice(t *testing.T) { @@ -137,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) { } } - // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } - - // var result Language - // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result, lang) - // } - - // lang.Name += "_new" - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } - - // var result2 Language - // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result2, lang) - // } + lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + + lang.Name += "_new" + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result2 Language + if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result2, lang) + } } func TestFindOrInitialize(t *testing.T) { @@ -274,3 +309,20 @@ func TestFindOrCreate(t *testing.T) { t.Errorf("belongs to association should be saved") } } + +func TestUpdateWithMissWhere(t *testing.T) { + type User struct { + ID uint `gorm:"column:id;<-:create"` + Name string `gorm:"column:name"` + } + user := User{ID: 1, Name: "king"} + tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) + + if err := tx.Error; err != nil { + t.Fatalf("failed to update user,missing where condtion,err=%+v", err) + } + + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String()) + } +} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index b8452ef9a..9543f750a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -7,8 +7,7 @@ import ( "gorm.io/gorm/schema" ) -type DummyDialector struct { -} +type DummyDialector struct{} func (DummyDialector) Name() string { return "dummy" @@ -31,9 +30,51 @@ func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v in } func (DummyDialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '`': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString("``") + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteString("`") + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteByte('`') + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString("``") + } + + writer.WriteByte(v) + } + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString("``") + } + writer.WriteString("`") } func (DummyDialector) Explain(sql string, vars ...interface{}) string { diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0f..c84f9cae9 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,6 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) type User struct { gorm.Model Name string @@ -18,6 +19,7 @@ type User struct { Birthday *time.Time Account Account Pets []*Pet + NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company @@ -58,3 +60,23 @@ type Language struct { Code string `gorm:"primarykey"` Name string } + +type Coupon struct { + ID int `gorm:"primarykey; size:255"` + AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` + AmountOff uint32 `gorm:"amount_off"` + PercentOff float32 `gorm:"percent_off"` +} + +type CouponProduct struct { + CouponId int `gorm:"primarykey;size:255"` + ProductId string `gorm:"primarykey;size:255"` + Desc string +} + +type Order struct { + gorm.Model + Num string + Coupon *Coupon + CouponID string +} diff --git a/utils/utils.go b/utils/utils.go index ecba7fb93..296917b93 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -15,17 +15,20 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) + // compatible solution to get gorm source directory with various operating systems gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") } +// FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { + // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { return file + ":" + strconv.FormatInt(int64(line), 10) } } + return "" } @@ -33,17 +36,14 @@ func IsValidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } -func CheckTruth(val interface{}) bool { - if v, ok := val.(bool); ok { - return v - } - - if v, ok := val.(string); ok { - v = strings.ToLower(v) - return v != "false" +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if val != "" && !strings.EqualFold(val, "false") { + return true + } } - - return !reflect.ValueOf(val).IsZero() + return false } func ToStringKey(values ...interface{}) string { @@ -69,6 +69,15 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} + func AssertEqual(src, dst interface{}) bool { if !reflect.DeepEqual(src, dst) { if valuer, ok := src.(driver.Valuer); ok {