diff --git a/.travis.yml b/.travis.yml index bb952ec..64d424f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ go: services: - postgresql before_script: + - go get golang.org/x/tools/cmd/cover - | psql \ --command='create database travis_ci_test;' \ @@ -22,6 +23,12 @@ before_script: slug text NOT NULL UNIQUE, name text NOT NULL, age int NOT NULL + ); + CREATE TABLE toys( + id serial PRIMARY KEY, + name text NOT NULL, + owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE, + second_owner bigint REFERENCES animals(id) );' \ --username='postgres' \ --dbname='travis_ci_test' diff --git a/bulk_fetch.go b/bulk_fetch_config.go similarity index 54% rename from bulk_fetch.go rename to bulk_fetch_config.go index 6367576..c3220ad 100644 --- a/bulk_fetch.go +++ b/bulk_fetch_config.go @@ -4,37 +4,12 @@ import ( "strings" ) -// OrderByType is an enumeration of the SQL standard order by -type OrderByType int - -const ( - ORDER_BY_ASC OrderByType = iota - ORDER_BY_DESC -) - -// OrderBy is the definition of a single order by clause -type OrderBy struct { - Field string - Type OrderByType -} - -// ToString converts an OrderBy to SQL -func (ob *OrderBy) ToString() string { - obType := "" - switch ob.Type { - case ORDER_BY_ASC: - obType = " ASC" - case ORDER_BY_DESC: - obType = " DESC" - } - return ob.Field + obType -} - // BulkFetchConfig is the configuration of a Model.BulkFetch() type BulkFetchConfig struct { - Limit int - Offset int - OrderBys []OrderBy + Limit int + Offset int + OrderBys []OrderBy + Predicates []Predicate } // ConsumeSortQuery consumes a `sort` query parameter diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..a6ff0ee --- /dev/null +++ b/logger.go @@ -0,0 +1,99 @@ +package surf + +import ( + "fmt" + "gopkg.in/guregu/null.v3" + "io" + "strconv" + "strings" + "time" +) + +var ( + loggingEnabled = false + loggingWriter io.Writer +) + +// SetLogging adjusts the configuration for logging. You can enable +// and disable the logging here. By default, logging is disabled. +// +// Most calls to this function will be called like SetLogging(true, os.Stdout) +func SetLogging(enabled bool, writer io.Writer) { + loggingEnabled = enabled + loggingWriter = writer +} + +// printQuery prints a query if the user has enabled logging +func PrintSqlQuery(query string, args ...interface{}) { + if loggingEnabled { + for i, arg := range args { + query = strings.Replace(query, "$"+strconv.Itoa(i+1), pointerToLogString(arg), 1) + } + fmt.Fprint(loggingWriter, query) + } +} + +// pointerToLogString converts a value pointer to the string +// that should be logged for it +func pointerToLogString(pointer interface{}) string { + switch v := pointer.(type) { + case *string: + return "'" + *v + "'" + case *float32: + return strconv.FormatFloat(float64(*v), 'f', -1, 32) + case *float64: + return strconv.FormatFloat(*v, 'f', -1, 64) + case *bool: + return strconv.FormatBool(*v) + case *int: + return strconv.Itoa(*v) + case *int8: + return strconv.FormatInt(int64(*v), 10) + case *int16: + return strconv.FormatInt(int64(*v), 10) + case *int32: + return strconv.FormatInt(int64(*v), 10) + case *int64: + return strconv.FormatInt(*v, 10) + case *uint: + return strconv.FormatUint(uint64(*v), 10) + case *uint8: + return strconv.FormatUint(uint64(*v), 10) + case *uint16: + return strconv.FormatUint(uint64(*v), 10) + case *uint32: + return strconv.FormatUint(uint64(*v), 10) + case *uint64: + return strconv.FormatUint(*v, 10) + case *time.Time: + return "'" + (*v).Format(time.RFC3339) + "'" + case *null.Int: + if v.Valid { + return strconv.FormatInt(v.Int64, 10) + } + break + case *null.String: + if v.Valid { + return "'" + v.String + "'" + } + break + case *null.Bool: + if v.Valid { + return strconv.FormatBool(v.Bool) + } + break + case *null.Float: + if v.Valid { + return strconv.FormatFloat(v.Float64, 'f', -1, 64) + } + break + case *null.Time: + if v.Valid { + return "'" + v.Time.Format(time.RFC3339) + "'" + } + break + default: + return fmt.Sprintf("%v", v) + } + return "null" +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..f648a35 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,147 @@ +package surf_test + +import ( + "github.com/go-carrot/surf" + "github.com/stretchr/testify/assert" + "gopkg.in/guregu/null.v3" + "testing" + "time" +) + +type StackWriter struct { + Stack []string +} + +func (sw *StackWriter) Write(p []byte) (n int, err error) { + sw.Stack = append(sw.Stack, string(p)) + return 0, nil +} + +func (sw *StackWriter) Peek() string { + if len(sw.Stack) > 0 { + return sw.Stack[len(sw.Stack)-1] + } + return "" +} + +func TestLogger(t *testing.T) { + // Enable logging + stackWriter := &StackWriter{} + surf.SetLogging(true, stackWriter) + + // float32 + var idFloat32 float32 = 8.8 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idFloat32) + assert.Equal(t, "SELECT * FROM table WHERE id = 8.8", stackWriter.Peek()) + + // float64 + var idFloat64 float64 = 8.9 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idFloat64) + assert.Equal(t, "SELECT * FROM table WHERE id = 8.9", stackWriter.Peek()) + + // bool + var idBool bool = false + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idBool) + assert.Equal(t, "SELECT * FROM table WHERE id = false", stackWriter.Peek()) + + // int + var idInt int = 190 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 190", stackWriter.Peek()) + + // int8 + var idInt8 int8 = 8 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt8) + assert.Equal(t, "SELECT * FROM table WHERE id = 8", stackWriter.Peek()) + + // int16 + var idInt16 int16 = 111 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt16) + assert.Equal(t, "SELECT * FROM table WHERE id = 111", stackWriter.Peek()) + + // int32 + var idInt32 int32 = 1110 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt32) + assert.Equal(t, "SELECT * FROM table WHERE id = 1110", stackWriter.Peek()) + + // int64 + var idInt64 int64 = 11100 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idInt64) + assert.Equal(t, "SELECT * FROM table WHERE id = 11100", stackWriter.Peek()) + + // uint + var idUInt uint = 200 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 200", stackWriter.Peek()) + + // uint8 + var idUInt8 uint8 = 127 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt8) + assert.Equal(t, "SELECT * FROM table WHERE id = 127", stackWriter.Peek()) + + // uint16 + var idUInt16 uint16 = 1278 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt16) + assert.Equal(t, "SELECT * FROM table WHERE id = 1278", stackWriter.Peek()) + + // uint32 + var idUInt32 uint32 = 12788 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt32) + assert.Equal(t, "SELECT * FROM table WHERE id = 12788", stackWriter.Peek()) + + // uint64 + var idUInt64 uint64 = 127888 + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idUInt64) + assert.Equal(t, "SELECT * FROM table WHERE id = 127888", stackWriter.Peek()) + + // time.Time + const layout = "Jan 2, 2006 at 3:04pm (MST)" + idTime, _ := time.Parse(layout, "Feb 3, 2013 at 7:54pm (PST)") + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idTime) + assert.Equal(t, "SELECT * FROM table WHERE id = '2013-02-03T19:54:00Z'", stackWriter.Peek()) + + // null.Int + idNullInt := null.IntFrom(100) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullInt) + assert.Equal(t, "SELECT * FROM table WHERE id = 100", stackWriter.Peek()) + + idNullIntNull := null.Int{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullIntNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.String + idNullString := null.StringFrom("Hello") + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullString) + assert.Equal(t, "SELECT * FROM table WHERE id = 'Hello'", stackWriter.Peek()) + + idNullStringNull := null.String{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullStringNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Bool + idNullBool := null.BoolFrom(false) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullBool) + assert.Equal(t, "SELECT * FROM table WHERE id = false", stackWriter.Peek()) + + idNullBoolNull := null.Bool{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullBoolNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Float + idNullFloat := null.FloatFrom(1.2) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullFloat) + assert.Equal(t, "SELECT * FROM table WHERE id = 1.2", stackWriter.Peek()) + + idNullFloatNull := null.Float{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullFloatNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) + + // null.Time + idNullTime := null.TimeFrom(idTime) + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullTime) + assert.Equal(t, "SELECT * FROM table WHERE id = '2013-02-03T19:54:00Z'", stackWriter.Peek()) + + idNullTimeNull := null.Time{} + surf.PrintSqlQuery("SELECT * FROM table WHERE id = $1", &idNullTimeNull) + assert.Equal(t, "SELECT * FROM table WHERE id = null", stackWriter.Peek()) +} diff --git a/model.go b/model.go index 01e7bc1..d0fc028 100644 --- a/model.go +++ b/model.go @@ -25,6 +25,8 @@ type Field struct { Updatable bool UniqueIdentifier bool SkipValidation bool + GetReference func() (BuildModel, string) + SetReference func(Model) error IsSet func(interface{}) bool } diff --git a/model_helpers.go b/model_helpers.go new file mode 100644 index 0000000..5352053 --- /dev/null +++ b/model_helpers.go @@ -0,0 +1,208 @@ +package surf + +import ( + "database/sql" + "errors" + "fmt" + "gopkg.in/guregu/null.v3" +) + +// getUniqueIdentifier Returns the unique identifier that this model will +// query against. +// +// This will return the first field in Configuration.Fields that: +// - Has `UniqueIdentifier` set to true +// - Returns true from `IsSet` +// +// This function will panic in the event that it encounters a field that is a +// `UniqueIdentifier`, and doesn't have `IsSet` implemented. +func getUniqueIdentifier(w Model) (Field, error) { + // Get all unique identifier fields + var uniqueIdentifierFields []Field + for _, field := range w.GetConfiguration().Fields { + if field.UniqueIdentifier { + uniqueIdentifierFields = append(uniqueIdentifierFields, field) + } + } + + // Determine which unique identifier we will be querying with + var uniqueIdentifierField Field + for _, field := range uniqueIdentifierFields { + if field.IsSet == nil { + panic(fmt.Sprintf("Field `%v` must implement IsSet, as it is a `UniqueIdentifier`", field.Name)) + } else if field.IsSet(field.Pointer) { + uniqueIdentifierField = field + break + } + } + + // Return + if uniqueIdentifierField.Pointer == nil { + return uniqueIdentifierField, errors.New("There is no UniqueIdentifier Field that is set") + } + return uniqueIdentifierField, nil +} + +// expandForeign expands all foreign references for a single Model +func expandForeign(model Model) error { + // Load all foreign references + for _, field := range model.GetConfiguration().Fields { + + // If it's a set foreign reference + if field.GetReference != nil && field.SetReference != nil && field.IsSet(field.Pointer) { + + // Get the reference type + modelBuilder, identifier := field.GetReference() + model := modelBuilder() + + // Set the identifier on the foreign reference + // The foreign reference value may only be a `null.Int` or an `int64` + // The identifier on the foreign model may only be of type `int64` + for _, modelField := range model.GetConfiguration().Fields { + if modelField.Name == identifier { + switch tv := field.Pointer.(type) { + case *int64: + *(modelField.Pointer.(*int64)) = *tv + break + case *null.Int: + *(modelField.Pointer.(*int64)) = tv.Int64 + break + } + break + } + } + + // Load + err := model.Load() + if err != nil { + return err + } + + // Set reference + err = field.SetReference(model) + if err != nil { + return err + } + } + } + + return nil +} + +// expandForeigns expands all foreign references for an array of Model +func expandForeigns(modelBuilder BuildModel, models []Model) error { + // Expand all foreign references + for _, field := range modelBuilder().GetConfiguration().Fields { + // If the field is a foreign key + if field.GetReference != nil && field.SetReference != nil { + builder, foreignField := field.GetReference() + err := expandForeignsByField(field.Name, builder, foreignField, models) + if err != nil { + return err + } + } + } + return nil +} + +// expandForeignsByField expands a single foreign key for an array of Model +func expandForeignsByField(fieldName string, foreignBuilder BuildModel, foreignField string, models []Model) error { + // Get Foreign IDs + ids := make([]interface{}, 0) + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + switch tv := field.Pointer.(type) { + case *int64: + ids = appendIfMissing(ids, *tv) + break + case *null.Int: + if tv.Valid { + ids = appendIfMissing(ids, tv.Int64) + } + break + default: + panic(fmt.Sprintf("Foreign Key for %v.%v may only be of type `null.Int` or `int64`", + model.GetConfiguration().TableName, field.Name)) + } + } + } + } + + // If there's nothing to load, exit early + if len(ids) == 0 { + return nil + } + + // Load Foreign models + foreignModels, err := foreignBuilder().BulkFetch( + BulkFetchConfig{ + Limit: len(ids), + Predicates: []Predicate{{ + Field: foreignField, + PredicateType: WHERE_IN, + Values: ids, + }}, + }, + foreignBuilder, + ) + if err != nil { + return err + } + + // Stuff foreign models into models + for _, model := range models { + for _, field := range model.GetConfiguration().Fields { + if field.Name == fieldName { + var toMatch int64 + switch tv := field.Pointer.(type) { + case *int64: + toMatch = *tv + break + case *null.Int: + toMatch = tv.Int64 + break + } + + MatchForeignModel: + for _, foreignModel := range foreignModels { + FindField: + for _, foreignModelField := range foreignModel.GetConfiguration().Fields { + if foreignModelField.Name == foreignField { + + if *(foreignModelField.Pointer.(*int64)) == toMatch { + field.SetReference(foreignModel) + break MatchForeignModel + } + break FindField + } + } + } + break + } + } + } + return nil +} + +// appendIfMissing functions like append(), but will only add the +// int64 to the slice if it doesn't exist in the slice already +func appendIfMissing(slice []interface{}, i int64) []interface{} { + for _, ele := range slice { + if ele == i { + return slice + } + } + return append(slice, i) +} + +// consumeRow Scans a *sql.Row into our struct +// that is using this model +func consumeRow(w Model, row *sql.Row) error { + fields := w.GetConfiguration().Fields + var s []interface{} + for _, value := range fields { + s = append(s, value.Pointer) + } + return row.Scan(s...) +} diff --git a/order_by.go b/order_by.go new file mode 100644 index 0000000..059166e --- /dev/null +++ b/order_by.go @@ -0,0 +1,27 @@ +package surf + +// OrderByType is an enumeration of the SQL standard order by +type OrderByType int + +const ( + ORDER_BY_ASC OrderByType = iota + ORDER_BY_DESC +) + +// OrderBy is the definition of a single order by clause +type OrderBy struct { + Field string + Type OrderByType +} + +// ToString converts an OrderBy to SQL +func (ob *OrderBy) toString() string { + obType := "" + switch ob.Type { + case ORDER_BY_ASC: + obType = " ASC" + case ORDER_BY_DESC: + obType = " DESC" + } + return ob.Field + obType +} diff --git a/pq_model.go b/pq_model.go index 8739ae7..5c34cc6 100644 --- a/pq_model.go +++ b/pq_model.go @@ -23,8 +23,7 @@ func (w *PqModel) GetConfiguration() *Configuration { func (w *PqModel) Insert() error { // Get Insertable Fields var insertableFields []Field - allFields := w.Config.Fields - for _, field := range allFields { + for _, field := range w.Config.Fields { if field.Insertable { insertableFields = append(insertableFields, field) } @@ -49,7 +48,14 @@ func (w *PqModel) Insert() error { queryBuffer.WriteString(", ") } } - queryBuffer.WriteString(") RETURNING *;") + queryBuffer.WriteString(") RETURNING ") + for i, field := range w.Config.Fields { + queryBuffer.WriteString(field.Name) + if (i + 1) < len(w.Config.Fields) { + queryBuffer.WriteString(", ") + } + } + queryBuffer.WriteString(";") // Get Value Fields var valueFields []interface{} @@ -57,16 +63,26 @@ func (w *PqModel) Insert() error { valueFields = append(valueFields, value.Pointer) } + // Log Query + query := queryBuffer.String() + PrintSqlQuery(query, valueFields...) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), valueFields...) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, valueFields...) + err := consumeRow(w, row) + if err != nil { + return err + } + + // Expand foreign references + return expandForeign(w) } // Load loads the model from the database from its unique identifier // and then loads those values into the struct func (w *PqModel) Load() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -86,15 +102,25 @@ func (w *PqModel) Load() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$1;") + // Log Query + query := queryBuffer.String() + PrintSqlQuery(query, uniqueIdentifierField.Pointer) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), uniqueIdentifierField.Pointer) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, uniqueIdentifierField.Pointer) + err = consumeRow(w, row) + if err != nil { + return err + } + + // Expand foreign references + return expandForeign(w) } // Update updates the model with the current values in the struct func (w *PqModel) Update() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -124,7 +150,14 @@ func (w *PqModel) Update() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$") queryBuffer.WriteString(strconv.Itoa(len(updatableFields) + 1)) - queryBuffer.WriteString(" RETURNING *;") + queryBuffer.WriteString(" RETURNING ") + for i, field := range w.Config.Fields { + queryBuffer.WriteString(field.Name) + if (i + 1) < len(w.Config.Fields) { + queryBuffer.WriteString(", ") + } + } + queryBuffer.WriteString(";") // Get Value Fields var valueFields []interface{} @@ -133,15 +166,25 @@ func (w *PqModel) Update() error { } valueFields = append(valueFields, uniqueIdentifierField.Pointer) + // Log Query + query := queryBuffer.String() + PrintSqlQuery(query, valueFields...) + // Execute Query - row := w.Database.QueryRow(queryBuffer.String(), valueFields...) - return w.ConsumeRow(row) + row := w.Database.QueryRow(query, valueFields...) + err = consumeRow(w, row) + if err != nil { + return err + } + + // Expand foreign references + return expandForeign(w) } // Delete deletes the model -func (w PqModel) Delete() error { +func (w *PqModel) Delete() error { // Get Unique Identifier - uniqueIdentifierField, err := w.getUniqueIdentifier() + uniqueIdentifierField, err := getUniqueIdentifier(w) if err != nil { return err } @@ -154,6 +197,10 @@ func (w PqModel) Delete() error { queryBuffer.WriteString(uniqueIdentifierField.Name) queryBuffer.WriteString("=$1;") + // Log Query + query := queryBuffer.String() + PrintSqlQuery(query, uniqueIdentifierField.Pointer) + // Execute Query res, err := w.Database.Exec(queryBuffer.String(), uniqueIdentifierField.Pointer) if err != nil { @@ -166,45 +213,12 @@ func (w PqModel) Delete() error { return nil } -// getUniqueIdentifier Returns the unique identifier that this model will -// query against. -// -// This will return the first field in Configuration.Fields that: -// - Has `UniqueIdentifier` set to true -// - Returns true from `IsSet` -// -// This function will panic in the event that it encounters a field that is a -// `UniqueIdentifier`, and doesn't have `IsSet` implemented. -func (w *PqModel) getUniqueIdentifier() (Field, error) { - // Get all unique identifier fields - var uniqueIdentifierFields []Field - for _, field := range w.Config.Fields { - if field.UniqueIdentifier { - uniqueIdentifierFields = append(uniqueIdentifierFields, field) - } - } - - // Determine which unique identifier we will be querying with - var uniqueIdentifierField Field - for _, field := range uniqueIdentifierFields { - if field.IsSet == nil { - panic(fmt.Sprintf("Field `%v` must implement IsSet, as it is a `UniqueIdentifier`", field.Name)) - } else if field.IsSet(field.Pointer) { - uniqueIdentifierField = field - break - } - } - - // Return - if uniqueIdentifierField.Pointer == nil { - return uniqueIdentifierField, errors.New("There is no UniqueIdentifier Field that is set") - } - return uniqueIdentifierField, nil -} - // BulkFetch gets an array of models func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) ([]Model, error) { - // Generate Query + // Set up values + values := make([]interface{}, 0) + + // Generate query var queryBuffer bytes.Buffer queryBuffer.WriteString("SELECT ") for i, field := range w.Config.Fields { @@ -215,7 +229,17 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) } queryBuffer.WriteString(" FROM ") queryBuffer.WriteString(buildModel().GetConfiguration().TableName) - queryBuffer.WriteString(" ORDER BY ") + if len(fetchConfig.Predicates) > 0 { + // WHERE + queryBuffer.WriteString(" ") + predicatesStr, predicateValues := predicatesToString(1, fetchConfig.Predicates) + + values = append(values, predicateValues...) + queryBuffer.WriteString(predicatesStr) + } + if len(fetchConfig.OrderBys) > 0 { + queryBuffer.WriteString(" ORDER BY ") + } for i, orderBy := range fetchConfig.OrderBys { // Validate that the orderBy.Field is a field valid := false @@ -230,7 +254,7 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) w.Config.TableName, orderBy.Field) } // Write to query - queryBuffer.WriteString(orderBy.ToString()) + queryBuffer.WriteString(orderBy.toString()) if (i + 1) < len(fetchConfig.OrderBys) { queryBuffer.WriteString(", ") } @@ -241,14 +265,18 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) queryBuffer.WriteString(strconv.Itoa(fetchConfig.Offset)) queryBuffer.WriteString(";") + // Log Query + query := queryBuffer.String() + PrintSqlQuery(query, values...) + // Execute Query - rows, err := w.Database.Query(queryBuffer.String()) + rows, err := w.Database.Query(query, values...) if err != nil { return nil, err } // Stuff into []Model - var models []Model + models := make([]Model, 0) for rows.Next() { model := buildModel() @@ -266,17 +294,12 @@ func (w *PqModel) BulkFetch(fetchConfig BulkFetchConfig, buildModel BuildModel) models = append(models, model.(Model)) } + // Expand foreign references + err = expandForeigns(buildModel, models) + if err != nil { + return nil, err + } + // OK return models, nil } - -// ConsumeRow Scans a *sql.Row into our struct -// that is using this model -func (w *PqModel) ConsumeRow(row *sql.Row) error { - fields := w.Config.Fields - var s []interface{} - for _, value := range fields { - s = append(s, value.Pointer) - } - return row.Scan(s...) -} diff --git a/pq_model_test.go b/pq_model_test.go index 4c56aeb..14748c6 100644 --- a/pq_model_test.go +++ b/pq_model_test.go @@ -6,6 +6,7 @@ import ( _ "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "gopkg.in/guregu/null.v3" "os" "testing" ) @@ -19,7 +20,7 @@ import ( // in the database. type Place struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Name string `json:"name"` } @@ -39,7 +40,7 @@ func (p *Place) Prep(dbConnection *sql.DB) *Place { Name: "id", UniqueIdentifier: true, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 }, }, @@ -63,7 +64,7 @@ func (p *Place) Prep(dbConnection *sql.DB) *Place { // as we are using this to test a failure before we even hit the db type Person struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Name string `json:"name"` } @@ -115,7 +116,7 @@ CREATE TABLE animals( */ type Animal struct { surf.Model - Id int `json:"id"` + Id int64 `json:"id"` Slug string `json:"slug"` Name string `json:"name"` Age int `json:"age"` @@ -137,7 +138,7 @@ func (a *Animal) Prep(dbConnection *sql.DB) *Animal { Name: "id", UniqueIdentifier: true, IsSet: func(pointer interface{}) bool { - pointerInt := *pointer.(*int) + pointerInt := *pointer.(*int64) return pointerInt != 0 }, }, @@ -170,6 +171,83 @@ func (a *Animal) Prep(dbConnection *sql.DB) *Animal { return a } +// =============================== +// ========== Toy Model ========== +// =============================== + +/** +Represents: + +CREATE TABLE toys( + id serial PRIMARY KEY, + name text NOT NULL, + owner bigint NOT NULL REFERENCES animals(id) ON DELETE CASCADE +); +*/ +type Toy struct { + surf.Model + Id int64 `json:"id"` + Name string `json:"name"` + OwnerId int64 `json:"-"` + Owner *Animal `json:"owner"` + SecondOwnerId null.Int `json:"-"` + SecondOwner *Animal `json:"second_owner"` +} + +func NewToy(dbConnection *sql.DB) *Toy { + toy := new(Toy) + return toy.Prep(dbConnection) +} + +func (t *Toy) Prep(dbConnection *sql.DB) *Toy { + t.Model = &surf.PqModel{ + Database: dbConnection, + Config: surf.Configuration{ + TableName: "toys", + Fields: []surf.Field{ + {Pointer: &t.Id, Name: "id", UniqueIdentifier: true, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*int64) + return pointerInt != 0 + }, + }, + {Pointer: &t.Name, Name: "name", Insertable: true, Updatable: true}, + {Pointer: &t.OwnerId, Name: "owner", Insertable: true, Updatable: true, + GetReference: func() (surf.BuildModel, string) { + return func() surf.Model { + return NewAnimal(dbConnection) + }, "id" + }, + SetReference: func(model surf.Model) error { + t.Owner = model.(*Animal) + return nil + }, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*int64) + return pointerInt != 0 + }, + }, + {Pointer: &t.SecondOwnerId, Name: "second_owner", Insertable: true, Updatable: true, + GetReference: func() (surf.BuildModel, string) { + return func() surf.Model { + return NewAnimal(dbConnection) + }, "id" + }, + SetReference: func(model surf.Model) error { + t.SecondOwner = model.(*Animal) + return nil + }, + IsSet: func(pointer interface{}) bool { + pointerInt := *pointer.(*null.Int) + return pointerInt.Valid + }, + }, + }, + }, + } + return t +} + // ================================================== // ========== Animal Consume Failure Model ========== // ================================================== @@ -245,9 +323,12 @@ type PqWorkerTestSuite struct { } func (suite *PqWorkerTestSuite) SetupTest() { - databaseUrl := os.Getenv("SERF_TEST_DATABASE_URL") + // Enable logging + stackWriter := &StackWriter{} + surf.SetLogging(true, stackWriter) // Opening + storing the connection + databaseUrl := os.Getenv("SERF_TEST_DATABASE_URL") db, err := sql.Open("postgres", databaseUrl) if err != nil { suite.Fail("Failed to open database connection") @@ -557,6 +638,182 @@ func (suite *PqWorkerTestSuite) TestGetConfiguration() { assert.True(suite.T(), (hasId && hasSlug && hasName && hasAge)) } +func (suite *PqWorkerTestSuite) TestNestedModel() { + // Create an Animal + cat := NewAnimal(suite.db) + cat.Name = "Luna" + cat.Slug = "luna" + cat.Age = 2 + cat.Insert() + assert.NotEqual(suite.T(), 0, cat.Id) + + // Create a toy + tennisBall := NewToy(suite.db) + tennisBall.Name = "tennis ball" + tennisBall.OwnerId = cat.Id + tennisBall.Insert() + assert.NotEqual(suite.T(), int64(0), tennisBall.Id) + + // Create a second toy + sock := NewToy(suite.db) + sock.Name = "sock" + sock.OwnerId = cat.Id + sock.SecondOwnerId = null.IntFrom(cat.Id) + sock.Insert() + assert.NotEqual(suite.T(), int64(0), sock.Id) + + // Load all toys + toys, err := NewToy(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + }, func() surf.Model { + return NewToy(suite.db) + }) + + // Verify everything loaded properly + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 2, len(toys)) + + loadedTennisBall := toys[0] + assert.Equal(suite.T(), tennisBall.Id, loadedTennisBall.(*Toy).Id) + assert.Equal(suite.T(), "tennis ball", loadedTennisBall.(*Toy).Name) + assert.Equal(suite.T(), cat.Id, loadedTennisBall.(*Toy).Owner.Id) + + loadedSock := toys[1] + assert.Equal(suite.T(), loadedSock.(*Toy).Id, sock.Id) + assert.Equal(suite.T(), loadedSock.(*Toy).Name, "sock") + assert.Equal(suite.T(), loadedSock.(*Toy).Owner.Id, cat.Id) + + // Clean up + cat.Delete() + tennisBall.Delete() + sock.Delete() +} + +func (suite *PqWorkerTestSuite) TestPredicates() { + // Create some Animals + luna := NewAnimal(suite.db) + luna.Name = "Luna" + luna.Slug = "luna" + luna.Age = 2 + luna.Insert() + assert.NotEqual(suite.T(), 0, luna.Id) + + rae := NewAnimal(suite.db) + rae.Name = "Rae" + rae.Slug = "rae" + rae.Age = 2 + rae.Insert() + assert.NotEqual(suite.T(), 0, rae.Id) + + rigby := NewAnimal(suite.db) + rigby.Name = "Rigby" + rigby.Slug = "rigby" + rigby.Age = 3 + rigby.Insert() + assert.NotEqual(suite.T(), 0, rigby.Id) + + // Test multiple predicates + animals, err := NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_NOT_EQUAL, Values: []interface{}{"Luna"}}, + {Field: "name", PredicateType: surf.WHERE_NOT_EQUAL, Values: []interface{}{"Rae"}}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 1, len(animals)) + loadedRigby := animals[0] + assert.Equal(suite.T(), loadedRigby.(*Animal).Name, "Rigby") + + // Test WHERE_IN with multiple elements + animals, err = NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_IN, Values: []interface{}{"Luna", "Rae"}}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 2, len(animals)) + loadedLuna := animals[0] + loadedRae := animals[1] + assert.Equal(suite.T(), "Luna", loadedLuna.(*Animal).Name) + assert.Equal(suite.T(), "Rae", loadedRae.(*Animal).Name) + + // Test WHERE_NOT_NULL + animals, err = NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: surf.WHERE_IS_NOT_NULL}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), 3, len(animals)) + + // Cleanup + luna.Delete() + rae.Delete() + rigby.Delete() +} + +func (suite *PqWorkerTestSuite) TestWhereEqualPanic() { + suite.PredicatePanic(surf.WHERE_EQUAL) + suite.PredicatePanic(surf.WHERE_IN) + suite.PredicatePanic(surf.WHERE_NOT_IN) + suite.PredicatePanic(surf.WHERE_LIKE) + suite.PredicatePanic(surf.WHERE_EQUAL) + suite.PredicatePanic(surf.WHERE_NOT_EQUAL) + suite.PredicatePanic(surf.WHERE_GREATER_THAN) + suite.PredicatePanic(surf.WHERE_GREATER_THAN_OR_EQUAL_TO) + suite.PredicatePanic(surf.WHERE_LESS_THAN) + suite.PredicatePanic(surf.WHERE_LESS_THAN_OR_EQUAL_TO) + suite.PredicatePanic(surf.WHERE_IS_NOT_NULL, 1) + suite.PredicatePanic(surf.WHERE_IS_NULL, 1) + suite.PredicatePanic(9999, 1) +} + +func (suite *PqWorkerTestSuite) PredicatePanic(predType surf.PredicateType, values ...interface{}) { + defer func() { + recover() + }() + + // Test WHERE_NOT_IN without any elements (panics) + NewAnimal(suite.db).BulkFetch(surf.BulkFetchConfig{ + Limit: 10, + Offset: 0, + OrderBys: []surf.OrderBy{ + {Field: "id", Type: surf.ORDER_BY_ASC}, + }, + Predicates: []surf.Predicate{ + {Field: "name", PredicateType: predType, Values: values}, + }, + }, func() surf.Model { + return NewAnimal(suite.db) + }) + assert.Fail(suite.T(), "Unreachable statement, last call should have panicked") +} + // In order for 'go test' to run this suite, we need to create // a normal test function and pass our suite to suite.Run func TestPqWorkerTestSuite(t *testing.T) { diff --git a/predicate.go b/predicate.go new file mode 100644 index 0000000..945b0c6 --- /dev/null +++ b/predicate.go @@ -0,0 +1,168 @@ +package surf + +import ( + "strconv" +) + +type PredicateType int + +const ( + WHERE_IS_NOT_NULL PredicateType = iota // Default + WHERE_IS_NULL + WHERE_IN + WHERE_NOT_IN + WHERE_LIKE + WHERE_EQUAL + WHERE_NOT_EQUAL + WHERE_GREATER_THAN + WHERE_GREATER_THAN_OR_EQUAL_TO + WHERE_LESS_THAN + WHERE_LESS_THAN_OR_EQUAL_TO +) + +// getPredicateTypeString returns the predicate type string from it's value +func getPredicateTypeString(predicateType PredicateType) string { + switch predicateType { + case WHERE_IS_NULL: + return "WHERE_IS_NULL" + case WHERE_IN: + return "WHERE_IN" + case WHERE_NOT_IN: + return "WHERE_NOT_IN" + case WHERE_LIKE: + return "WHERE_LIKE" + case WHERE_EQUAL: + return "WHERE_EQUAL" + case WHERE_NOT_EQUAL: + return "WHERE_NOT_EQUAL" + case WHERE_GREATER_THAN: + return "WHERE_GREATER_THAN" + case WHERE_GREATER_THAN_OR_EQUAL_TO: + return "WHERE_GREATER_THAN_OR_EQUAL_TO" + case WHERE_LESS_THAN: + return "WHERE_LESS_THAN" + case WHERE_LESS_THAN_OR_EQUAL_TO: + return "WHERE_LESS_THAN_OR_EQUAL_TO" + } + return "WHERE_IS_NOT_NULL" +} + +// Predicate is the definition of a single where SQL predicate +type Predicate struct { + Field string + PredicateType PredicateType + Values []interface{} +} + +// toString will convert a predicate to it's query string, along with its values +// to be passed along with the query +// +// This function will panic in the event that this is called on a malformed predicate +func (p *Predicate) toString(valueIndex int) (string, []interface{}) { + // Field + predicate := p.Field + + // Type + switch p.PredicateType { + case WHERE_IS_NULL: + predicate += " IS NULL" + break + case WHERE_IN: + predicate += " IN " + break + case WHERE_NOT_IN: + predicate += " NOT IN " + break + case WHERE_LIKE: + predicate += " LIKE " + break + case WHERE_EQUAL: + predicate += " = " + break + case WHERE_NOT_EQUAL: + predicate += " != " + break + case WHERE_GREATER_THAN: + predicate += " > " + break + case WHERE_GREATER_THAN_OR_EQUAL_TO: + predicate += " >= " + break + case WHERE_LESS_THAN: + predicate += " < " + break + case WHERE_LESS_THAN_OR_EQUAL_TO: + predicate += " <= " + break + default: + predicate += " IS NOT NULL" + break + } + + // Values + values := make([]interface{}, 0) + switch p.PredicateType { + case WHERE_IN, + WHERE_NOT_IN: + if len(p.Values) == 0 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates require at least one value.") + } + predicate += "(" + for i, value := range p.Values { + values = append(values, value) + predicate += "$" + strconv.Itoa(valueIndex) + valueIndex++ + if i < len(p.Values)-1 { + predicate += ", " + } + } + predicate += ")" + break + case WHERE_LIKE, + WHERE_EQUAL, + WHERE_NOT_EQUAL, + WHERE_GREATER_THAN, + WHERE_GREATER_THAN_OR_EQUAL_TO, + WHERE_LESS_THAN, + WHERE_LESS_THAN_OR_EQUAL_TO: + if len(p.Values) != 1 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates require exactly one value.") + } + values = append(values, p.Values[0]) + predicate += "$" + strconv.Itoa(valueIndex) + break + case WHERE_IS_NOT_NULL, + WHERE_IS_NULL: + if len(p.Values) != 0 { + panic("`" + getPredicateTypeString(p.PredicateType) + "` predicates cannot have any values.") + } + break + default: + panic("Unknown predicate type.") + } + + return predicate, values +} + +// predicatesToString converts an array of predicates to a query string, along with its values +// to be passed along with the query +// +// This function will panic in the event that it encounters a malformed predicate +func predicatesToString(valueIndex int, predicates []Predicate) (string, []interface{}) { + values := make([]interface{}, 0) + + predicateStr := "" + if len(predicates) > 0 { + predicateStr += "WHERE " + } + for i, predicate := range predicates { + iPredicateStr, iValues := predicate.toString(valueIndex) + valueIndex += len(iValues) + values = append(values, iValues...) + predicateStr += iPredicateStr + if i < (len(predicates) - 1) { + predicateStr += " AND " + } + } + return predicateStr, values +}