Skip to content

Commit

Permalink
Merge pull request #167 from uptrace/feature/support-mariadb
Browse files Browse the repository at this point in the history
Add support for MariaDB
  • Loading branch information
vmihailenco authored Sep 8, 2021
2 parents b36fdb7 + 4198224 commit 74477c4
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 53 deletions.
15 changes: 14 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,24 @@ jobs:
options: >-
--health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s
--health-retries=3
mariadb:
image: mariadb:10.6
env:
MARIADB_DATABASE: test
MARIADB_USER: user
MARIADB_PASSWORD: pass
MARIADB_ROOT_PASSWORD: pass
ports:
- 13306:3306
options: >-
--health-cmd="mariadb-admin ping" --health-interval=10s --health-timeout=5s
--health-retries=3
steps:
- name: Set up ${{ matrix.go-version }}
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
mysql-version: ${{ matrix.mysql-version }}

- name: Checkout code
uses: actions/checkout@v2
Expand All @@ -63,3 +74,5 @@ jobs:
env:
PG: postgres://postgres:postgres@localhost/postgres?sslmode=disable
MYSQL: user:pass@/test
MYSQL5: user:pass@tcp(localhost:53306)/test
MARIADB: user:pass@tcp(localhost:13306)/test
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ packages that share one repo with the core.
Main features are:

- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql),
[MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql),
[MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB),
[SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite).
- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars.
- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert).
Expand Down
9 changes: 3 additions & 6 deletions dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ func (n Name) String() string {
return "pg"
case SQLite:
return "sqlite"
case MySQL5:
return "mysql5"
case MySQL8:
return "mysql8"
case MySQL:
return "mysql"
default:
return "invalid"
}
Expand All @@ -21,6 +19,5 @@ const (
Invalid Name = iota
PG
SQLite
MySQL5
MySQL8
MySQL
)
5 changes: 2 additions & 3 deletions dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ import "github.com/uptrace/bun/internal"

type Feature = internal.Flag

const DefaultFeatures = Returning | TableCascade

