diff --git a/dialect/mysqldialect/dialect.go b/dialect/mysqldialect/dialect.go index 959747d94..881aa7ebf 100644 --- a/dialect/mysqldialect/dialect.go +++ b/dialect/mysqldialect/dialect.go @@ -27,14 +27,17 @@ func init() { } } +type DialectOption func(d *Dialect) + type Dialect struct { schema.BaseDialect tables *schema.Tables features feature.Feature + loc *time.Location } -func New() *Dialect { +func New(opts ...DialectOption) *Dialect { d := new(Dialect) d.tables = schema.NewTables(d) d.features = feature.AutoIncrement | @@ -47,9 +50,24 @@ func New() *Dialect { feature.InsertOnDuplicateKey | feature.SelectExists | feature.CompositeIn + + for _, opt := range opts { + opt(d) + } + return d } +func WithTimeLocation(loc string) DialectOption { + return func(d *Dialect) { + location, err := time.LoadLocation(loc) + if err != nil { + panic(fmt.Errorf("mysqldialect can't load provided location %s: %s", loc, err)) + } + d.loc = location + } +} + func (d *Dialect) Init(db *sql.DB) { var version string if err := db.QueryRow("SELECT version()").Scan(&version); err != nil { @@ -103,9 +121,13 @@ func (d *Dialect) IdentQuote() byte { return '`' } -func (*Dialect) AppendTime(b []byte, tm time.Time) []byte { +func (d *Dialect) AppendTime(b []byte, tm time.Time) []byte { b = append(b, '\'') - b = tm.AppendFormat(b, "2006-01-02 15:04:05.999999") + if d.loc != nil { + b = tm.In(d.loc).AppendFormat(b, "2006-01-02 15:04:05.999999") + } else { + b = tm.AppendFormat(b, "2006-01-02 15:04:05.999999") + } b = append(b, '\'') return b }