Skip to content

Commit

Permalink
feat(inserter): Adding a dynamic insert package (#32)
Browse files Browse the repository at this point in the history
* Adding a dynamic insert package

* Updating tests

* Adding tests
  • Loading branch information
Jacobbrewer1 authored Oct 12, 2024
1 parent ed28cc3 commit 53df004
Show file tree
Hide file tree
Showing 5 changed files with 553 additions and 1 deletion.
90 changes: 90 additions & 0 deletions inserter/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package inserter

import (
"database/sql"
"errors"
)

var (
// ErrNoDatabaseConnection is returned when no database connection is set
ErrNoDatabaseConnection = errors.New("no database connection set")

// ErrNoTable is returned when no table is set
ErrNoTable = errors.New("no table set")

// ErrNoFields is returned when no fields are set
ErrNoFields = errors.New("no fields set")

// ErrNoArgs is returned when no arguments are set
ErrNoArgs = errors.New("no arguments set")

// ErrNoResources is returned when no resources are set
ErrNoResources = errors.New("no resources set")
)

type SQLBatch struct {
// resources is the resources to use in the SQL statement
resources []any

// fields is the fields to update in the SQL statement
fields []string

// args is the arguments to use in the SQL statement
args []any

// db is the database connection to use
db *sql.DB

// tagName is the tag name to look for in the struct. This is an override from the default tag "db"
tagName string

// table is the table name to use in the SQL statement
table string
}

func (b *SQLBatch) AddResources(resources ...any) {
b.resources = append(b.resources, resources...)
}

func (b *SQLBatch) Fields() []string {
return b.fields
}

func (b *SQLBatch) Args() []any {
return b.args
}

func (b *SQLBatch) validateSQLGen() error {
if b.table == "" {
return ErrNoTable
}
if len(b.resources) == 0 {
return ErrNoResources
}
if len(b.fields) == 0 {
return ErrNoFields
}
if len(b.args) == 0 {
return ErrNoArgs
}
return nil
}

func (b *SQLBatch) validateSQLInsert() error {
if b.db == nil {
return ErrNoDatabaseConnection
}
if b.table == "" {
return ErrNoTable
}
if len(b.resources) == 0 {
return ErrNoResources
}
if len(b.fields) == 0 {
return ErrNoFields
}
if len(b.args) == 0 {
return ErrNoArgs
}
return nil
}
32 changes: 32 additions & 0 deletions inserter/batch_opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package inserter

import "database/sql"

type BatchOpt func(*SQLBatch)

// WithTagName sets the tag name to look for in the struct. This is an override from the default tag "db"
func WithTagName(tagName string) BatchOpt {
return func(b *SQLBatch) {
b.tagName = tagName
}
}

// WithTable sets the table name to use in the SQL statement
func WithTable(table string) BatchOpt {
return func(b *SQLBatch) {
b.table = table
}
}

// WithDB sets the database connection to use
func WithDB(db *sql.DB) BatchOpt {
return func(b *SQLBatch) {
b.db = db
}
}

func WithResources(resources []any) BatchOpt {
return func(b *SQLBatch) {
b.resources = resources
}
}
115 changes: 115 additions & 0 deletions inserter/sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package inserter

import (
"database/sql"
"fmt"
"reflect"
"strings"
)

const (
// defaultTagName is the default tag name to look for in the struct
defaultTagName = "db"
)

func NewBatch(opts ...BatchOpt) *SQLBatch {
b := new(SQLBatch)
b.tagName = defaultTagName
for _, opt := range opts {
opt(b)
}

b.genBatch()

return b
}

func (b *SQLBatch) genBatch() {
uniqueFields := make(map[string]struct{})

for _, r := range b.resources {
// get the type of the resource
t := reflect.TypeOf(r)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}

// Is the type a struct?
if t.Kind() != reflect.Struct {
continue
}

// get the value of the resource
v := reflect.ValueOf(r)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

// get the fields
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get(b.tagName)
if tag == "-" {
continue
}

if !f.IsExported() {
continue
}

// if no tag is set, use the field name
if tag == "" {
tag = strings.ToLower(f.Name)
}

b.args = append(b.args, v.Field(i).Interface())

// if the field is not unique, skip it
if _, ok := uniqueFields[tag]; ok {
continue
}

// add the field to the list
b.fields = append(b.fields, tag)
uniqueFields[tag] = struct{}{}
}
}
}

func (b *SQLBatch) sqlGen() (string, []any, error) {
if err := b.validateSQLGen(); err != nil {
return "", nil, err
}

sqlBuilder := new(strings.Builder)

sqlBuilder.WriteString("INSERT INTO ")
sqlBuilder.WriteString(b.table)
sqlBuilder.WriteString(" (")
sqlBuilder.WriteString(strings.Join(b.fields, ", "))
sqlBuilder.WriteString(") VALUES ")

// We need to have the same number of "?" as fields and then repeat that for the number of resources
placeholder := strings.Repeat("?, ", len(b.fields))
placeholder = placeholder[:len(placeholder)-2] // Remove the trailing ", "
placeholder = "(" + placeholder + "), "

// Repeat the placeholder for the number of resources
placeholders := strings.Repeat(placeholder, len(b.resources))
sqlBuilder.WriteString(placeholders[:len(placeholders)-2]) // Remove the trailing ", " and add the closing ")"

return sqlBuilder.String(), b.args, nil
}

func (b *SQLBatch) Perform() (sql.Result, error) {
if err := b.validateSQLInsert(); err != nil {
return nil, fmt.Errorf("validate SQL generation: %w", err)
}

sqlStr, args, err := b.sqlGen()
if err != nil {
return nil, fmt.Errorf("generate SQL: %w", err)
}

return b.db.Exec(sqlStr, args...)
}
Loading

0 comments on commit 53df004

Please sign in to comment.