From 24fe381fd9bc8a4c81cfaf0173fd2ae2333defda Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Wed, 5 Jul 2017 20:03:11 -0400 Subject: [PATCH 01/10] Bulk Fetch / predicates + foreigns --- bulk_fetch.go | 109 +++++++++++++++++++++++++- logger.go | 83 ++++++++++++++++++++ model.go | 2 + pq_model.go | 213 +++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 393 insertions(+), 14 deletions(-) create mode 100644 logger.go diff --git a/bulk_fetch.go b/bulk_fetch.go index 6367576..135b187 100644 --- a/bulk_fetch.go +++ b/bulk_fetch.go @@ -1,6 +1,7 @@ package surf import ( + "strconv" "strings" ) @@ -12,12 +13,94 @@ const ( ORDER_BY_DESC ) +type PredicateType int + +const ( + WHERE_IS_NOT_NULL PredicateType = iota + WHERE_IS_NULL + WHERE_IN + WHERE_LIKE + WHERE_EQUAL + WHERE_NOT_EQUAL + WHERE_GREATER_THAN + WHERE_GREATER_THAN_OR_EQUAL_TO + WHERE_LESS_THAN + WHERE_LESS_THAN_OR_EQUAL_TO +) + // OrderBy is the definition of a single order by clause type OrderBy struct { Field string Type OrderByType } +// Predicate is the definition of a single where predicate +type Predicate struct { + Field string + PredicateType PredicateType + Values []interface{} +} + +func (p *Predicate) ToString(valueIndex int) (string, []interface{}) { + // Field + predicate := p.Field + + // Type + switch p.PredicateType { + case WHERE_IS_NOT_NULL: + predicate += " IS NOT NULL" + break + case WHERE_IS_NULL: + predicate += " IS NULL" + break + case WHERE_IN: + predicate += " IN " + break + case WHERE_LIKE: + predicate += " LIKE " + break + case WHERE_EQUAL: + predicate += " = " + break + case WHERE_NOT_EQUAL: + predicate += " != " + break + case WHERE_GREATER_THAN: + predicate += " > " + break + case WHERE_GREATER_THAN_OR_EQUAL_TO: + predicate += " >= " + break + case WHERE_LESS_THAN: + predicate += " < " + break + case WHERE_LESS_THAN_OR_EQUAL_TO: + predicate += " <= " + break + } + + // Values + values := make([]interface{}, 0) + if p.PredicateType != WHERE_IS_NOT_NULL && p.PredicateType != WHERE_IS_NULL { + if len(p.Values) > 1 { + predicate += "(" + } + for i, value := range p.Values { + values = append(values, value) + predicate += "$" + strconv.Itoa(valueIndex) + valueIndex++ + if i < len(p.Values)-1 { + predicate += ", " + } + } + if len(p.Values) > 1 { + predicate += ")" + } + } + + return predicate, values +} + // ToString converts an OrderBy to SQL func (ob *OrderBy) ToString() string { obType := "" @@ -30,11 +113,31 @@ func (ob *OrderBy) ToString() string { return ob.Field + obType } +func PredicatesToString(valueIndex int, predicates []Predicate) (string, []interface{}) { + values := make([]interface{}, 0) + + predicateStr := "" + if len(predicates) > 0 { + predicateStr += "WHERE " + } + for i, predicate := range predicates { + iPredicateStr, iValues := predicate.ToString(valueIndex) + valueIndex += len(iValues) + values = append(values, iValues...) + predicateStr += iPredicateStr + if i < (len(predicates) - 1) { + predicateStr += " AND " + } + } + return predicateStr, values +} + // BulkFetchConfig is the configuration of a Model.BulkFetch() type BulkFetchConfig struct { - Limit int - Offset int - OrderBys []OrderBy + Limit int + Offset int + OrderBys []OrderBy + Predicates []Predicate } // ConsumeSortQuery consumes a `sort` query parameter diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..f2435cd --- /dev/null +++ b/logger.go @@ -0,0 +1,83 @@ +package surf + +import ( + "fmt" + "gopkg.in/guregu/null.v3" + "io" + "strconv" + "strings" + "time" +) + +var ( + loggingEnabled = false + loggingWriter io.Writer +) + +func SetLogging(enabled bool, writer io.Writer) { + loggingEnabled = enabled + loggingWriter = writer +} + +func printQuery(query string, args ...interface{}) { + if loggingEnabled { + for i, arg := range args { + query = strings.Replace(query, "$"+strconv.Itoa(i+1), pointerToLogString(arg), 1) + } + fmt.Fprintln(loggingWriter, "[Surf Query]: "+query) + } +} + +func pointerToLogString(pointer interface{}) string { + switch v := pointer.(type) { + case *null.Int: + if v.Valid { + return strconv.FormatInt(v.Int64, 10) + } + break + case *null.String: + if v.Valid { + return "'" + v.String + "'" + } + break + case *null.Bool: + if v.Valid { + return strconv.FormatBool(v.Bool) + } + break + case *null.Float: + if v.Valid { + return strconv.FormatFloat(v.Float64, 'f', -1, 64) + } + break + case *null.Time: + if v.Valid { + return "'" + v.Time.Format(time.RFC3339) + "'" + } + break + case *string: + return "'" + *v + "'" + case *float32: + case *float64: + return strconv.FormatFloat(float64(*v), 'f', -1, 64) + case *bool: + return strconv.FormatBool(*v) + case *int: + case *int8: + case *int16: + case *int32: + case *int64: + return strconv.FormatInt(int64(*v), 10) + case *uint: + case *uint8: + case *uint16: + case *uint32: + case *uint64: + return strconv.FormatUint(uint64(*v), 10) + case *time.Time: + return "'" + (*v).Format(time.RFC3339) + "'" + default: + return fmt.Sprintf("%v", v) + } + return "null" +} diff --git a/model.go b/model.go index 01e7bc1..d0fc028 100644 --- a/model.go +++ b/model.go @@ -25,6 +25,8 @@ type Field struct { Updatable bool UniqueIdentifier bool SkipValidation bool + GetReference func() (BuildModel, string) + SetReference func(Model) error IsSet func(interface{}) bool } diff --git a/pq_model.go b/pq_model.go index 8739ae7..85d42b9 100644 --- a/pq_model.go +++ b/pq_model.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "gopkg.in/guregu/null.v3" "strconv" ) @@ -23,8 +24,7 @@ func (w *PqModel) GetConfiguration() *Configuration { func (w *PqModel) Insert() error { // Get Insertable Fields var insertableFields []Field - allFields := w.Config.Fields - for _, field := range allFields { + for _, field := range w.Config.Fields { if field.Insertable { insertableFields = append(insertableFields, field) } @@ -57,9 +57,19 @@ func (w *PqModel) Insert() error { valueFields = append(valueFields, value.Pointer) } + // Log Query + query := queryBuffer.String() + printQuery(query, valueFields...) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), valueFields...) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, valueFields...) + err := w.ConsumeRow(row) + if err != nil { + return err + } + + // Expand foreign references + return w.expandForeign() } // Load loads the model from the database from its unique identifier @@ -86,9 +96,19 @@ func (w *PqModel) Load() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$1;") + // Log Query + query := queryBuffer.String() + printQuery(query, uniqueIdentifierField.Pointer) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), uniqueIdentifierField.Pointer) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, uniqueIdentifierField.Pointer) + err = w.ConsumeRow(row) + if err != nil { + return err + } + + // Expand foreign references + return w.expandForeign() } // Update updates the model with the current values in the struct @@ -133,9 +153,19 @@ func (w *PqModel) Update() error { } valueFields = append(valueFields, uniqueIdentifierField.Pointer) + // Log Query + query := queryBuffer.String() + printQuery(query, valueFields...) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), valueFields...) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, valueFields...) + err = w.ConsumeRow(row) + if err != nil { + return err + } + + // Expand foreign references + return w.expandForeign() } // Delete deletes the model @@ -154,6 +184,10 @@ func (w PqModel) Delete() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$1;") + // Log Query + query := queryBuffer.String() + printQuery(query, uniqueIdentifierField.Pointer) + // Execute Query res, err := w.Database.Exec(queryBuffer.String(), uniqueIdentifierField.Pointer) if err != nil { @@ -204,7 +238,10 @@ func (w *PqModel) getUniqueIdentifier() (Field, error) { // BulkFetch gets an array of models func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) ([]Model, error) { - // Generate Query + // Set up values + values := make([]interface{}, 0) + + // Generate query var queryBuffer bytes.Buffer queryBuffer.WriteString("SELECT ") for i, field := range w.Config.Fields { @@ -215,7 +252,16 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) } queryBuffer.WriteString(" FROM ") queryBuffer.WriteString(buildModel().GetConfiguration().TableName) - queryBuffer.WriteString(" ORDER BY ") + if len(fetchConfig.Predicates) > 0 { + // WHERE + queryBuffer.WriteString(" ") + predicatesStr, predicateValues := PredicatesToString(1, fetchConfig.Predicates) + values = append(values, predicateValues...) + queryBuffer.WriteString(predicatesStr) + } + if len(fetchConfig.OrderBys) > 0 { + queryBuffer.WriteString(" ORDER BY ") + } for i, orderBy := range fetchConfig.OrderBys { // Validate that the orderBy.Field is a field valid := false @@ -241,8 +287,12 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) queryBuffer.WriteString(strconv.Itoa(fetchConfig.Offset)) queryBuffer.WriteString(";") + // Log Query + query := queryBuffer.String() + printQuery(query, values...) + // Execute Query - rows, err := w.Database.Query(queryBuffer.String()) + rows, err := w.Database.Query(query, values...) if err != nil { return nil, err } @@ -266,10 +316,151 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) models = append(models, model.(Model)) } + // Expand all foreign references + for _, field := range buildModel().GetConfiguration().Fields { + // If the field is a foreign key + if field.GetReference != nil && field.SetReference != nil { + builder, foreignField := field.GetReference() + err = expandForeigns(field.Name, builder, foreignField, models) + if err != nil { + return nil, err + } + } + } + // OK return models, nil } +// expandForeigns expands a single foreign key for an array of Model +func expandForeigns(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { + // Get Foreign IDs + ids := make([]interface{}, 0) + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + switch tv := field.Pointer.(type) { + case *int64: + ids = appendIfMissing(ids, *tv) + break + case *null.Int: + if tv.Valid { + ids = appendIfMissing(ids, tv.Int64) + } + break + } + } + } + } + + // Load Foreign models + foreignModels, err := foreignBuilder().BulkFetch( + BulkFetchConfig{ + Limit: len(ids), + Predicates: []Predicate{{ + Field: foreignField, + PredicateType: WHERE_IN, + Values: ids, + }}, + }, + foreignBuilder, + ) + if err != nil { + return err + } + + // Stuff foreign models into models + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + var toMatch int64 + switch tv := field.Pointer.(type) { + case *int64: + toMatch = *tv + break + case *null.Int: + toMatch = tv.Int64 + break + } + + MatchForeignModel: + for _, foreignModel := range foreignModels { + FindField: + for _, foreignModelField := range foreignModel.GetConfiguration().Fields { + if foreignModelField.Name == foreignField { + + if *(foreignModelField.Pointer.(*int64)) == toMatch { + field.SetReference(foreignModel) + break MatchForeignModel + } + break FindField + } + } + } + break + } + } + } + return nil +} + +// appendIfMissing functions like append(), but will only add the +// int64 to the slice if it doesn't exist in the slice already +func appendIfMissing(slice []interface{}, i int64) []interface{} { + for _, ele := range slice { + if ele == i { + return slice + } + } + return append(slice, i) +} + +// expandForeign expands all foreign references for a single model +func (w *PqModel) expandForeign() error { + // Load all foreign references + for _, field := range w.Config.Fields { + + // If it's a set foreign reference + if field.GetReference != nil && field.SetReference != nil && field.IsSet(field.Pointer) { + + // Get the reference type + modelBuilder, identifier := field.GetReference() + model := modelBuilder() + + // Set the identifier on the foreign reference + // The foreign reference value may only be a `null.Int` or an `int64` + // The identifier on the foreign model may only be of type `int64` + for _, modelField := range model.GetConfiguration().Fields { + if modelField.Name == identifier { + switch tv := field.Pointer.(type) { + case *int64: + *(modelField.Pointer.(*int64)) = *tv + break + case *null.Int: + *(modelField.Pointer.(*int64)) = tv.Int64 + break + } + break + } + } + + // Load + err := model.Load() + if err != nil { + return err + } + + // Set reference + err = field.SetReference(model) + if err != nil { + return err + } + } + } + + return nil +} + // ConsumeRow Scans a *sql.Row into our struct // that is using this model func (w *PqModel) ConsumeRow(row *sql.Row) error { From 425fabbc7d64a4d9d7ec28d0dfc3ec3a3aaf805b Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Mon, 10 Jul 2017 18:00:52 -0400 Subject: [PATCH 02/10] File re-org, better comments --- bulk_fetch.go | 161 --------------------------------- bulk_fetch_config.go | 33 +++++++ logger.go | 58 +++++++----- model_helpers.go | 199 ++++++++++++++++++++++++++++++++++++++++ order_by.go | 27 ++++++ pq_model.go | 210 +++++-------------------------------------- predicate.go | 162 +++++++++++++++++++++++++++++++++ 7 files changed, 479 insertions(+), 371 deletions(-) delete mode 100644 bulk_fetch.go create mode 100644 bulk_fetch_config.go create mode 100644 model_helpers.go create mode 100644 order_by.go create mode 100644 predicate.go diff --git a/bulk_fetch.go b/bulk_fetch.go deleted file mode 100644 index 135b187..0000000 --- a/bulk_fetch.go +++ /dev/null @@ -1,161 +0,0 @@ -package surf - -import ( - "strconv" - "strings" -) - -// OrderByType is an enumeration of the SQL standard order by -type OrderByType int - -const ( - ORDER_BY_ASC OrderByType = iota - ORDER_BY_DESC -) - -type PredicateType int - -const ( - WHERE_IS_NOT_NULL PredicateType = iota - WHERE_IS_NULL - WHERE_IN - WHERE_LIKE - WHERE_EQUAL - WHERE_NOT_EQUAL - WHERE_GREATER_THAN - WHERE_GREATER_THAN_OR_EQUAL_TO - WHERE_LESS_THAN - WHERE_LESS_THAN_OR_EQUAL_TO -) - -// OrderBy is the definition of a single order by clause -type OrderBy struct { - Field string - Type OrderByType -} - -// Predicate is the definition of a single where predicate -type Predicate struct { - Field string - PredicateType PredicateType - Values []interface{} -} - -func (p *Predicate) ToString(valueIndex int) (string, []interface{}) { - // Field - predicate := p.Field - - // Type - switch p.PredicateType { - case WHERE_IS_NOT_NULL: - predicate += " IS NOT NULL" - break - case WHERE_IS_NULL: - predicate += " IS NULL" - break - case WHERE_IN: - predicate += " IN " - break - case WHERE_LIKE: - predicate += " LIKE " - break - case WHERE_EQUAL: - predicate += " = " - break - case WHERE_NOT_EQUAL: - predicate += " != " - break - case WHERE_GREATER_THAN: - predicate += " > " - break - case WHERE_GREATER_THAN_OR_EQUAL_TO: - predicate += " >= " - break - case WHERE_LESS_THAN: - predicate += " < " - break - case WHERE_LESS_THAN_OR_EQUAL_TO: - predicate += " <= " - break - } - - // Values - values := make([]interface{}, 0) - if p.PredicateType != WHERE_IS_NOT_NULL && p.PredicateType != WHERE_IS_NULL { - if len(p.Values) > 1 { - predicate += "(" - } - for i, value := range p.Values { - values = append(values, value) - predicate += "$" + strconv.Itoa(valueIndex) - valueIndex++ - if i < len(p.Values)-1 { - predicate += ", " - } - } - if len(p.Values) > 1 { - predicate += ")" - } - } - - return predicate, values -} - -// ToString converts an OrderBy to SQL -func (ob *OrderBy) ToString() string { - obType := "" - switch ob.Type { - case ORDER_BY_ASC: - obType = " ASC" - case ORDER_BY_DESC: - obType = " DESC" - } - return ob.Field + obType -} - -func PredicatesToString(valueIndex int, predicates []Predicate) (string, []interface{}) { - values := make([]interface{}, 0) - - predicateStr := "" - if len(predicates) > 0 { - predicateStr += "WHERE " - } - for i, predicate := range predicates { - iPredicateStr, iValues := predicate.ToString(valueIndex) - valueIndex += len(iValues) - values = append(values, iValues...) - predicateStr += iPredicateStr - if i < (len(predicates) - 1) { - predicateStr += " AND " - } - } - return predicateStr, values -} - -// BulkFetchConfig is the configuration of a Model.BulkFetch() -type BulkFetchConfig struct { - Limit int - Offset int - OrderBys []OrderBy - Predicates []Predicate -} - -// ConsumeSortQuery consumes a `sort` query parameter -// and stuffs them into the OrderBys field -func (c *BulkFetchConfig) ConsumeSortQuery(sortQuery string) { - var orderBys []OrderBy - for _, sort := range strings.Split(sortQuery, ",") { - if string(sort[0]) == "-" { - orderBys = append(orderBys, OrderBy{ - Field: sort[1:], - Type: ORDER_BY_DESC, - }) - } else { - orderBys = append(orderBys, OrderBy{ - Field: sort, - Type: ORDER_BY_ASC, - }) - } - } - c.OrderBys = orderBys -} diff --git a/bulk_fetch_config.go b/bulk_fetch_config.go new file mode 100644 index 0000000..c3220ad --- /dev/null +++ b/bulk_fetch_config.go @@ -0,0 +1,33 @@ +package surf + +import ( + "strings" +) + +// BulkFetchConfig is the configuration of a Model.BulkFetch() +type BulkFetchConfig struct { + Limit int + Offset int + OrderBys []OrderBy + Predicates []Predicate +} + +// ConsumeSortQuery consumes a `sort` query parameter +// and stuffs them into the OrderBys field +func (c *BulkFetchConfig) ConsumeSortQuery(sortQuery string) { + var orderBys []OrderBy + for _, sort := range strings.Split(sortQuery, ",") { + if string(sort[0]) == "-" { + orderBys = append(orderBys, OrderBy{ + Field: sort[1:], + Type: ORDER_BY_DESC, + }) + } else { + orderBys = append(orderBys, OrderBy{ + Field: sort, + Type: ORDER_BY_ASC, + }) + } + } + c.OrderBys = orderBys +} diff --git a/logger.go b/logger.go index f2435cd..b026644 100644 --- a/logger.go +++ b/logger.go @@ -14,11 +14,16 @@ var ( loggingWriter io.Writer ) +// SetLogging adjusts the configuration for logging. You can enable +// and disable the logging here. By default, logging is disabled. +// +// Most calls to this function will be called like SetLogging(true, os.Stdout) func SetLogging(enabled bool, writer io.Writer) { loggingEnabled = enabled loggingWriter = writer } +// printQuery prints a query if the user has enabled logging func printQuery(query string, args ...interface{}) { if loggingEnabled { for i, arg := range args { @@ -28,8 +33,40 @@ func printQuery(query string, args ...interface{}) { } } +// pointerToLogString converts a value pointer to the string +// that should be logged for it func pointerToLogString(pointer interface{}) string { switch v := pointer.(type) { + case *string: + return "'" + *v + "'" + case *float32: + return strconv.FormatFloat(float64(*v), 'f', -1, 32) + case *float64: + return strconv.FormatFloat(*v, 'f', -1, 64) + case *bool: + return strconv.FormatBool(*v) + case *int: + return strconv.Itoa(*v) + case *int8: + return strconv.FormatInt(int64(*v), 10) + case *int16: + return strconv.FormatInt(int64(*v), 10) + case *int32: + return strconv.FormatInt(int64(*v), 10) + case *int64: + return strconv.FormatInt(*v, 10) + case *uint: + return strconv.FormatUint(uint64(*v), 10) + case *uint8: + return strconv.FormatUint(uint64(*v), 10) + case *uint16: + return strconv.FormatUint(uint64(*v), 10) + case *uint32: + return strconv.FormatUint(uint64(*v), 10) + case *uint64: + return strconv.FormatUint(*v, 10) + case *time.Time: + return "'" + (*v).Format(time.RFC3339) + "'" case *null.Int: if v.Valid { return strconv.FormatInt(v.Int64, 10) @@ -55,27 +92,6 @@ func pointerToLogString(pointer interface{}) string { return "'" + v.Time.Format(time.RFC3339) + "'" } break - case *string: - return "'" + *v + "'" - case *float32: - case *float64: - return strconv.FormatFloat(float64(*v), 'f', -1, 64) - case *bool: - return strconv.FormatBool(*v) - case *int: - case *int8: - case *int16: - case *int32: - case *int64: - return strconv.FormatInt(int64(*v), 10) - case *uint: - case *uint8: - case *uint16: - case *uint32: - case *uint64: - return strconv.FormatUint(uint64(*v), 10) - case *time.Time: - return "'" + (*v).Format(time.RFC3339) + "'" default: return fmt.Sprintf("%v", v) } diff --git a/model_helpers.go b/model_helpers.go new file mode 100644 index 0000000..d7ce459 --- /dev/null +++ b/model_helpers.go @@ -0,0 +1,199 @@ +package surf + +import ( + "database/sql" + "errors" + "fmt" + "gopkg.in/guregu/null.v3" +) + +// getUniqueIdentifier Returns the unique identifier that this model will +// query against. +// +// This will return the first field in Configuration.Fields that: +// - Has `UniqueIdentifier` set to true +// - Returns true from `IsSet` +// +// This function will panic in the event that it encounters a field that is a +// `UniqueIdentifier`, and doesn't have `IsSet` implemented. +func getUniqueIdentifier(w Model) (Field, error) { + // Get all unique identifier fields + var uniqueIdentifierFields []Field + for _, field := range w.GetConfiguration().Fields { + if field.UniqueIdentifier { + uniqueIdentifierFields = append(uniqueIdentifierFields, field) + } + } + + // Determine which unique identifier we will be querying with + var uniqueIdentifierField Field + for _, field := range uniqueIdentifierFields { + if field.IsSet == nil { + panic(fmt.Sprintf("Field `%v` must implement IsSet, as it is a `UniqueIdentifier`", field.Name)) + } else if field.IsSet(field.Pointer) { + uniqueIdentifierField = field + break + } + } + + // Return + if uniqueIdentifierField.Pointer == nil { + return uniqueIdentifierField, errors.New("There is no UniqueIdentifier Field that is set") + } + return uniqueIdentifierField, nil +} + +// expandForeign expands all foreign references for a single model +func expandForeign(model Model) error { + // Load all foreign references + for _, field := range model.GetConfiguration().Fields { + + // If it's a set foreign reference + if field.GetReference != nil && field.SetReference != nil && field.IsSet(field.Pointer) { + + // Get the reference type + modelBuilder, identifier := field.GetReference() + model := modelBuilder() + + // Set the identifier on the foreign reference + // The foreign reference value may only be a `null.Int` or an `int64` + // The identifier on the foreign model may only be of type `int64` + for _, modelField := range model.GetConfiguration().Fields { + if modelField.Name == identifier { + switch tv := field.Pointer.(type) { + case *int64: + *(modelField.Pointer.(*int64)) = *tv + break + case *null.Int: + *(modelField.Pointer.(*int64)) = tv.Int64 + break + } + break + } + } + + // Load + err := model.Load() + if err != nil { + return err + } + + // Set reference + err = field.SetReference(model) + if err != nil { + return err + } + } + } + + return nil +} + +func expandAllForeigns(modelBuilder BuildModel, models []Model) error { + // Expand all foreign references + for _, field := range modelBuilder().GetConfiguration().Fields { + // If the field is a foreign key + if field.GetReference != nil && field.SetReference != nil { + builder, foreignField := field.GetReference() + err := expandForeigns(field.Name, builder, foreignField, models) + if err != nil { + return err + } + } + } + return nil +} + +// expandForeigns expands a single foreign key for an array of Model +func expandForeigns(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { + // Get Foreign IDs + ids := make([]interface{}, 0) + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + switch tv := field.Pointer.(type) { + case *int64: + ids = appendIfMissing(ids, *tv) + break + case *null.Int: + if tv.Valid { + ids = appendIfMissing(ids, tv.Int64) + } + break + } + } + } + } + + // Load Foreign models + foreignModels, err := foreignBuilder().BulkFetch( + BulkFetchConfig{ + Limit: len(ids), + Predicates: []Predicate{{ + Field: foreignField, + PredicateType: WHERE_IN, + Values: ids, + }}, + }, + foreignBuilder, + ) + if err != nil { + return err + } + + // Stuff foreign models into models + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + var toMatch int64 + switch tv := field.Pointer.(type) { + case *int64: + toMatch = *tv + break + case *null.Int: + toMatch = tv.Int64 + break + } + + MatchForeignModel: + for _, foreignModel := range foreignModels { + FindField: + for _, foreignModelField := range foreignModel.GetConfiguration().Fields { + if foreignModelField.Name == foreignField { + + if *(foreignModelField.Pointer.(*int64)) == toMatch { + field.SetReference(foreignModel) + break MatchForeignModel + } + break FindField + } + } + } + break + } + } + } + return nil +} + +// appendIfMissing functions like append(), but will only add the +// int64 to the slice if it doesn't exist in the slice already +func appendIfMissing(slice []interface{}, i int64) []interface{} { + for _, ele := range slice { + if ele == i { + return slice + } + } + return append(slice, i) +} + +// consumeRow Scans a *sql.Row into our struct +// that is using this model +func consumeRow(w Model, row *sql.Row) error { + fields := w.GetConfiguration().Fields + var s []interface{} + for _, value := range fields { + s = append(s, value.Pointer) + } + return row.Scan(s...) +} diff --git a/order_by.go b/order_by.go new file mode 100644 index 0000000..059166e --- /dev/null +++ b/order_by.go @@ -0,0 +1,27 @@ +package surf + +// OrderByType is an enumeration of the SQL standard order by +type OrderByType int + +const ( + ORDER_BY_ASC OrderByType = iota + ORDER_BY_DESC +) + +// OrderBy is the definition of a single order by clause +type OrderBy struct { + Field string + Type OrderByType +} + +// ToString converts an OrderBy to SQL +func (ob *OrderBy) toString() string { + obType := "" + switch ob.Type { + case ORDER_BY_ASC: + obType = " ASC" + case ORDER_BY_DESC: + obType = " DESC" + } + return ob.Field + obType +} diff --git a/pq_model.go b/pq_model.go index 85d42b9..7a5e681 100644 --- a/pq_model.go +++ b/pq_model.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "gopkg.in/guregu/null.v3" "strconv" ) @@ -63,20 +62,20 @@ func (w *PqModel) Insert() error { // Execute Query row := w.Database.QueryRow(query, valueFields...) - err := w.ConsumeRow(row) + err := consumeRow(w, row) if err != nil { return err } // Expand foreign references - return w.expandForeign() + return expandForeign(w) } // Load loads the model from the database from its unique identifier // and then loads those values into the struct func (w *PqModel) Load() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -102,19 +101,19 @@ func (w *PqModel) Load() error { // Execute Query row := w.Database.QueryRow(query, uniqueIdentifierField.Pointer) - err = w.ConsumeRow(row) + err = consumeRow(w, row) if err != nil { return err } // Expand foreign references - return w.expandForeign() + return expandForeign(w) } // Update updates the model with the current values in the struct func (w *PqModel) Update() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -159,19 +158,19 @@ func (w *PqModel) Update() error { // Execute Query row := w.Database.QueryRow(query, valueFields...) - err = w.ConsumeRow(row) + err = consumeRow(w, row) if err != nil { return err } // Expand foreign references - return w.expandForeign() + return expandForeign(w) } // Delete deletes the model -func (w PqModel) Delete() error { +func (w *PqModel) Delete() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -200,42 +199,6 @@ func (w PqModel) Delete() error { return nil } -// getUniqueIdentifier Returns the unique identifier that this model will -// query against. -// -// This will return the first field in Configuration.Fields that: -// - Has `UniqueIdentifier` set to true -// - Returns true from `IsSet` -// -// This function will panic in the event that it encounters a field that is a -// `UniqueIdentifier`, and doesn't have `IsSet` implemented. -func (w *PqModel) getUniqueIdentifier() (Field, error) { - // Get all unique identifier fields - var uniqueIdentifierFields []Field - for _, field := range w.Config.Fields { - if field.UniqueIdentifier { - uniqueIdentifierFields = append(uniqueIdentifierFields, field) - } - } - - // Determine which unique identifier we will be querying with - var uniqueIdentifierField Field - for _, field := range uniqueIdentifierFields { - if field.IsSet == nil { - panic(fmt.Sprintf("Field `%v` must implement IsSet, as it is a `UniqueIdentifier`", field.Name)) - } else if field.IsSet(field.Pointer) { - uniqueIdentifierField = field - break - } - } - - // Return - if uniqueIdentifierField.Pointer == nil { - return uniqueIdentifierField, errors.New("There is no UniqueIdentifier Field that is set") - } - return uniqueIdentifierField, nil -} - // BulkFetch gets an array of models func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) ([]Model, error) { // Set up values @@ -255,7 +218,8 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) if len(fetchConfig.Predicates) > 0 { // WHERE queryBuffer.WriteString(" ") - predicatesStr, predicateValues := PredicatesToString(1, fetchConfig.Predicates) + predicatesStr, predicateValues := predicatesToString(1, fetchConfig.Predicates) + values = append(values, predicateValues...) queryBuffer.WriteString(predicatesStr) } @@ -276,7 +240,7 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) w.Config.TableName, orderBy.Field) } // Write to query - queryBuffer.WriteString(orderBy.ToString()) + queryBuffer.WriteString(orderBy.toString()) if (i + 1) < len(fetchConfig.OrderBys) { queryBuffer.WriteString(", ") } @@ -316,6 +280,14 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) models = append(models, model.(Model)) } + // Expand + /* + err = expandAllForeigns(buildModel, models) + if err != nil { + return nil, err + } + */ + // Expand all foreign references for _, field := range buildModel().GetConfiguration().Fields { // If the field is a foreign key @@ -331,143 +303,3 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) // OK return models, nil } - -// expandForeigns expands a single foreign key for an array of Model -func expandForeigns(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { - // Get Foreign IDs - ids := make([]interface{}, 0) - for _, model := range models { - for _, field := range model.GetConfiguration().Fields { - if field.Name == fieldName { - switch tv := field.Pointer.(type) { - case *int64: - ids = appendIfMissing(ids, *tv) - break - case *null.Int: - if tv.Valid { - ids = appendIfMissing(ids, tv.Int64) - } - break - } - } - } - } - - // Load Foreign models - foreignModels, err := foreignBuilder().BulkFetch( - BulkFetchConfig{ - Limit: len(ids), - Predicates: []Predicate{{ - Field: foreignField, - PredicateType: WHERE_IN, - Values: ids, - }}, - }, - foreignBuilder, - ) - if err != nil { - return err - } - - // Stuff foreign models into models - for _, model := range models { - for _, field := range model.GetConfiguration().Fields { - if field.Name == fieldName { - var toMatch int64 - switch tv := field.Pointer.(type) { - case *int64: - toMatch = *tv - break - case *null.Int: - toMatch = tv.Int64 - break - } - - MatchForeignModel: - for _, foreignModel := range foreignModels { - FindField: - for _, foreignModelField := range foreignModel.GetConfiguration().Fields { - if foreignModelField.Name == foreignField { - - if *(foreignModelField.Pointer.(*int64)) == toMatch { - field.SetReference(foreignModel) - break MatchForeignModel - } - break FindField - } - } - } - break - } - } - } - return nil -} - -// appendIfMissing functions like append(), but will only add the -// int64 to the slice if it doesn't exist in the slice already -func appendIfMissing(slice []interface{}, i int64) []interface{} { - for _, ele := range slice { - if ele == i { - return slice - } - } - return append(slice, i) -} - -// expandForeign expands all foreign references for a single model -func (w *PqModel) expandForeign() error { - // Load all foreign references - for _, field := range w.Config.Fields { - - // If it's a set foreign reference - if field.GetReference != nil && field.SetReference != nil && field.IsSet(field.Pointer) { - - // Get the reference type - modelBuilder, identifier := field.GetReference() - model := modelBuilder() - - // Set the identifier on the foreign reference - // The foreign reference value may only be a `null.Int` or an `int64` - // The identifier on the foreign model may only be of type `int64` - for _, modelField := range model.GetConfiguration().Fields { - if modelField.Name == identifier { - switch tv := field.Pointer.(type) { - case *int64: - *(modelField.Pointer.(*int64)) = *tv - break - case *null.Int: - *(modelField.Pointer.(*int64)) = tv.Int64 - break - } - break - } - } - - // Load - err := model.Load() - if err != nil { - return err - } - - // Set reference - err = field.SetReference(model) - if err != nil { - return err - } - } - } - - return nil -} - -// ConsumeRow Scans a *sql.Row into our struct -// that is using this model -func (w *PqModel) ConsumeRow(row *sql.Row) error { - fields := w.Config.Fields - var s []interface{} - for _, value := range fields { - s = append(s, value.Pointer) - } - return row.Scan(s...) -} diff --git a/predicate.go b/predicate.go new file mode 100644 index 0000000..4340704 --- /dev/null +++ b/predicate.go @@ -0,0 +1,162 @@ +package surf + +import ( + "strconv" +) + +type PredicateType int + +const ( + WHERE_IS_NOT_NULL PredicateType = iota + WHERE_IS_NULL + WHERE_IN + WHERE_NOT_IN + WHERE_LIKE + WHERE_EQUAL + WHERE_NOT_EQUAL + WHERE_GREATER_THAN + WHERE_GREATER_THAN_OR_EQUAL_TO + WHERE_LESS_THAN + WHERE_LESS_THAN_OR_EQUAL_TO +) + +// getPredicateTypeString returns the predicate type string from it's value +func getPredicateTypeString(predicateType PredicateType) string { + switch predicateType { + case WHERE_IS_NOT_NULL: + return "WHERE_IS_NOT_NULL" + case WHERE_IS_NULL: + return "WHERE_IS_NULL" + case WHERE_IN: + return "WHERE_IN" + case WHERE_NOT_IN: + return "WHERE_NOT_IN" + case WHERE_LIKE: + return "WHERE_LIKE" + case WHERE_EQUAL: + return "WHERE_EQUAL" + case WHERE_NOT_EQUAL: + return "WHERE_NOT_EQUAL" + case WHERE_GREATER_THAN: + return "WHERE_GREATER_THAN" + case WHERE_GREATER_THAN_OR_EQUAL_TO: + return "WHERE_GREATER_THAN_OR_EQUAL_TO" + case WHERE_LESS_THAN: + return "WHERE_LESS_THAN" + case WHERE_LESS_THAN_OR_EQUAL_TO: + return "WHERE_LESS_THAN_OR_EQUAL_TO" + } + return "" +} + +// Predicate is the definition of a single where SQL predicate +type Predicate struct { + Field string + PredicateType PredicateType + Values []interface{} +} + +// toString will convert a predicate to it's query string, along with its values +// to be passed along with the query +// +// This function will panic in the event that this is called on a malformed predicate +func (p *Predicate) toString(valueIndex int) (string, []interface{}) { + // Field + predicate := p.Field + + // Type + switch p.PredicateType { + case WHERE_IS_NOT_NULL: + predicate += " IS NOT NULL" + break + case WHERE_IS_NULL: + predicate += " IS NULL" + break + case WHERE_IN: + predicate += " IN " + break + case WHERE_NOT_IN: + predicate += " NOT IN " + break + case WHERE_LIKE: + predicate += " LIKE " + break + case WHERE_EQUAL: + predicate += " = " + break + case WHERE_NOT_EQUAL: + predicate += " != " + break + case WHERE_GREATER_THAN: + predicate += " > " + break + case WHERE_GREATER_THAN_OR_EQUAL_TO: + predicate += " >= " + break + case WHERE_LESS_THAN: + predicate += " < " + break + case WHERE_LESS_THAN_OR_EQUAL_TO: + predicate += " <= " + break + } + + // Values + values := make([]interface{}, 0) + switch p.PredicateType { + case WHERE_IN, + WHERE_NOT_IN: + if len(p.Values) == 0 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates require at least one value.") + } + predicate += "(" + for i, value := range p.Values { + values = append(values, value) + predicate += "$" + strconv.Itoa(valueIndex) + valueIndex++ + if i < len(p.Values)-1 { + predicate += ", " + } + } + predicate += ")" + break + case WHERE_LIKE, + WHERE_EQUAL, + WHERE_NOT_EQUAL, + WHERE_GREATER_THAN, + WHERE_GREATER_THAN_OR_EQUAL_TO, + WHERE_LESS_THAN, + WHERE_LESS_THAN_OR_EQUAL_TO: + if len(p.Values) != 1 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates require exactly one value.") + } + values = append(values, p.Values[0]) + predicate += "$" + strconv.Itoa(valueIndex) + break + } + + return predicate, values +} + +// predicatesToString converts an array of predicates to a query string, along with its values +// to be passed along with the query +// +// This function will panic in the event that it encounters a malformed predicate +func predicatesToString(valueIndex int, predicates []Predicate) (string, []interface{}) { + values := make([]interface{}, 0) + + predicateStr := "" + if len(predicates) > 0 { + predicateStr += "WHERE " + } + for i, predicate := range predicates { + iPredicateStr, iValues := predicate.toString(valueIndex) + valueIndex += len(iValues) + values = append(values, iValues...) + predicateStr += iPredicateStr + if i < (len(predicates) - 1) { + predicateStr += " AND " + } + } + return predicateStr, values +} From 8fb0e5ba74e5c4bcbb5f1a996e350e81e7c97388 Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Mon, 10 Jul 2017 18:06:19 -0400 Subject: [PATCH 03/10] Add golang.org/x/tools/cmd/cover For some reason, travis's go 1.4 doesn't have this... --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index bb952ec..ce85277 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ go: services: - postgresql before_script: + - go get golang.org/x/tools/cmd/cover - | psql \ --command='create database travis_ci_test;' \ From c8f3927feb092873ea7288408740fae71fd9f93f Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Mon, 10 Jul 2017 18:13:38 -0400 Subject: [PATCH 04/10] Update expandForeign to expand all keys Effectively just removing the loop from pq_model --- model_helpers.go | 11 ++++++----- pq_model.go | 22 ++++------------------ 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/model_helpers.go b/model_helpers.go index d7ce459..0462a62 100644 --- a/model_helpers.go +++ b/model_helpers.go @@ -43,7 +43,7 @@ func getUniqueIdentifier(w Model) (Field, error) { return uniqueIdentifierField, nil } -// expandForeign expands all foreign references for a single model +// expandForeign expands all foreign references for a single Model func expandForeign(model Model) error { // Load all foreign references for _, field := range model.GetConfiguration().Fields { @@ -89,13 +89,14 @@ func expandForeign(model Model) error { return nil } -func expandAllForeigns(modelBuilder BuildModel, models []Model) error { +// expandForeigns expands all foreign references for an array of Model +func expandForeigns(modelBuilder BuildModel, models []Model) error { // Expand all foreign references for _, field := range modelBuilder().GetConfiguration().Fields { // If the field is a foreign key if field.GetReference != nil && field.SetReference != nil { builder, foreignField := field.GetReference() - err := expandForeigns(field.Name, builder, foreignField, models) + err := expandForeignsByField(field.Name, builder, foreignField, models) if err != nil { return err } @@ -104,8 +105,8 @@ func expandAllForeigns(modelBuilder BuildModel, models []Model) error { return nil } -// expandForeigns expands a single foreign key for an array of Model -func expandForeigns(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { +// expandForeignsByField expands a single foreign key for an array of Model +func expandForeignsByField(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { // Get Foreign IDs ids := make([]interface{}, 0) for _, model := range models { diff --git a/pq_model.go b/pq_model.go index 7a5e681..c669518 100644 --- a/pq_model.go +++ b/pq_model.go @@ -280,24 +280,10 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) models = append(models, model.(Model)) } - // Expand - /* - err = expandAllForeigns(buildModel, models) - if err != nil { - return nil, err - } - */ - - // Expand all foreign references - for _, field := range buildModel().GetConfiguration().Fields { - // If the field is a foreign key - if field.GetReference != nil && field.SetReference != nil { - builder, foreignField := field.GetReference() - err = expandForeigns(field.Name, builder, foreignField, models) - if err != nil { - return nil, err - } - } + // Expand foreign references + err = expandForeigns(buildModel, models) + if err != nil { + return nil, err } // OK From 748ee31dcca9aad31eb32d159e897abe782bd192 Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Mon, 10 Jul 2017 18:47:37 -0400 Subject: [PATCH 05/10] Prep tests --- .travis.yml | 5 ++++ pq_model_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/.travis.yml b/.travis.yml index ce85277..3317e62 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,6 +23,11 @@ before_script: slug text NOT NULL UNIQUE, name text NOT NULL, age int NOT NULL + ); + CREATE TABLE toys( + id serial PRIMARY KEY, + name text NOT NULL, + owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE );' \ --username='postgres' \ --dbname='travis_ci_test' diff --git a/pq_model_test.go b/pq_model_test.go index 4c56aeb..2e5283b 100644 --- a/pq_model_test.go +++ b/pq_model_test.go @@ -170,6 +170,65 @@ func (a *Animal) Prep(dbConnection *sql.DB) *Animal { return a } +// =============================== +// ========== Toy Model ========== +// =============================== + +/** +Represents: + +CREATE TABLE toys( + id serial PRIMARY KEY, + name text NOT NULL, + owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE +); +*/ +type Toy struct { + surf.Model + Id int `json:"id"` + Name string `json:"name"` + OwnerId int `json:"-"` + Owner *Animal `json:"animal"` +} + +func NewToy(dbConnection *sql.DB) *Toy { + toy := new(Toy) + return toy.Prep(dbConnection) +} + +func (t *Toy) Prep(dbConnection *sql.DB) *Toy { + t.Model = &surf.PqModel{ + Database: dbConnection, + Config: surf.Configuration{ + TableName: "toys", + Fields: []surf.Field{ + {Pointer: &t.Id, Name: "id", UniqueIdentifier: true, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*int) + return pointerInt != 0 + }}, + {Pointer: &t.Name, Name: "name", Insertable: true, Updatable: true}, + {Pointer: &t.OwnerId, Name: "owner", Insertable: true, Updatable: true, + GetReference: func() (surf.BuildModel, string) { + return func() surf.Model { + return NewAnimal(dbConnection) + }, "id" + }, + SetReference: func(model surf.Model) error { + t.Owner = model.(*Animal) + return nil + }, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*int) + return pointerInt != 0 + }, + }, + }, + }, + } + return t +} + // ================================================== // ========== Animal Consume Failure Model ========== // ================================================== From 0f45958fd4a9e3f0aeabc5d99ccb22327e073f2c Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Thu, 13 Jul 2017 18:36:00 -0400 Subject: [PATCH 06/10] Only return Surf fields in Create + Update This will allow new columns to be added to a database without causing any scan errors to occur with queries. This was already handled in Load and BulkFetch, and isn't needed in Delete as we don't return anything. --- pq_model.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pq_model.go b/pq_model.go index c669518..85f7525 100644 --- a/pq_model.go +++ b/pq_model.go @@ -48,7 +48,14 @@ func (w *PqModel) Insert() error { queryBuffer.WriteString(", ") } } - queryBuffer.WriteString(") RETURNING *;") + queryBuffer.WriteString(") RETURNING ") + for i, field := range w.Config.Fields { + queryBuffer.WriteString(field.Name) + if(i + 1) < len(w.Config.Fields) { + queryBuffer.WriteString(", ") + } + } + queryBuffer.WriteString(";") // Get Value Fields var valueFields []interface{} @@ -143,7 +150,14 @@ func (w *PqModel) Update() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$") queryBuffer.WriteString(strconv.Itoa(len(updatableFields) + 1)) - queryBuffer.WriteString(" RETURNING *;") + queryBuffer.WriteString(" RETURNING ") + for i, field := range w.Config.Fields { + queryBuffer.WriteString(field.Name) + if(i + 1) < len(w.Config.Fields) { + queryBuffer.WriteString(", ") + } + } + queryBuffer.WriteString(";") // Get Value Fields var valueFields []interface{} From 300ca410f62df05d488af137cd32c9a7660d6c40 Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Fri, 14 Jul 2017 11:17:11 -0400 Subject: [PATCH 07/10] Enhance tests --- model_helpers.go | 8 ++ pq_model_test.go | 233 ++++++++++++++++++++++++++++++++++++++++++++--- predicate.go | 20 ++-- 3 files changed, 241 insertions(+), 20 deletions(-) diff --git a/model_helpers.go b/model_helpers.go index 0462a62..5352053 100644 --- a/model_helpers.go +++ b/model_helpers.go @@ -121,11 +121,19 @@ func expandForeignsByField(fieldName string, foreignBuilder BuildModel, foreignF ids = appendIfMissing(ids, tv.Int64) } break + default: + panic(fmt.Sprintf("Foreign Key for %v.%v may only be of type `null.Int` or `int64`", + model.GetConfiguration().TableName, field.Name)) } } } } + // If there's nothing to load, exit early + if len(ids) == 0 { + return nil + } + // Load Foreign models foreignModels, err := foreignBuilder().BulkFetch( BulkFetchConfig{ diff --git a/pq_model_test.go b/pq_model_test.go index 2e5283b..8275d54 100644 --- a/pq_model_test.go +++ b/pq_model_test.go @@ -6,10 +6,20 @@ import ( _ "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gopkg.in/guregu/null.v3" "os" "testing" ) +type StackWriter struct { + Stack []string +} + +func (sw *StackWriter) Write(p []byte) (n int, err error) { + sw.Stack = append(sw.Stack, string(p)) + return 0, nil +} + // ================================= // ========== Place Model ========== // ================================= @@ -19,7 +29,7 @@ import ( // in the database. type Place struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Name string `json:"name"` } @@ -39,7 +49,7 @@ func (p *Place) Prep(dbConnection *sql.DB) *Place { Name: "id", UniqueIdentifier: true, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 }, }, @@ -63,7 +73,7 @@ func (p *Place) Prep(dbConnection *sql.DB) *Place { // as we are using this to test a failure before we even hit the db type Person struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Name string `json:"name"` } @@ -115,7 +125,7 @@ CREATE TABLE animals( */ type Animal struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Slug string `json:"slug"` Name string `json:"name"` Age int `json:"age"` @@ -137,7 +147,7 @@ func (a *Animal) Prep(dbConnection *sql.DB) *Animal { Name: "id", UniqueIdentifier: true, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 }, }, @@ -185,10 +195,12 @@ CREATE TABLE toys( */ type Toy struct { surf.Model - Id int `json:"id"` - Name string `json:"name"` - OwnerId int `json:"-"` - Owner *Animal `json:"animal"` + Id int64 `json:"id"` + Name string `json:"name"` + OwnerId int64 `json:"-"` + Owner *Animal `json:"owner"` + SecondOwnerId null.Int `json:"-"` + SecondOwner *Animal `json:"second_owner"` } func NewToy(dbConnection *sql.DB) *Toy { @@ -204,9 +216,10 @@ func (t *Toy) Prep(dbConnection *sql.DB) *Toy { Fields: []surf.Field{ {Pointer: &t.Id, Name: "id", UniqueIdentifier: true, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 - }}, + }, + }, {Pointer: &t.Name, Name: "name", Insertable: true, Updatable: true}, {Pointer: &t.OwnerId, Name: "owner", Insertable: true, Updatable: true, GetReference: func() (surf.BuildModel, string) { @@ -219,10 +232,25 @@ func (t *Toy) Prep(dbConnection *sql.DB) *Toy { return nil }, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 }, }, + {Pointer: &t.SecondOwnerId, Name: "second_owner", Insertable: true, Updatable: true, + GetReference: func() (surf.BuildModel, string) { + return func() surf.Model { + return NewAnimal(dbConnection) + }, "id" + }, + SetReference: func(model surf.Model) error { + t.SecondOwner = model.(*Animal) + return nil + }, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*null.Int) + return pointerInt.Valid + }, + }, }, }, } @@ -304,9 +332,12 @@ type PqWorkerTestSuite struct { } func (suite *PqWorkerTestSuite) SetupTest() { - databaseUrl := os.Getenv("SERF_TEST_DATABASE_URL") + // Enable logging + stackWriter := &StackWriter{} + surf.SetLogging(true, stackWriter) // Opening + storing the connection + databaseUrl := os.Getenv("SERF_TEST_DATABASE_URL") db, err := sql.Open("postgres", databaseUrl) if err != nil { suite.Fail("Failed to open database connection") @@ -616,6 +647,182 @@ func (suite *PqWorkerTestSuite) TestGetConfiguration() { assert.True(suite.T(), (hasId && hasSlug && hasName && hasAge)) } +func (suite *PqWorkerTestSuite) TestNestedModel() { + // Create an Animal + cat := NewAnimal(suite.db) + cat.Name = "Luna" + cat.Slug = "luna" + cat.Age = 2 + cat.Insert() + assert.NotEqual(suite.T(), 0, cat.Id) + + // Create a toy + tennisBall := NewToy(suite.db) + tennisBall.Name = "tennis ball" + tennisBall.OwnerId = cat.Id + tennisBall.Insert() + assert.NotEqual(suite.T(), int64(0), tennisBall.Id) + + // Create a second toy + sock := NewToy(suite.db) + sock.Name = "sock" + sock.OwnerId = cat.Id + sock.SecondOwnerId = null.IntFrom(cat.Id) + sock.Insert() + assert.NotEqual(suite.T(), int64(0), sock.Id) + + // Load all toys + toys, err := NewToy(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + }, func() surf.Model { + return NewToy(suite.db) + }) + + // Verify everything loaded properly + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 2, len(toys)) + + loadedTennisBall := toys[0] + assert.Equal(suite.T(), tennisBall.Id, loadedTennisBall.(*Toy).Id) + assert.Equal(suite.T(), "tennis ball", loadedTennisBall.(*Toy).Name) + assert.Equal(suite.T(), cat.Id, loadedTennisBall.(*Toy).Owner.Id) + + loadedSock := toys[1] + assert.Equal(suite.T(), loadedSock.(*Toy).Id, sock.Id) + assert.Equal(suite.T(), loadedSock.(*Toy).Name, "sock") + assert.Equal(suite.T(), loadedSock.(*Toy).Owner.Id, cat.Id) + + // Clean up + cat.Delete() + tennisBall.Delete() + sock.Delete() +} + +func (suite *PqWorkerTestSuite) TestPredicates() { + // Create some Animals + luna := NewAnimal(suite.db) + luna.Name = "Luna" + luna.Slug = "luna" + luna.Age = 2 + luna.Insert() + assert.NotEqual(suite.T(), 0, luna.Id) + + rae := NewAnimal(suite.db) + rae.Name = "Rae" + rae.Slug = "rae" + rae.Age = 2 + rae.Insert() + assert.NotEqual(suite.T(), 0, rae.Id) + + rigby := NewAnimal(suite.db) + rigby.Name = "Rigby" + rigby.Slug = "rigby" + rigby.Age = 3 + rigby.Insert() + assert.NotEqual(suite.T(), 0, rigby.Id) + + // Test multiple predicates + animals, err := NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_NOT_EQUAL, Values: []interface{}{"Luna"}}, + {Field: "name", PredicateType: surf.WHERE_NOT_EQUAL, Values: []interface{}{"Rae"}}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 1, len(animals)) + loadedRigby := animals[0] + assert.Equal(suite.T(), loadedRigby.(*Animal).Name, "Rigby") + + // Test WHERE_IN with multiple elements + animals, err = NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_IN, Values: []interface{}{"Luna", "Rae"}}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 2, len(animals)) + loadedLuna := animals[0] + loadedRae := animals[1] + assert.Equal(suite.T(), "Luna", loadedLuna.(*Animal).Name) + assert.Equal(suite.T(), "Rae", loadedRae.(*Animal).Name) + + // Test WHERE_NOT_NULL + animals, err = NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_IS_NOT_NULL}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 3, len(animals)) + + // Cleanup + luna.Delete() + rae.Delete() + rigby.Delete() +} + +func (suite *PqWorkerTestSuite) TestWhereEqualPanic() { + suite.PredicatePanic(surf.WHERE_EQUAL) + suite.PredicatePanic(surf.WHERE_IN) + suite.PredicatePanic(surf.WHERE_NOT_IN) + suite.PredicatePanic(surf.WHERE_LIKE) + suite.PredicatePanic(surf.WHERE_EQUAL) + suite.PredicatePanic(surf.WHERE_NOT_EQUAL) + suite.PredicatePanic(surf.WHERE_GREATER_THAN) + suite.PredicatePanic(surf.WHERE_GREATER_THAN_OR_EQUAL_TO) + suite.PredicatePanic(surf.WHERE_LESS_THAN) + suite.PredicatePanic(surf.WHERE_LESS_THAN_OR_EQUAL_TO) + suite.PredicatePanic(surf.WHERE_IS_NOT_NULL, 1) + suite.PredicatePanic(surf.WHERE_IS_NULL, 1) + suite.PredicatePanic(9999, 1) +} + +func (suite *PqWorkerTestSuite) PredicatePanic(predType surf.PredicateType, values ...interface{}) { + defer func() { + recover() + }() + + // Test WHERE_NOT_IN without any elements (panics) + NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: predType, Values: values}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Fail(suite.T(), "Unreachable statement, last call should have panicked") +} + // In order for 'go test' to run this suite, we need to create // a normal test function and pass our suite to suite.Run func TestPqWorkerTestSuite(t *testing.T) { diff --git a/predicate.go b/predicate.go index 4340704..945b0c6 100644 --- a/predicate.go +++ b/predicate.go @@ -7,7 +7,7 @@ import ( type PredicateType int const ( - WHERE_IS_NOT_NULL PredicateType = iota + WHERE_IS_NOT_NULL PredicateType = iota // Default WHERE_IS_NULL WHERE_IN WHERE_NOT_IN @@ -23,8 +23,6 @@ const ( // getPredicateTypeString returns the predicate type string from it's value func getPredicateTypeString(predicateType PredicateType) string { switch predicateType { - case WHERE_IS_NOT_NULL: - return "WHERE_IS_NOT_NULL" case WHERE_IS_NULL: return "WHERE_IS_NULL" case WHERE_IN: @@ -46,7 +44,7 @@ func getPredicateTypeString(predicateType PredicateType) string { case WHERE_LESS_THAN_OR_EQUAL_TO: return "WHERE_LESS_THAN_OR_EQUAL_TO" } - return "" + return "WHERE_IS_NOT_NULL" } // Predicate is the definition of a single where SQL predicate @@ -66,9 +64,6 @@ func (p *Predicate) toString(valueIndex int) (string, []interface{}) { // Type switch p.PredicateType { - case WHERE_IS_NOT_NULL: - predicate += " IS NOT NULL" - break case WHERE_IS_NULL: predicate += " IS NULL" break @@ -99,6 +94,9 @@ func (p *Predicate) toString(valueIndex int) (string, []interface{}) { case WHERE_LESS_THAN_OR_EQUAL_TO: predicate += " <= " break + default: + predicate += " IS NOT NULL" + break } // Values @@ -133,6 +131,14 @@ func (p *Predicate) toString(valueIndex int) (string, []interface{}) { values = append(values, p.Values[0]) predicate += "$" + strconv.Itoa(valueIndex) break + case WHERE_IS_NOT_NULL, + WHERE_IS_NULL: + if len(p.Values) != 0 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates cannot have any values.") + } + break + default: + panic("Unknown predicate type.") } return predicate, values From adb5c19420dc943d571f0d6390056ef8aa278410 Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Fri, 14 Jul 2017 12:07:03 -0400 Subject: [PATCH 08/10] Add logger tests --- logger.go | 4 +- logger_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++++++ pq_model.go | 10 ++-- pq_model_test.go | 9 --- 4 files changed, 154 insertions(+), 16 deletions(-) create mode 100644 logger_test.go diff --git a/logger.go b/logger.go index b026644..a6ff0ee 100644 --- a/logger.go +++ b/logger.go @@ -24,12 +24,12 @@ func SetLogging(enabled bool, writer io.Writer) { } // printQuery prints a query if the user has enabled logging -func printQuery(query string, args ...interface{}) { +func PrintSqlQuery(query string, args ...interface{}) { if loggingEnabled { for i, arg := range args { query = strings.Replace(query, "$"+strconv.Itoa(i+1), pointerToLogString(arg), 1) } - fmt.Fprintln(loggingWriter, "[Surf Query]: "+query) + fmt.Fprint(loggingWriter, query) } } diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..f648a35 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,147 @@ +package surf_test + +import ( + "github.com/go-carrot/surf" + "github.com/stretchr/testify/assert" + "gopkg.in/guregu/null.v3" + "testing" + "time" +) + +type StackWriter struct { + Stack []string +} + +func (sw *StackWriter) Write(p []byte) (n int, err error) { + sw.Stack = append(sw.Stack, string(p)) + return 0, nil +} + +func (sw *StackWriter) Peek() string { + if len(sw.Stack) > 0 { + return sw.Stack[len(sw.Stack)-1] + } + return "" +} + +func TestLogger(t *testing.T) { + // Enable logging + stackWriter := &StackWriter{} + surf.SetLogging(true, stackWriter) + + // float32 + var idFloat32 float32 = 8.8 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idFloat32) + assert.Equal(t, "SELECT * FROM table WHERE id = 8.8", stackWriter.Peek()) + + // float64 + var idFloat64 float64 = 8.9 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idFloat64) + assert.Equal(t, "SELECT * FROM table WHERE id = 8.9", stackWriter.Peek()) + + // bool + var idBool bool = false + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idBool) + assert.Equal(t, "SELECT * FROM table WHERE id = false", stackWriter.Peek()) + + // int + var idInt int = 190 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 190", stackWriter.Peek()) + + // int8 + var idInt8 int8 = 8 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt8) + assert.Equal(t, "SELECT * FROM table WHERE id = 8", stackWriter.Peek()) + + // int16 + var idInt16 int16 = 111 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt16) + assert.Equal(t, "SELECT * FROM table WHERE id = 111", stackWriter.Peek()) + + // int32 + var idInt32 int32 = 1110 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt32) + assert.Equal(t, "SELECT * FROM table WHERE id = 1110", stackWriter.Peek()) + + // int64 + var idInt64 int64 = 11100 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt64) + assert.Equal(t, "SELECT * FROM table WHERE id = 11100", stackWriter.Peek()) + + // uint + var idUInt uint = 200 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 200", stackWriter.Peek()) + + // uint8 + var idUInt8 uint8 = 127 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt8) + assert.Equal(t, "SELECT * FROM table WHERE id = 127", stackWriter.Peek()) + + // uint16 + var idUInt16 uint16 = 1278 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt16) + assert.Equal(t, "SELECT * FROM table WHERE id = 1278", stackWriter.Peek()) + + // uint32 + var idUInt32 uint32 = 12788 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt32) + assert.Equal(t, "SELECT * FROM table WHERE id = 12788", stackWriter.Peek()) + + // uint64 + var idUInt64 uint64 = 127888 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt64) + assert.Equal(t, "SELECT * FROM table WHERE id = 127888", stackWriter.Peek()) + + // time.Time + const layout = "Jan 2, 2006 at 3:04pm (MST)" + idTime, _ := time.Parse(layout, "Feb 3, 2013 at 7:54pm (PST)") + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idTime) + assert.Equal(t, "SELECT * FROM table WHERE id = '2013-02-03T19:54:00Z'", stackWriter.Peek()) + + // null.Int + idNullInt := null.IntFrom(100) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 100", stackWriter.Peek()) + + idNullIntNull := null.Int{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullIntNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.String + idNullString := null.StringFrom("Hello") + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullString) + assert.Equal(t, "SELECT * FROM table WHERE id = 'Hello'", stackWriter.Peek()) + + idNullStringNull := null.String{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullStringNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Bool + idNullBool := null.BoolFrom(false) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullBool) + assert.Equal(t, "SELECT * FROM table WHERE id = false", stackWriter.Peek()) + + idNullBoolNull := null.Bool{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullBoolNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Float + idNullFloat := null.FloatFrom(1.2) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullFloat) + assert.Equal(t, "SELECT * FROM table WHERE id = 1.2", stackWriter.Peek()) + + idNullFloatNull := null.Float{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullFloatNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Time + idNullTime := null.TimeFrom(idTime) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullTime) + assert.Equal(t, "SELECT * FROM table WHERE id = '2013-02-03T19:54:00Z'", stackWriter.Peek()) + + idNullTimeNull := null.Time{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullTimeNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) +} diff --git a/pq_model.go b/pq_model.go index 85f7525..7830d75 100644 --- a/pq_model.go +++ b/pq_model.go @@ -65,7 +65,7 @@ func (w *PqModel) Insert() error { // Log Query query := queryBuffer.String() - printQuery(query, valueFields...) + PrintSqlQuery(query, valueFields...) // Execute Query row := w.Database.QueryRow(query, valueFields...) @@ -104,7 +104,7 @@ func (w *PqModel) Load() error { // Log Query query := queryBuffer.String() - printQuery(query, uniqueIdentifierField.Pointer) + PrintSqlQuery(query, uniqueIdentifierField.Pointer) // Execute Query row := w.Database.QueryRow(query, uniqueIdentifierField.Pointer) @@ -168,7 +168,7 @@ func (w *PqModel) Update() error { // Log Query query := queryBuffer.String() - printQuery(query, valueFields...) + PrintSqlQuery(query, valueFields...) // Execute Query row := w.Database.QueryRow(query, valueFields...) @@ -199,7 +199,7 @@ func (w *PqModel) Delete() error { // Log Query query := queryBuffer.String() - printQuery(query, uniqueIdentifierField.Pointer) + PrintSqlQuery(query, uniqueIdentifierField.Pointer) // Execute Query res, err := w.Database.Exec(queryBuffer.String(), uniqueIdentifierField.Pointer) @@ -267,7 +267,7 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) // Log Query query := queryBuffer.String() - printQuery(query, values...) + PrintSqlQuery(query, values...) // Execute Query rows, err := w.Database.Query(query, values...) diff --git a/pq_model_test.go b/pq_model_test.go index 8275d54..14748c6 100644 --- a/pq_model_test.go +++ b/pq_model_test.go @@ -11,15 +11,6 @@ import ( "testing" ) -type StackWriter struct { - Stack []string -} - -func (sw *StackWriter) Write(p []byte) (n int, err error) { - sw.Stack = append(sw.Stack, string(p)) - return 0, nil -} - // ================================= // ========== Place Model ========== // ================================= From ee7594d559d14a141d1fe92779cc1fd15916052c Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Fri, 14 Jul 2017 12:10:44 -0400 Subject: [PATCH 09/10] Update DB schema --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 3317e62..64d424f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,7 +27,8 @@ before_script: CREATE TABLE toys( id serial PRIMARY KEY, name text NOT NULL, - owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE + owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE, + second_owner bigint REFERENCES animals(id) );' \ --username='postgres' \ --dbname='travis_ci_test' From f401ffd7cf6da4f0737bd580fbf5c1ac3b98ea7b Mon Sep 17 00:00:00 2001 From: BrandonRomano Date: Fri, 21 Jul 2017 15:13:36 -0400 Subject: [PATCH 10/10] Instantiate models in bulk fetch This will make sure the models get marshalled as empty arrays, and not null --- pq_model.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pq_model.go b/pq_model.go index 7830d75..5c34cc6 100644 --- a/pq_model.go +++ b/pq_model.go @@ -51,7 +51,7 @@ func (w *PqModel) Insert() error { queryBuffer.WriteString(") RETURNING ") for i, field := range w.Config.Fields { queryBuffer.WriteString(field.Name) - if(i + 1) < len(w.Config.Fields) { + if (i + 1) < len(w.Config.Fields) { queryBuffer.WriteString(", ") } } @@ -153,7 +153,7 @@ func (w *PqModel) Update() error { queryBuffer.WriteString(" RETURNING ") for i, field := range w.Config.Fields { queryBuffer.WriteString(field.Name) - if(i + 1) < len(w.Config.Fields) { + if (i + 1) < len(w.Config.Fields) { queryBuffer.WriteString(", ") } } @@ -276,7 +276,7 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) } // Stuff into []Model - var models []Model + models := make([]Model, 0) for rows.Next() { model := buildModel()