Skip to content

Commit

Permalink
Task/test expansion (yoyo-project#7)
Browse files Browse the repository at this point in the history
* Add tests for base dialect implementation

* Expand testing and some light decoupling

* adding more tests for mysql reverser

* Expanded tests

- Light refactoring
- added validation
  • Loading branch information
dotvezz authored Dec 6, 2020
1 parent 0ef4193 commit dd766c2
Show file tree
Hide file tree
Showing 14 changed files with 1,485 additions and 97 deletions.
5 changes: 4 additions & 1 deletion cmd/yoyo/main.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package main

import (
"database/sql"
"fmt"
"github.com/dotvezz/lime"
"github.com/dotvezz/lime/cli"
"github.com/dotvezz/lime/options"
"github.com/dotvezz/yoyo/internal/dbms/mysql"
"github.com/dotvezz/yoyo/internal/dbms/postgres"
"os"
)

Expand All @@ -23,7 +26,7 @@ func main() {
//},
lime.Command{
Keyword: "reverse",
Func: initReverser(),
Func: initReverser(mysql.InitNewReverser(sql.Open), postgres.InitNewReverser(sql.Open)),
},
)
err := c.Run()
Expand Down
17 changes: 13 additions & 4 deletions cmd/yoyo/reverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"fmt"
"github.com/dotvezz/lime"
"github.com/dotvezz/yoyo/env"
"github.com/dotvezz/yoyo/internal/dbms/dialect"
"github.com/dotvezz/yoyo/internal/reverse"
"github.com/dotvezz/yoyo/internal/yoyo"
)

