diff --git a/command.go b/command.go index 693bd43..6098505 100644 --- a/command.go +++ b/command.go @@ -25,7 +25,11 @@ func Command(owner, repository string, opts ...option) *cobra.Command { Args: cobra.MinimumNArgs(3), // TODO implicit update to latest RunE: func(cmd *cobra.Command, args []string) error { if len(args) > 2 { // TODO test - c := New(owner, repo[args[0]], append([]option{WithProgress(cmd.ErrOrStderr())}, opts...)...) + opts := append([]option{WithProgress(cmd.ErrOrStderr())}, opts...) + if cmd.Flag("force").Changed { + opts = append(opts, WithForce(true)) + } + c := New(owner, repo[args[0]], opts...) return c.Install(args[1], args[2]) } return nil @@ -33,7 +37,7 @@ func Command(owner, repository string, opts ...option) *cobra.Command { } cmd.Flags().BoolP("all", "a", false, "show all tags/assets") - // cmd.Flags().BoolP("force", "f", false, "force") // TODO force + cmd.Flags().BoolP("force", "f", false, "force") cmd.Flags().BoolP("help", "h", false, "help for selfupdate") // TODO use cobras help flag // cmd.Flags().Bool("no-verify", false, "disable checksum verification") // TODO disable verification diff --git a/selfupdate.go b/selfupdate.go index 1460517..5d2fbb2 100644 --- a/selfupdate.go +++ b/selfupdate.go @@ -23,6 +23,7 @@ type config struct { repo string binary string filter func(asset string) bool + force bool progress io.Writer t transport.Transport } @@ -57,6 +58,12 @@ func WithAssetFilter(f func(s string) bool) func(c *config) { } } +func WithForce(b bool) func(c *config) { + return func(c *config) { + c.force = b + } +} + func WithTransport(t transport.Transport) func(c *config) { return func(c *config) { c.t = t @@ -122,6 +129,21 @@ func (c config) Printf(format string, any ...any) { fmt.Fprintf(c.progress, "\x1b[1;2m"+format+"\x1b[0m", any...) } +func (c config) confirm(format string, any ...any) error { + if c.force { + return nil + } + fmt.Fprintf(os.Stderr, "\x1b[1;2m"+format+" [y/n]: \x1b[0m", any...) + var input string + if _, err := fmt.Scanln(&input); err != nil { + return err + } + if strings.ToLower(input) != "y" { + return errors.New("aborted") + } + return nil +} + func (c config) Install(tag, asset string) error { if !strings.HasSuffix(asset, ".tar.gz") && !strings.HasSuffix(asset, ".zip") { return errors.New("invalid extension [expected: .tar.gz|.zip]") // fail early @@ -140,6 +162,7 @@ func (c config) Install(tag, asset string) error { } defer f.Close() + c.Printf("downloading to %#v\n", tmpArchive.Name()) if err := c.Download(tag, asset, f); err != nil { return err } @@ -157,13 +180,24 @@ func (c config) Install(tag, asset string) error { } if _, err := os.Stat(binDir); errors.Is(err, os.ErrNotExist) { + if err := c.confirm("create directory %#v?", binDir); err != nil { + return err + } + c.Printf("creating directory %#v\n", binDir) if err := os.MkdirAll(binDir, os.ModePerm); err != nil { return err } } - fExecutable, err := os.Create(filepath.Join(binDir, c.binary+".selfupdate")) + tmpBinary := filepath.Join(binDir, c.binary+".selfupdate") + if _, err := os.Stat(tmpBinary); err == nil { + if err := c.confirm("overwrite %#v?", tmpBinary); err != nil { + return err + } + } + + fExecutable, err := os.Create(tmpBinary) if err != nil { return err } @@ -173,6 +207,7 @@ func (c config) Install(tag, asset string) error { if err := c.extract(tmpArchive.Name(), fExecutable); err != nil { return err } + defer os.Remove(fExecutable.Name()) if err := os.Chmod(fExecutable.Name(), 0755); err != nil { return err @@ -182,12 +217,19 @@ func (c config) Install(tag, asset string) error { return err } - c.Println("verifying format") + c.Printf("executing %#v\n", fExecutable.Name()+" --version") if err := c.verify(fExecutable.Name()); err != nil { return err } target := filepath.Join(binDir, c.binary) + + if _, err := os.Stat(target); err == nil { + if err := c.confirm("overwrite %#v?", target); err != nil { + return err + } + } + c.Printf("moving to %#v\n", target) if err = os.Rename(fExecutable.Name(), target); err != nil { return err @@ -221,7 +263,6 @@ func (c config) extract(source string, out io.Writer) error { } func (c config) Download(tag, asset string, out io.Writer) error { - c.Printf("downloading %#v\n", asset) return c.t.Download(c.repo, tag, asset, out, c.progress) }