diff --git a/create.go b/create.go index 85031d4..390a71e 100644 --- a/create.go +++ b/create.go @@ -8,19 +8,11 @@ import ( ) func createAndGetResult(exec ExecFn, record interface{}) (stdsql.Result, error) { - row, err := NewRow(record) + row, columns, values, err := valuesForRecord(record) if err != nil { return nil, err } - columns := []string{} - values := []interface{}{} - - for c, v := range row.SQLValues() { - columns = append(columns, c) - values = append(values, v) - } - return exec(sql.InsertQuery(row.SQLTableName, columns), values...) } @@ -35,6 +27,50 @@ func createAndRead(exec ExecFn, query QueryFn, record interface{}) error { return err } + return readLastInsert(query, record, result) +} + +func replaceAndGetResult(exec ExecFn, record interface{}) (stdsql.Result, error) { + row, columns, values, err := valuesForRecord(record) + if err != nil { + return nil, err + } + + return exec(sql.ReplaceQuery(row.SQLTableName, columns), values...) +} + +func replace(exec ExecFn, record interface{}) error { + _, err := replaceAndGetResult(exec, record) + return err +} + +func replaceAndRead(exec ExecFn, query QueryFn, record interface{}) error { + result, err := replaceAndGetResult(exec, record) + if err != nil { + return err + } + + return readLastInsert(query, record, result) +} + +func valuesForRecord(record interface{}) (*Row, []string, []interface{}, error) { + row, err := NewRow(record) + if err != nil { + return nil, nil, nil, err + } + + columns := []string{} + values := []interface{}{} + + for c, v := range row.SQLValues() { + columns = append(columns, c) + values = append(values, v) + } + + return row, columns, values, nil +} + +func readLastInsert(query QueryFn, record interface{}, result stdsql.Result) error { id, err := result.LastInsertId() if err != nil { return err diff --git a/db.go b/db.go index 6c8cabc..6e0caeb 100644 --- a/db.go +++ b/db.go @@ -118,6 +118,20 @@ func (db *DB) CreateAndRead(record interface{}) error { return createAndRead(db.Exec, db.Query, record) } +// Replaces given record into the database, generating a replace query for it. +func (db *DB) Replace(record interface{}) error { + return replace(db.Exec, record) +} + +func (db *DB) ReplaceAndGetResult(record interface{}) (stdsql.Result, error) { + return replaceAndGetResult(db.Exec, record) +} + +// Replaces given record and scans the replaceed row back to the given row. +func (db *DB) ReplaceAndRead(record interface{}) error { + return replaceAndRead(db.Exec, db.Query, record) +} + // Runs given SQL query and scans the result rows into the given target interface. The target // interface could be both a single record or a slice of records. // diff --git a/sql/table.go b/sql/table.go index 2f9de5a..28993dc 100644 --- a/sql/table.go +++ b/sql/table.go @@ -109,17 +109,19 @@ func ShowTablesLikeQuery(name string) string { } func InsertQuery(tableName string, columnNames []string) string { - var questionMarks string - - if len(columnNames) > 0 { - questionMarks = strings.Repeat("?,", len(columnNames)) - questionMarks = questionMarks[:len(questionMarks)-1] - } + questionMarks := repeatComma(len(columnNames), "?") return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(quoteColumnNames(columnNames), ","), questionMarks) } +func ReplaceQuery(tableName string, columnNames []string) string { + questionMarks := repeatComma(len(columnNames), "?") + + return fmt.Sprintf("REPLACE INTO %s (%s) VALUES (%s)", + tableName, strings.Join(quoteColumnNames(columnNames), ","), questionMarks) +} + func SelectQuery(tableName string, columnNames []string) string { columns := strings.Join(columnNames, ",") if columns == "" { @@ -150,3 +152,14 @@ func quoteColumnNames(columns []string) []string { return quoted } + +func repeatComma(num int, char string) string { + var out string + + if num > 0 { + out = strings.Repeat(char+",", num) + out = out[:len(out)-1] + } + + return out +} diff --git a/transaction.go b/transaction.go index 5109fe5..daf0883 100644 --- a/transaction.go +++ b/transaction.go @@ -51,6 +51,16 @@ func (tx *Tx) CreateAndRead(record interface{}) error { return createAndRead(tx.Exec, tx.Query, record) } +// Replace given record to the database. +func (tx *Tx) Replace(record interface{}) error { + return replace(tx.Exec, record) +} + +// Replaces given record and scans the replaceed row back to the given row. +func (tx *Tx) ReplaceAndRead(record interface{}) error { + return replaceAndRead(tx.Exec, tx.Query, record) +} + // Run a select query on the databaase (w/ given parameters optionally) and scan the result(s) to the // target interface specified as the first parameter. //