Skip to content

Commit

Permalink
Continue trying other Credentials when PrivateKeyCredential fails to …
Browse files Browse the repository at this point in the history
…load. (#276)
  • Loading branch information
tmds authored Dec 14, 2024
1 parent e4e54e2 commit ea3a348
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/Tmds.Ssh/AuthResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ enum AuthResult
{
Failure,
Success,
Partial
Partial,
Skipped
}
14 changes: 0 additions & 14 deletions src/Tmds.Ssh/PrivateKeyLoadException.cs

This file was deleted.

28 changes: 25 additions & 3 deletions src/Tmds.Ssh/SshSession.Authentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ private async Task AuthenticateAsync(SshConnection connection, CancellationToken

UserAuthContext context = new UserAuthContext(connection, _settings.UserName, _settings.PublicKeyAcceptedAlgorithms, _settings.MinimumRSAKeySize, Logger);

HashSet<Name>? rejectedMethods = null;
HashSet<Name>? failedMethods = null;
HashSet<Name>? skippedMethods = null;

int partialAuthAttempts = 0;
// Try credentials.
List<Credential> credentials = new(_settings.CredentialsOrDefault);
for (int i = 0; i < credentials.Count; i++)
{
Credential credential = credentials[i];

AuthResult authResult = AuthResult.Failure;
AuthResult authResult = AuthResult.Skipped;
bool? methodAccepted;
Name method;
if (credential is PasswordCredential passwordCredential)
Expand Down Expand Up @@ -72,6 +76,8 @@ private async Task AuthenticateAsync(SshConnection connection, CancellationToken
// We didn't try the method, skip to the next credential.
if (methodAccepted == false)
{
rejectedMethods ??= new();
rejectedMethods.Add(method);
continue;
}

Expand All @@ -80,8 +86,18 @@ private async Task AuthenticateAsync(SshConnection connection, CancellationToken
return;
}

if (authResult == AuthResult.Failure)
if (authResult is AuthResult.Failure or AuthResult.Skipped)
{
if (authResult == AuthResult.Failure)
{
failedMethods ??= new();
failedMethods.Add(method);
}
else
{
skippedMethods ??= new();
skippedMethods.Add(method);
}
// If we didn't know if the method was accepted before, check the context which was updated by SSH_MSG_USERAUTH_FAILURE.
if (methodAccepted == null)
{
Expand Down Expand Up @@ -122,7 +138,13 @@ bool TryMethod(Name credentialMethod)
}
}

throw new ConnectFailedException(ConnectFailedReason.AuthenticationFailed, "Authentication failed.", ConnectionInfo);
throw new ConnectFailedException(
ConnectFailedReason.AuthenticationFailed,
$"Authentication failed. {DescribeMethodListBehavior("failed", failedMethods)} {DescribeMethodListBehavior("were skipped", skippedMethods)} {DescribeMethodListBehavior("were rejected", rejectedMethods)}", ConnectionInfo);

static string DescribeMethodListBehavior(string state, IEnumerable<Name>? methods)
=> methods is null ? $"No methods {state}."
: $"These methods {state}: {string.Join(", ", methods)}.";
}

private static Packet CreateServiceRequestMessage(SequencePool sequencePool)
Expand Down
19 changes: 10 additions & 9 deletions src/Tmds.Ssh/UserAuthentication.PasswordAuth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ public sealed class PasswordAuth
public static async Task<AuthResult> TryAuthenticate(PasswordCredential passwordCredential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger<SshClient> logger, CancellationToken ct)
{
string? password = passwordCredential.GetPassword();
if (password is not null)

if (password is null)
{
context.StartAuth(AlgorithmNames.Password);
return AuthResult.Skipped;
}

logger.PasswordAuth();
context.StartAuth(AlgorithmNames.Password);

{
using var userAuthMsg = CreatePasswordRequestMessage(context.SequencePool, context.UserName, password);
await context.SendPacketAsync(userAuthMsg.Move(), ct).ConfigureAwait(false);
}
logger.PasswordAuth();

return await context.ReceiveAuthResultAsync(ct).ConfigureAwait(false);
{
using var userAuthMsg = CreatePasswordRequestMessage(context.SequencePool, context.UserName, password);
await context.SendPacketAsync(userAuthMsg.Move(), ct).ConfigureAwait(false);
}

return AuthResult.Failure;
return await context.ReceiveAuthResultAsync(ct).ConfigureAwait(false);
}

private static Packet CreatePasswordRequestMessage(SequencePool sequencePool, string userName, string password)
Expand Down
7 changes: 4 additions & 3 deletions src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ public static async Task<AuthResult> TryAuthenticate(PrivateKeyCredential keyCre
if (pk is null)
{
logger.PrivateKeyNotFound(keyCredential.Identifier);
return AuthResult.Failure;
return AuthResult.Skipped;
}
}
catch (Exception error)
{
logger.PrivateKeyCanNotLoad(keyCredential.Identifier, error);
throw new PrivateKeyLoadException(keyCredential.Identifier, error);
return AuthResult.Skipped;
}

using (pk)
Expand All @@ -36,7 +36,7 @@ public static async Task<AuthResult> TryAuthenticate(PrivateKeyCredential keyCre
if (rsaKey.KeySize < context.MinimumRSAKeySize)
{
// TODO: log
return AuthResult.Failure;
return AuthResult.Skipped;
}
}

Expand Down Expand Up @@ -69,6 +69,7 @@ public static async Task<AuthResult> TryAuthenticate(PrivateKeyCredential keyCre
if (!acceptedAlgorithm)
{
logger.PrivateKeyAlgorithmsNotAccepted(keyCredential.Identifier, context.PublicKeyAcceptedAlgorithms);
return AuthResult.Skipped;
}
}

Expand Down
2 changes: 0 additions & 2 deletions test/Tmds.Ssh.Tests/PrivateKeyCredentialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ await RunWithKeyConversion(_sshServer.TestUserIdentityFile, async (string localK
}, async (SshClient client) =>
{
var exc = await Assert.ThrowsAnyAsync<ConnectFailedException>(() => client.ConnectAsync());
Assert.IsType<PrivateKeyLoadException>(exc.InnerException);
});
}

Expand All @@ -252,7 +251,6 @@ await RunWithKeyConversion(_sshServer.TestUserIdentityFile, async (string localK
}, async (SshClient client) =>
{
var exc = await Assert.ThrowsAnyAsync<ConnectFailedException>(() => client.ConnectAsync());
Assert.IsType<PrivateKeyLoadException>(exc.InnerException);
});
}

Expand Down

0 comments on commit ea3a348

Please sign in to comment.