func initReverser() lime.Func {
func initReverser(newMysqlReverser, newPostgresReverser func(host, userName, dbName, password, port string) (reverse.Reverser, error)) lime.Func {
return func(args []string) error {
var (
config yoyo.Config
Expand All @@ -23,11 +24,19 @@ func initReverser() lime.Func {
dia = config.Schema.Dialect
}

if err != nil {
return fmt.Errorf("unable to initialize: %w", err)
if dia == "" {
return fmt.Errorf("no dialect given")
}

switch dia {
case dialect.MySQL:
reverser, err = newMysqlReverser(env.DBHost(), env.DBUser(), env.DBName(), env.DBPassword(), env.DBPort())
case dialect.PostgreSQL:
reverser, err = newPostgresReverser(env.DBHost(), env.DBUser(), env.DBName(), env.DBPassword(), env.DBPort())
default:
err = fmt.Errorf("unknown dialect `%s`", dia)
}

reverser, err = reverse.LoadReverser(dia, env.DBHost(), env.DBUser(), env.DBName(), env.DBPassword(), env.DBPort())
if err != nil {
return fmt.Errorf("unable to initialize: %w", err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/dotvezz/yoyo
go 1.12

require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/dotvezz/lime v0.0.0-20190701000217-4127a8765ba8
github.com/go-sql-driver/mysql v1.5.0
github.com/lib/pq v1.8.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/dotvezz/lime v0.0.0-20190701000217-4127a8765ba8 h1:md23XMhF9u2fj3IRaYSvhlNfMTLCBBzstSQy17uZ5Tw=
github.com/dotvezz/lime v0.0.0-20190701000217-4127a8765ba8/go.mod h1:TyGfUraSwOyY4aLBtu0dgjg4DMutLx0hM9w8aRpGd6I=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
Expand Down
2 changes: 1 addition & 1 deletion internal/datatype/datatypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Datatype uint64
// These are the actual Datatype constants with all the metadata and unique identifiers encoded into them
const (
Integer = idInteger | metaNumeric | metaInteger | metaSignable
TinyInt = idTinyInt | metaNumeric | metaSignable
TinyInt = idTinyInt | metaNumeric | metaInteger | metaSignable
SmallInt = idSmallInt | metaNumeric | metaInteger | metaSignable
MediumInt = idMediumInt | metaNumeric | metaInteger | metaSignable
BigInt = idBigInt | metaNumeric | metaInteger | metaSignable
Expand Down
58 changes: 32 additions & 26 deletions internal/dbms/mysql/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ type migrator struct {
validator
}

func (d *migrator) TypeString(dt datatype.Datatype) (s string, err error) {
if !d.SupportsDatatype(dt) {
return "", errors.New("unsupported datatype")
}
func (m *migrator) TypeString(dt datatype.Datatype) (s string, err error) {
switch dt {
case datatype.Integer:
s = "INT"
default:
s, err = d.Base.TypeString(dt)
s, err = m.Base.TypeString(dt)
}
if err == nil && !m.validator.SupportsDatatype(dt) {
err = errors.New("unsupported datatype")
}
return s, err
}

// CreateTable generates a query to create a given table.
func (d *migrator) CreateTable(table string, t schema.Table) string {
func (m *migrator) CreateTable(tName string, t schema.Table) string {
sb := strings.Builder{}

sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", table))
sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", tName))

var (
first = true
Expand All @@ -53,50 +53,51 @@ func (d *migrator) CreateTable(table string, t schema.Table) string {
} else {
first = false
}
sb.WriteString(d.generateColumn(colName, c))
sb.WriteString(" ")
sb.WriteString(m.generateColumn(colName, c))
if c.PrimaryKey {
pks = append(pks, colName)
}
}

if len(pks) > 0 {
sb.WriteString(fmt.Sprintf("\nPRIMARY KEY (%s)", strings.Join(pks, ",")))
sb.WriteString(fmt.Sprintf("\n PRIMARY KEY (`%s`)", strings.Join(pks, ",")))
}

sb.WriteString("\n);")

return sb.String()
}

func (d *migrator) AddColumn(table, column string, c schema.Column) string {
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s;", table, d.generateColumn(column, c))
func (m *migrator) AddColumn(tName, cName string, c schema.Column) string {
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s;", tName, m.generateColumn(cName, c))
}

// AddIndex returns a string query which adds the specified index to a table
func (d *migrator) AddIndex(table, index string, i schema.Index) string {
func (m *migrator) AddIndex(tName, iName string, i schema.Index) string {
var indexType string

switch {
case i.Unique:
indexType = " UNIQUE INDEX"
indexType = "UNIQUE INDEX"
default:
indexType = " INDEX"
indexType = "INDEX"
}

cols := strings.Builder{}
firstCol := true
for _, col := range i.Columns {
if !firstCol {
cols.WriteRune(',')
cols.WriteString(", ")
}
firstCol = false
cols.WriteString(fmt.Sprintf("%s", col))
cols.WriteString(fmt.Sprintf("`%s`", col))
}

return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (%s);\n", table, indexType, index, cols.String())
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (%s);", tName, indexType, iName, cols.String())
}

func (d *migrator) AddReference(table, referencedTable string, db schema.Database, r schema.Reference) (string, error) {
func (m *migrator) AddReference(table, referencedTable string, db schema.Database, r schema.Reference) (string, error) {
sw := strings.Builder{}

if r.HasMany { // swap the tables if it's a HasMany
Expand Down Expand Up @@ -135,7 +136,7 @@ func (d *migrator) AddReference(table, referencedTable string, db schema.Databas
fknames = append(fknames, fkname)
fcols = append(fcols, cname)

sw.WriteString(d.AddColumn(table, fkname, col))
sw.WriteString(m.AddColumn(table, fkname, col))
sw.WriteRune('\n')
}

Expand All @@ -162,20 +163,21 @@ func (d *migrator) AddReference(table, referencedTable string, db schema.Databas
return sw.String(), nil
}

func (d *migrator) generateColumn(name string, c schema.Column) string {
func (m *migrator) generateColumn(cName string, c schema.Column) string {
sb := strings.Builder{}
ts, _ := m.TypeString(c.Datatype)

if c.Datatype.RequiresScale() {
sb.WriteString(fmt.Sprintf("`%s` %s(%d, %d)", name, c.Datatype, c.Scale, c.Precision))
sb.WriteString(fmt.Sprintf("`%s` %s(%d, %d)", cName, ts, c.Scale, c.Precision))
} else {
if c.Scale > 0 {
if c.Precision > 0 {
sb.WriteString(fmt.Sprintf("`%s` %s(%d, %d)", name, c.Datatype, c.Scale, c.Precision))
sb.WriteString(fmt.Sprintf("`%s` %s(%d, %d)", cName, ts, c.Scale, c.Precision))
} else {
sb.WriteString(fmt.Sprintf("`%s` %s(%d)", name, c.Datatype, c.Scale))
sb.WriteString(fmt.Sprintf("`%s` %s(%d)", cName, ts, c.Scale))
}
} else {
sb.WriteString(fmt.Sprintf("`%s` %s", name, c.Datatype))
sb.WriteString(fmt.Sprintf("`%s` %s", cName, ts))
}
}

Expand All @@ -190,9 +192,9 @@ func (d *migrator) generateColumn(name string, c schema.Column) string {
if c.Default != nil {
sb.WriteString(` DEFAULT `)
if c.Datatype.IsString() {
sb.WriteString(fmt.Sprintf(`"%s" `, *c.Default))
sb.WriteString(fmt.Sprintf(`"%s"`, *c.Default))
} else {
sb.WriteString(fmt.Sprintf("%s ", *c.Default))
sb.WriteString(fmt.Sprintf("%s", *c.Default))
}
} else if c.Nullable {
sb.WriteString(` DEFAULT NULL`)
Expand All @@ -203,5 +205,9 @@ func (d *migrator) generateColumn(name string, c schema.Column) string {
}
sb.WriteString(" NULL")

if c.AutoIncrement {
sb.WriteString(" AUTO_INCREMENT")
}

return sb.String()
}
Loading

0 comments on commit dd766c2

Please sign in to comment.