-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(inserter): Adding a dynamic insert package (#32)
* Adding a dynamic insert package * Updating tests * Adding tests
- Loading branch information
1 parent
ed28cc3
commit 53df004
Showing
5 changed files
with
553 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package inserter | ||
|
||
import ( | ||
"database/sql" | ||
"errors" | ||
) | ||
|
||
var ( | ||
// ErrNoDatabaseConnection is returned when no database connection is set | ||
ErrNoDatabaseConnection = errors.New("no database connection set") | ||
|
||
// ErrNoTable is returned when no table is set | ||
ErrNoTable = errors.New("no table set") | ||
|
||
// ErrNoFields is returned when no fields are set | ||
ErrNoFields = errors.New("no fields set") | ||
|
||
// ErrNoArgs is returned when no arguments are set | ||
ErrNoArgs = errors.New("no arguments set") | ||
|
||
// ErrNoResources is returned when no resources are set | ||
ErrNoResources = errors.New("no resources set") | ||
) | ||
|
||
type SQLBatch struct { | ||
// resources is the resources to use in the SQL statement | ||
resources []any | ||
|
||
// fields is the fields to update in the SQL statement | ||
fields []string | ||
|
||
// args is the arguments to use in the SQL statement | ||
args []any | ||
|
||
// db is the database connection to use | ||
db *sql.DB | ||
|
||
// tagName is the tag name to look for in the struct. This is an override from the default tag "db" | ||
tagName string | ||
|
||
// table is the table name to use in the SQL statement | ||
table string | ||
} | ||
|
||
func (b *SQLBatch) AddResources(resources ...any) { | ||
b.resources = append(b.resources, resources...) | ||
} | ||
|
||
func (b *SQLBatch) Fields() []string { | ||
return b.fields | ||
} | ||
|
||
func (b *SQLBatch) Args() []any { | ||
return b.args | ||
} | ||
|
||
func (b *SQLBatch) validateSQLGen() error { | ||
if b.table == "" { | ||
return ErrNoTable | ||
} | ||
if len(b.resources) == 0 { | ||
return ErrNoResources | ||
} | ||
if len(b.fields) == 0 { | ||
return ErrNoFields | ||
} | ||
if len(b.args) == 0 { | ||
return ErrNoArgs | ||
} | ||
return nil | ||
} | ||
|
||
func (b *SQLBatch) validateSQLInsert() error { | ||
if b.db == nil { | ||
return ErrNoDatabaseConnection | ||
} | ||
if b.table == "" { | ||
return ErrNoTable | ||
} | ||
if len(b.resources) == 0 { | ||
return ErrNoResources | ||
} | ||
if len(b.fields) == 0 { | ||
return ErrNoFields | ||
} | ||
if len(b.args) == 0 { | ||
return ErrNoArgs | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package inserter | ||
|
||
import "database/sql" | ||
|
||
type BatchOpt func(*SQLBatch) | ||
|
||
// WithTagName sets the tag name to look for in the struct. This is an override from the default tag "db" | ||
func WithTagName(tagName string) BatchOpt { | ||
return func(b *SQLBatch) { | ||
b.tagName = tagName | ||
} | ||
} | ||
|
||
// WithTable sets the table name to use in the SQL statement | ||
func WithTable(table string) BatchOpt { | ||
return func(b *SQLBatch) { | ||
b.table = table | ||
} | ||
} | ||
|
||
// WithDB sets the database connection to use | ||
func WithDB(db *sql.DB) BatchOpt { | ||
return func(b *SQLBatch) { | ||
b.db = db | ||
} | ||
} | ||
|
||
func WithResources(resources []any) BatchOpt { | ||
return func(b *SQLBatch) { | ||
b.resources = resources | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package inserter | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
) | ||
|
||
const ( | ||
// defaultTagName is the default tag name to look for in the struct | ||
defaultTagName = "db" | ||
) | ||
|
||
func NewBatch(opts ...BatchOpt) *SQLBatch { | ||
b := new(SQLBatch) | ||
b.tagName = defaultTagName | ||
for _, opt := range opts { | ||
opt(b) | ||
} | ||
|
||
b.genBatch() | ||
|
||
return b | ||
} | ||
|
||
func (b *SQLBatch) genBatch() { | ||
uniqueFields := make(map[string]struct{}) | ||
|
||
for _, r := range b.resources { | ||
// get the type of the resource | ||
t := reflect.TypeOf(r) | ||
if t.Kind() == reflect.Ptr { | ||
t = t.Elem() | ||
} | ||
|
||
// Is the type a struct? | ||
if t.Kind() != reflect.Struct { | ||
continue | ||
} | ||
|
||
// get the value of the resource | ||
v := reflect.ValueOf(r) | ||
if v.Kind() == reflect.Ptr { | ||
v = v.Elem() | ||
} | ||
|
||
// get the fields | ||
for i := 0; i < t.NumField(); i++ { | ||
f := t.Field(i) | ||
tag := f.Tag.Get(b.tagName) | ||
if tag == "-" { | ||
continue | ||
} | ||
|
||
if !f.IsExported() { | ||
continue | ||
} | ||
|
||
// if no tag is set, use the field name | ||
if tag == "" { | ||
tag = strings.ToLower(f.Name) | ||
} | ||
|
||
b.args = append(b.args, v.Field(i).Interface()) | ||
|
||
// if the field is not unique, skip it | ||
if _, ok := uniqueFields[tag]; ok { | ||
continue | ||
} | ||
|
||
// add the field to the list | ||
b.fields = append(b.fields, tag) | ||
uniqueFields[tag] = struct{}{} | ||
} | ||
} | ||
} | ||
|
||
func (b *SQLBatch) sqlGen() (string, []any, error) { | ||
if err := b.validateSQLGen(); err != nil { | ||
return "", nil, err | ||
} | ||
|
||
sqlBuilder := new(strings.Builder) | ||
|
||
sqlBuilder.WriteString("INSERT INTO ") | ||
sqlBuilder.WriteString(b.table) | ||
sqlBuilder.WriteString(" (") | ||
sqlBuilder.WriteString(strings.Join(b.fields, ", ")) | ||
sqlBuilder.WriteString(") VALUES ") | ||
|
||
// We need to have the same number of "?" as fields and then repeat that for the number of resources | ||
placeholder := strings.Repeat("?, ", len(b.fields)) | ||
placeholder = placeholder[:len(placeholder)-2] // Remove the trailing ", " | ||
placeholder = "(" + placeholder + "), " | ||
|
||
// Repeat the placeholder for the number of resources | ||
placeholders := strings.Repeat(placeholder, len(b.resources)) | ||
sqlBuilder.WriteString(placeholders[:len(placeholders)-2]) // Remove the trailing ", " and add the closing ")" | ||
|
||
return sqlBuilder.String(), b.args, nil | ||
} | ||
|
||
func (b *SQLBatch) Perform() (sql.Result, error) { | ||
if err := b.validateSQLInsert(); err != nil { | ||
return nil, fmt.Errorf("validate SQL generation: %w", err) | ||
} | ||
|
||
sqlStr, args, err := b.sqlGen() | ||
if err != nil { | ||
return nil, fmt.Errorf("generate SQL: %w", err) | ||
} | ||
|
||
return b.db.Exec(sqlStr, args...) | ||
} |
Oops, something went wrong.