const (
Returning Feature = 1 << iota
CTE Feature = 1 << iota
Returning
DefaultPlaceholder
DoubleColonCast
ValuesRow
Expand Down
12 changes: 6 additions & 6 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import (
const datetimeType = "DATETIME"

type Dialect struct {
name dialect.Name

tables *schema.Tables
features feature.Feature

Expand All @@ -30,7 +28,6 @@ type Dialect struct {

func New() *Dialect {
d := new(Dialect)
d.name = dialect.MySQL5
d.tables = schema.NewTables(d)
d.features = feature.AutoIncrement |
feature.DefaultPlaceholder |
Expand All @@ -48,10 +45,13 @@ func (d *Dialect) Init(db *sql.DB) {
return
}

if strings.Contains(version, "MariaDB") {
return
}

version = semver.MajorMinor("v" + cleanupVersion(version))
if semver.Compare(version, "v8.0") >= 0 {
d.name = dialect.MySQL8
d.features |= feature.DeleteTableAlias
d.features |= feature.CTE | feature.DeleteTableAlias
}
}

Expand All @@ -63,7 +63,7 @@ func cleanupVersion(s string) string {
}

func (d *Dialect) Name() dialect.Name {
return d.name
return dialect.MySQL
}

func (d *Dialect) Features() feature.Feature {
Expand Down
3 changes: 2 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ type Dialect struct {
func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.Returning |
d.features = feature.CTE |
feature.Returning |
feature.DefaultPlaceholder |
feature.DoubleColonCast |
feature.InsertTableAlias |
Expand Down
5 changes: 4 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ type Dialect struct {
func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.Returning | feature.InsertTableAlias | feature.DeleteTableAlias
d.features = feature.CTE |
feature.Returning |
feature.InsertTableAlias |
feature.DeleteTableAlias
return d
}

Expand Down
2 changes: 1 addition & 1 deletion extra/bunotel/otel.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func dbSystem(db *bun.DB) string {
switch db.Dialect().Name() {
case dialect.PG:
return "postgresql"
case dialect.MySQL5, dialect.MySQL8:
case dialect.MySQL:
return "mysql"
case dialect.SQLite:
return "sqlite"
Expand Down
87 changes: 62 additions & 25 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/mysqldialect"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/dialect/sqlitedialect"
Expand All @@ -29,12 +30,22 @@ import (

var ctx = context.TODO()

const (
pgName = "pg"
pgxName = "pgx"
mysql5Name = "mysql5"
mysql8Name = "mysql8"
mariadbName = "mariadb"
sqliteName = "sqlite"
)

var allDBs = map[string]func(tb testing.TB) *bun.DB{
"pg": pg,
"pgx": pgx,
"mysql8": mysql8,
"mysql5": mysql5,
"sqlite": sqlite,
pgName: pg,
pgxName: pgx,
mysql5Name: mysql5,
mysql8Name: mysql8,
mariadbName: mariadb,
sqliteName: sqlite,
}

func pg(tb testing.TB) *bun.DB {
Expand Down Expand Up @@ -93,7 +104,7 @@ func mysql8(tb testing.TB) *bun.DB {
})

db := bun.NewDB(sqldb, mysqldialect.New())
require.Equal(tb, "DB<dialect=mysql8>", db.String())
require.Equal(tb, "DB<dialect=mysql>", db.String())

if _, ok := os.LookupEnv("DEBUG"); ok {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose()))
Expand All @@ -115,7 +126,29 @@ func mysql5(tb testing.TB) *bun.DB {
})

db := bun.NewDB(sqldb, mysqldialect.New())
require.Equal(tb, "DB<dialect=mysql5>", db.String())
require.Equal(tb, "DB<dialect=mysql>", db.String())

if _, ok := os.LookupEnv("DEBUG"); ok {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose()))
}

return db
}

func mariadb(tb testing.TB) *bun.DB {
dsn := os.Getenv("MYSQL5")
if dsn == "" {
dsn = "user:pass@tcp(localhost:13306)/test"
}

sqldb, err := sql.Open("mysql", dsn)
require.NoError(tb, err)
tb.Cleanup(func() {
assert.NoError(tb, sqldb.Close())
})

db := bun.NewDB(sqldb, mysqldialect.New())
require.Equal(tb, "DB<dialect=mysql>", db.String())

if _, ok := os.LookupEnv("DEBUG"); ok {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose()))
Expand All @@ -141,11 +174,10 @@ func sqlite(tb testing.TB) *bun.DB {
return db
}

func testEachDB(t *testing.T, f func(t *testing.T, db *bun.DB)) {
for name, newDB := range allDBs {
t.Run(name, func(t *testing.T) {
db := newDB(t)
f(t, db)
func testEachDB(t *testing.T, f func(t *testing.T, dbName string, db *bun.DB)) {
for dbName, newDB := range allDBs {
t.Run(dbName, func(t *testing.T) {
f(t, dbName, newDB(t))
})
}
}
Expand Down Expand Up @@ -192,7 +224,7 @@ func TestDB(t *testing.T) {
{testInterfaceJSON},
}

testEachDB(t, func(t *testing.T, db *bun.DB) {
testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
for _, test := range tests {
t.Run(funcName(test.run), func(t *testing.T) {
test.run(t, db)
Expand Down Expand Up @@ -223,7 +255,7 @@ func testSelectScan(t *testing.T, db *bun.DB) {
}

func testSelectCount(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand All @@ -249,13 +281,18 @@ func testSelectMap(t *testing.T, db *bun.DB) {
ColumnExpr("10 AS num").
Scan(ctx, &m)
require.NoError(t, err)
require.Equal(t, map[string]interface{}{
"num": int64(10),
}, m)
switch v := m["num"]; v.(type) {
case int32:
require.Equal(t, int32(10), v)
case int64:
require.Equal(t, int64(10), v)
default:
t.Fail()
}
}

func testSelectMapSlice(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand Down Expand Up @@ -354,7 +391,7 @@ func testSelectNestedStructPtr(t *testing.T, db *bun.DB) {
}

func testSelectStructSlice(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand All @@ -381,7 +418,7 @@ func testSelectStructSlice(t *testing.T, db *bun.DB) {
}

func testSelectSingleSlice(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand All @@ -401,7 +438,7 @@ func testSelectSingleSlice(t *testing.T, db *bun.DB) {
}

func testSelectMultiSlice(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand Down Expand Up @@ -506,7 +543,7 @@ func testScanSingleRow(t *testing.T, db *bun.DB) {
}

func testScanSingleRowByRow(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand Down Expand Up @@ -540,7 +577,7 @@ func testScanSingleRowByRow(t *testing.T, db *bun.DB) {
}

func testScanRows(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MySQL5 {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}

Expand Down Expand Up @@ -628,7 +665,7 @@ func testJSONSpecialChars(t *testing.T, db *bun.DB) {
err = db.NewSelect().Model(model).Scan(ctx)
require.NoError(t, err)
switch db.Dialect().Name() {
case dialect.MySQL5, dialect.MySQL8:
case dialect.MySQL:
require.Equal(t, map[string]interface{}{
"hello": "\x00world\nworld\x00",
}, model.Attrs)
Expand Down Expand Up @@ -726,7 +763,7 @@ func testFKViolation(t *testing.T, db *bun.DB) {

func testInterfaceAny(t *testing.T, db *bun.DB) {
switch db.Dialect().Name() {
case dialect.MySQL5, dialect.MySQL8:
case dialect.MySQL:
t.Skip()
}

Expand Down
14 changes: 14 additions & 0 deletions internal/dbtest/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ services:
retries: 5
ports:
- 5432:5432
mariadb:
image: mariadb:10.6
environment:
- MARIADB_DATABASE=test
- MARIADB_USER=user
- MARIADB_PASSWORD=pass
- MARIADB_ROOT_PASSWORD=pass
ports:
- 13306:3306
healthcheck:
test: ['CMD', 'mariadb-admin', 'ping']
timeout: 5s
interval: 10s
retries: 3
2 changes: 1 addition & 1 deletion internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestMigrate(t *testing.T) {
{run: testMigrateUpError},
}

testEachDB(t, func(t *testing.T, db *bun.DB) {
testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
for _, test := range tests {
t.Run(funcName(test.run), func(t *testing.T) {
test.run(t, db)
Expand Down
2 changes: 1 addition & 1 deletion internal/dbtest/model_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestModelHook(t *testing.T) {
testEachDB(t, testModelHook)
}

func testModelHook(t *testing.T, db *bun.DB) {
func testModelHook(t *testing.T, dbName string, db *bun.DB) {
_, err := db.NewDropTable().Model((*ModelHookTest)(nil)).IfExists().Exec(ctx)
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 74477c4

Please sign in to comment.