diff --git a/internal/dumper/dumper.go b/internal/dumper/dumper.go index d651d61e..ad04c378 100644 --- a/internal/dumper/dumper.go +++ b/internal/dumper/dumper.go @@ -157,7 +157,7 @@ func (d *Dumper) Run(ctx context.Context) error { zap.Int("thread_conn_id", conn.ID), ) - err := d.dumpTable(conn, database, table, d.cfg.UseReplica, d.cfg.UseRdonly) + err := d.dumpTable(ctx, conn, database, table, d.cfg.UseReplica, d.cfg.UseRdonly) if err != nil { d.log.Error("error dumping table", zap.Error(err)) } @@ -219,7 +219,7 @@ func (d *Dumper) dumpTableSchema(conn *Connection, database string, table string } // Dump a table in "MySQL" (multi-inserts) format -func (d *Dumper) dumpTable(conn *Connection, database string, table string, useReplica, useRdonly bool) error { +func (d *Dumper) dumpTable(ctx context.Context, conn *Connection, database string, table string, useReplica, useRdonly bool) error { var allBytes uint64 var allRows uint64 var where string @@ -278,6 +278,11 @@ func (d *Dumper) dumpTable(conn *Connection, database string, table string, useR return err } + // Allows for quicker exit when using Ctrl+C at the Terminal: + if ctx.Err() != nil { + return ctx.Err() + } + values := make([]string, 0, 16) for _, v := range row { if v.Raw() == nil { diff --git a/internal/dumper/loader.go b/internal/dumper/loader.go index 3298dcc1..66939c8c 100644 --- a/internal/dumper/loader.go +++ b/internal/dumper/loader.go @@ -85,7 +85,7 @@ func (l *Loader) Run(ctx context.Context) error { eg.Go(func() error { defer pool.Put(conn) - r, err := l.restoreTable(table, conn) + r, err := l.restoreTable(ctx, table, conn) if err != nil { return err } @@ -234,7 +234,7 @@ func (l *Loader) restoreTableSchema(overwrite bool, tables []string, conn *Conne return nil } -func (l *Loader) restoreTable(table string, conn *Connection) (int, error) { +func (l *Loader) restoreTable(ctx context.Context, table string, conn *Connection) (int, error) { bytes := 0 part := "0" base := filepath.Base(table) @@ -277,6 +277,11 @@ func (l *Loader) restoreTable(table string, conn *Connection) (int, error) { querys := strings.Split(query1, ";\n") bytes = len(query1) for _, query := range querys { + // Allows for quicker exit when using Ctrl+C at the Terminal: + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !strings.HasPrefix(query, "/*") && query != "" { err = conn.Execute(query) if err != nil {