Skip to content

Commit

Permalink
added confirmation
Browse files Browse the repository at this point in the history
  • Loading branch information
rsteube committed Nov 30, 2024
1 parent c9416d1 commit 2923c31
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
8 changes: 6 additions & 2 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@ func Command(owner, repository string, opts ...option) *cobra.Command {
Args: cobra.MinimumNArgs(3), // TODO implicit update to latest
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) > 2 { // TODO test
c := New(owner, repo[args[0]], append([]option{WithProgress(cmd.ErrOrStderr())}, opts...)...)
opts := append([]option{WithProgress(cmd.ErrOrStderr())}, opts...)
if cmd.Flag("force").Changed {
opts = append(opts, WithForce(true))
}
c := New(owner, repo[args[0]], opts...)
return c.Install(args[1], args[2])
}
return nil
},
}

cmd.Flags().BoolP("all", "a", false, "show all tags/assets")
// cmd.Flags().BoolP("force", "f", false, "force") // TODO force
cmd.Flags().BoolP("force", "f", false, "force")
cmd.Flags().BoolP("help", "h", false, "help for selfupdate") // TODO use cobras help flag
// cmd.Flags().Bool("no-verify", false, "disable checksum verification") // TODO disable verification

Expand Down
47 changes: 44 additions & 3 deletions selfupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type config struct {
repo string
binary string
filter func(asset string) bool
force bool
progress io.Writer
t transport.Transport
}
Expand Down Expand Up @@ -57,6 +58,12 @@ func WithAssetFilter(f func(s string) bool) func(c *config) {
}
}

func WithForce(b bool) func(c *config) {
return func(c *config) {
c.force = b
}
}

func WithTransport(t transport.Transport) func(c *config) {
return func(c *config) {
c.t = t
Expand Down Expand Up @@ -122,6 +129,21 @@ func (c config) Printf(format string, any ...any) {
fmt.Fprintf(c.progress, "\x1b[1;2m"+format+"\x1b[0m", any...)
}

func (c config) confirm(format string, any ...any) error {
if c.force {
return nil
}
fmt.Fprintf(os.Stderr, "\x1b[1;2m"+format+" [y/n]: \x1b[0m", any...)
var input string
if _, err := fmt.Scanln(&input); err != nil {
return err
}
if strings.ToLower(input) != "y" {
return errors.New("aborted")
}
return nil
}

func (c config) Install(tag, asset string) error {
if !strings.HasSuffix(asset, ".tar.gz") && !strings.HasSuffix(asset, ".zip") {
return errors.New("invalid extension [expected: .tar.gz|.zip]") // fail early
Expand All @@ -140,6 +162,7 @@ func (c config) Install(tag, asset string) error {
}
defer f.Close()

c.Printf("downloading to %#v\n", tmpArchive.Name())
if err := c.Download(tag, asset, f); err != nil {
return err
}
Expand All @@ -157,13 +180,24 @@ func (c config) Install(tag, asset string) error {
}

if _, err := os.Stat(binDir); errors.Is(err, os.ErrNotExist) {
if err := c.confirm("create directory %#v?", binDir); err != nil {
return err
}

c.Printf("creating directory %#v\n", binDir)
if err := os.MkdirAll(binDir, os.ModePerm); err != nil {
return err
}
}

fExecutable, err := os.Create(filepath.Join(binDir, c.binary+".selfupdate"))
tmpBinary := filepath.Join(binDir, c.binary+".selfupdate")
if _, err := os.Stat(tmpBinary); err == nil {
if err := c.confirm("overwrite %#v?", tmpBinary); err != nil {
return err
}
}

fExecutable, err := os.Create(tmpBinary)
if err != nil {
return err
}
Expand All @@ -173,6 +207,7 @@ func (c config) Install(tag, asset string) error {
if err := c.extract(tmpArchive.Name(), fExecutable); err != nil {
return err
}
defer os.Remove(fExecutable.Name())

if err := os.Chmod(fExecutable.Name(), 0755); err != nil {
return err
Expand All @@ -182,12 +217,19 @@ func (c config) Install(tag, asset string) error {
return err
}

c.Println("verifying format")
c.Printf("executing %#v\n", fExecutable.Name()+" --version")
if err := c.verify(fExecutable.Name()); err != nil {
return err
}

target := filepath.Join(binDir, c.binary)

if _, err := os.Stat(target); err == nil {
if err := c.confirm("overwrite %#v?", target); err != nil {
return err
}
}

c.Printf("moving to %#v\n", target)
if err = os.Rename(fExecutable.Name(), target); err != nil {
return err
Expand Down Expand Up @@ -221,7 +263,6 @@ func (c config) extract(source string, out io.Writer) error {
}

func (c config) Download(tag, asset string, out io.Writer) error {
c.Printf("downloading %#v\n", asset)
return c.t.Download(c.repo, tag, asset, out, c.progress)
}

Expand Down

0 comments on commit 2923c31

Please sign in to comment.