diff --git a/pkg/app/create.go b/pkg/app/create.go index 35b580016..3110acba2 100644 --- a/pkg/app/create.go +++ b/pkg/app/create.go @@ -27,6 +27,7 @@ import ( "github.com/tensorchord/envd/pkg/envd" "github.com/tensorchord/envd/pkg/home" + "github.com/tensorchord/envd/pkg/lang/ir" "github.com/tensorchord/envd/pkg/ssh" sshconfig "github.com/tensorchord/envd/pkg/ssh/config" "github.com/tensorchord/envd/pkg/types" @@ -176,7 +177,8 @@ func create(clicontext *cli.Context) error { } go func() { - if err := sshClient.Attach(); err != nil { + if err := sshClient.Attach(ir.DefaultGraph.Shell, + ir.DefaultGraph.EnvironmentName); err != nil { outputChannel <- errors.Wrap(err, "failed to attach to the container") } outputChannel <- nil diff --git a/pkg/envd/docker.go b/pkg/envd/docker.go index d7d9d3d48..a83bee116 100644 --- a/pkg/envd/docker.go +++ b/pkg/envd/docker.go @@ -269,7 +269,8 @@ func (e dockerEngine) Attach(name, iface, privateKeyPath string, } opt.Server = iface - if err := sshClient.Attach(); err != nil { + if err := sshClient.Attach( + ir.DefaultGraph.Shell, ir.DefaultGraph.EnvironmentName); err != nil { return errors.Wrap(err, "failed to attach to the container") } return nil diff --git a/pkg/envd/envdserver.go b/pkg/envd/envdserver.go index d34ca385d..eca7ece31 100644 --- a/pkg/envd/envdserver.go +++ b/pkg/envd/envdserver.go @@ -27,6 +27,7 @@ import ( "github.com/tensorchord/envd-server/errdefs" "github.com/tensorchord/envd-server/sshname" + "github.com/tensorchord/envd/pkg/lang/ir" "github.com/tensorchord/envd/pkg/ssh" sshconfig "github.com/tensorchord/envd/pkg/ssh/config" "github.com/tensorchord/envd/pkg/types" @@ -180,7 +181,8 @@ func (e envdServerEngine) Attach(name, iface, privateKeyPath string, startResult } go func() { - if err := sshClient.Attach(); err != nil { + if err := sshClient.Attach(ir.DefaultGraph.Shell, + ir.DefaultGraph.EnvironmentName); err != nil { outputChannel <- errors.Wrap(err, "failed to attach to the container") } outputChannel <- nil diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index d1edc402a..5c53e353f 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -35,12 +35,11 @@ import ( "golang.org/x/crypto/ssh/agent" "golang.org/x/term" - "github.com/tensorchord/envd/pkg/lang/ir" "github.com/tensorchord/envd/pkg/ssh/config" ) type Client interface { - Attach() error + Attach(shell, envName string) error ExecWithOutput(cmd string) ([]byte, error) LocalForward(localAddress, targetAddress string) error Close() error @@ -181,7 +180,7 @@ func (c generalClient) ExecWithOutput(cmd string) ([]byte, error) { return session.CombinedOutput(cmd) } -func (c generalClient) Attach() error { +func (c generalClient) Attach(shell, envName string) error { // open session session, err := c.cli.NewSession() if err != nil { @@ -263,13 +262,12 @@ func (c generalClient) Attach() error { } }() - // TODO(gaocegege): Refactor it to avoid direct access to DefaultGraph - cmd := shellescape.QuoteCommand([]string{ir.DefaultGraph.Shell}) + cmd := shellescape.QuoteCommand([]string{shell}) logrus.Debugf("executing command over ssh: '%s'", cmd) err = session.Run(cmd) if err == nil { logrus.Infof("Detached successfully. You can attach to the container with command `ssh %s.envd`\n", - ir.DefaultGraph.EnvironmentName) + envName) return nil } if strings.Contains(err.Error(), "status 130") || strings.Contains(err.Error(), "4294967295") {