Skip to content

Commit

Permalink
[v16] Support double dash delimiter in tsh ssh (#47493)
Browse files Browse the repository at this point in the history
* Support double dash delimiter in tsh ssh

This PR extends the tsh ssh command by adding support for the
double dash (--) delimiter before remote commands (e.g.
`tsh ssh -- echo test`), aligning its behavior with the standard
ssh binary. This improves compatibility with tools that rely on the
standard ssh binary behavior, such as sshuttle.

Fixes #18453, #16589.

Signed-off-by: Tim Ross <[email protected]>

* Support double dash delimiter in tsh ssh

This PR extends the tsh ssh command by adding support for the
double dash (--) delimiter before remote commands (e.g.
`tsh ssh -- echo test`), aligning its behavior with the standard
ssh binary. This improves compatibility with tools that rely on the
standard ssh binary behavior, such as sshuttle.

Fixes #18453, #16589.

Signed-off-by: Tim Ross <[email protected]>

* fix: update tests

---------

Signed-off-by: Tim Ross <[email protected]>
Co-authored-by: Samir Aguiar <[email protected]>
  • Loading branch information
rosstimothy and ns-sjorgedeaguiar authored Nov 4, 2024
1 parent 2133890 commit 9eac207
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3754,6 +3754,11 @@ func onSSH(cf *CLIConf) error {

tc.AllowHeadless = true

// Support calling `tsh ssh -- <command>` (with a double dash before the command)
if len(cf.RemoteCommand) > 0 && strings.TrimSpace(cf.RemoteCommand[0]) == "--" {
cf.RemoteCommand = cf.RemoteCommand[1:]
}

tc.Stdin = os.Stdin
err = retryWithAccessRequest(cf, tc, func() error {
sshFunc := func() error {
Expand Down
177 changes: 177 additions & 0 deletions tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2303,6 +2303,183 @@ func TestAccessRequestOnLeaf(t *testing.T) {
require.NoError(t, err)
}

// TestSSHCommand tests that a user can access a single SSH node and run commands.
func TestSSHCommands(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

accessRoleName := "access"
sshHostname := "test-ssh-server"

accessUser, err := types.NewUser(accessRoleName)
require.NoError(t, err)
accessUser.SetRoles([]string{accessRoleName})

user, err := user.Current()
require.NoError(t, err)
accessUser.SetLogins([]string{user.Username})

traits := map[string][]string{
constants.TraitLogins: {user.Username},
}
accessUser.SetTraits(traits)

connector := mockConnector(t)
rootServerOpts := []testserver.TestServerOptFunc{
testserver.WithBootstrap(connector, accessUser),
testserver.WithHostname(sshHostname),
testserver.WithClusterName(t, "root"),
testserver.WithSSHLabel(accessRoleName, "true"),
testserver.WithSSHPublicAddrs("127.0.0.1:0"),
testserver.WithConfig(func(cfg *servicecfg.Config) {
cfg.SSH.Enabled = true
cfg.SSH.PublicAddrs = []utils.NetAddr{cfg.SSH.Addr}
cfg.SSH.DisableCreateHostUser = true
}),
}
rootServer := testserver.MakeTestServer(t, rootServerOpts...)

rootProxyAddr, err := rootServer.ProxyWebAddr()
require.NoError(t, err)

require.EventuallyWithT(t, func(t *assert.CollectT) {
rootNodes, err := rootServer.GetAuthServer().GetNodes(ctx, apidefaults.Namespace)
if !assert.NoError(t, err) || !assert.Len(t, rootNodes, 1) {
return
}
}, 10*time.Second, 100*time.Millisecond)

tmpHomePath := t.TempDir()
rootAuth := rootServer.GetAuthServer()

err = Run(ctx, []string{
"login",
"--insecure",
"--proxy", rootProxyAddr.String(),
"--user", user.Username,
}, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth, accessUser, connector.GetName()))
require.NoError(t, err)

tests := []struct {
name string
args []string
expected string
shouldErr bool
}{
{
// Test that a simple echo works.
name: "ssh simple command",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo",
"this is a test message",
},
shouldErr: false,
},
{
// Test that commands can be prefixed with a double dash.
name: "ssh command with double dash",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"--",
"echo",
"this is a test message",
},
shouldErr: false,
},
{
// Test that a double dash is not removed from the middle of a command.
name: "ssh command with double dash in the middle",
expected: "-- this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo",
"--",
"this is a test message",
},
shouldErr: false,
},
{
// Test that quoted commands work (e.g. `tsh ssh 'echo test'`)
name: "ssh command literal",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo this is a test message",
},
shouldErr: false,
},
{
// Test that a double dash is passed as-is in a quoted command (which should fail).
name: "ssh command literal with double dash err",
expected: "",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"-- echo this is a test message",
},
shouldErr: true,
},
{
// Test that a double dash is not removed from the middle of a quoted command.
name: "ssh command literal with double dash in the middle",
expected: "-- this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo", "-- this is a test message",
},
shouldErr: false,
},
{
// Test tsh ssh -- hostname command
name: "delimiter before host and command",
expected: "this is a test message",
args: []string{
"--", sshHostname, "echo", "this is a test message",
},
shouldErr: false,
},
}

for _, test := range tests {
test := test
ctx := context.Background()
t.Run(test.name, func(t *testing.T) {
t.Parallel()

stdout := &output{buf: bytes.Buffer{}}
stderr := &output{buf: bytes.Buffer{}}
args := append(
[]string{
"ssh",
"--insecure",
"--proxy", rootProxyAddr.String(),
},
test.args...,
)

err := Run(ctx, args, setHomePath(tmpHomePath),
func(conf *CLIConf) error {
conf.overrideStdin = &bytes.Buffer{}
conf.OverrideStdout = stdout
conf.overrideStderr = stderr
return nil
},
)

if test.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, strings.TrimSpace(stdout.String()))
require.Empty(t, stderr.String())
}
})
}
}

// tryCreateTrustedCluster performs several attempts to create a trusted cluster,
// retries on connection problems and access denied errors to let caches
// propagate and services to start
Expand Down

0 comments on commit 9eac207

Please sign in to comment.