Skip to content

Commit

Permalink
Support using private keys that are not stored in files. (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Sep 11, 2024
1 parent ae4add5 commit 7130d4a
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 62 deletions.
88 changes: 81 additions & 7 deletions src/Tmds.Ssh/PrivateKeyCredential.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,94 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System.Security.Cryptography;

namespace Tmds.Ssh;

public sealed class PrivateKeyCredential : Credential
public class PrivateKeyCredential : Credential
{
internal string FilePath { get; }
internal string Identifier { get; }

private Func<CancellationToken, ValueTask<Key>> LoadKey { get; }

internal Func<string?> PasswordPrompt { get; }
public PrivateKeyCredential(string path, string? password = null, string? identifier = null) :
this(path, () => password, identifier)
{ }

public PrivateKeyCredential(string path, string? password = null) : this(path, () => password)
public PrivateKeyCredential(string path, Func<string?> passwordPrompt, string? identifier = null) :
this(LoadKeyFromFile(path ?? throw new ArgumentNullException(nameof(path)), passwordPrompt), identifier ?? path)
{ }

public PrivateKeyCredential(string path, Func<string?> passwordPrompt)
// Allows the user to implement derived classes that represent a private key.
protected PrivateKeyCredential(Func<CancellationToken, ValueTask<Key>> loadKey, string identifier)
{
LoadKey = loadKey;
Identifier = identifier;
}

private static Func<CancellationToken, ValueTask<Key>> LoadKeyFromFile(string path, Func<string?> passwordPrompt)
=> (CancellationToken cancellationToken) =>
{
if (PrivateKeyParser.TryParsePrivateKeyFile(path, passwordPrompt, out PrivateKey? privateKey, out Exception? error))
{
return ValueTask.FromResult(new Key(privateKey));
}
if (error is FileNotFoundException or DirectoryNotFoundException)
{
return ValueTask.FromResult(default(Key));
}
throw error;
};

// This is a type we expose to our derive types to avoid having to expose PrivateKey and a bunch of other internals.
protected readonly struct Key
{
internal PrivateKey? PrivateKey { get; }

public Key(RSA rsa)
{
PrivateKey = new RsaPrivateKey(rsa);
}

public Key(ECDsa ecdsa)
{
ECParameters parameters = ecdsa.ExportParameters(includePrivateParameters: false);
Oid oid = parameters.Curve.Oid;

Name algorithm;
Name curveName;
HashAlgorithmName hashAlgorithm;
if (oid.Equals(ECCurve.NamedCurves.nistP256.Oid))
{
(algorithm, curveName, hashAlgorithm) = (AlgorithmNames.EcdsaSha2Nistp256, AlgorithmNames.Nistp256, HashAlgorithmName.SHA256);
}
else if (oid.Equals(ECCurve.NamedCurves.nistP384.Oid))
{
(algorithm, curveName, hashAlgorithm) = (AlgorithmNames.EcdsaSha2Nistp384, AlgorithmNames.Nistp384, HashAlgorithmName.SHA384);
}
else if (oid.Equals(ECCurve.NamedCurves.nistP521.Oid))
{
(algorithm, curveName, hashAlgorithm) = (AlgorithmNames.EcdsaSha2Nistp521, AlgorithmNames.Nistp521, HashAlgorithmName.SHA512);
}
else
{
throw new NotSupportedException($"Curve {oid} is not known.");
}

PrivateKey = new ECDsaPrivateKey(ecdsa, algorithm, curveName, hashAlgorithm);
}

internal Key(PrivateKey key)
{
PrivateKey = key;
}
}

internal async ValueTask<PrivateKey?> LoadKeyAsync(CancellationToken cancellationToken)
{
FilePath = path ?? throw new ArgumentNullException(nameof(path));
PasswordPrompt = passwordPrompt;
Key key = await LoadKey(cancellationToken);
return key.PrivateKey;
}
}
16 changes: 8 additions & 8 deletions src/Tmds.Ssh/SshClientLogger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Name compressionAlgorithmServerToClient
[LoggerMessage(
EventId = 19,
Level = LogLevel.Information,
Message = "Auth using publickey from '{FileName}' with {SignatureAlgorithm} signature")]
public static partial void PublicKeyAuth(this ILogger<SshClient> logger, string fileName, Name signatureAlgorithm);
Message = "Auth using publickey '{keyIdentifier}' with {SignatureAlgorithm} signature")]
public static partial void PublicKeyAuth(this ILogger<SshClient> logger, string keyIdentifier, Name signatureAlgorithm);

[LoggerMessage(
EventId = 20,
Expand All @@ -211,20 +211,20 @@ Name compressionAlgorithmServerToClient
[LoggerMessage(
EventId = 22,
Level = LogLevel.Information,
Message = "Public key file '{FileName}' not found.")]
public static partial void PublicKeyFileNotFound(this ILogger<SshClient> logger, string fileName);
Message = "Public key '{KeyIdentifier}' not found.")]
public static partial void PublicKeyFileNotFound(this ILogger<SshClient> logger, string keyIdentifier);

[LoggerMessage(
EventId = 23,
Level = LogLevel.Error,
Message = "Failed to load public key file '{FileName}'.")]
public static partial void PublicKeyCanNotLoad(this ILogger<SshClient> logger, string fileName, Exception exception);
Message = "Failed to load public key '{KeyIdentifier}'.")]
public static partial void PublicKeyCanNotLoad(this ILogger<SshClient> logger, string keyIdentifier, Exception exception);

[LoggerMessage(
EventId = 24,
Level = LogLevel.Information,
Message = "Public key file '{FileName}' has no accepted algorithms. Accepted algorithms: {AcceptedAlgorithms}")]
public static partial void PublicKeyAlgorithmsNotAccepted(this ILogger<SshClient> logger, string fileName, List<Name> acceptedAlgorithms);
Message = "Public key '{KeyIdentifier}' has no accepted algorithms. Accepted algorithms: {AcceptedAlgorithms}")]
public static partial void PublicKeyAlgorithmsNotAccepted(this ILogger<SshClient> logger, string keyIdentifier, List<Name> acceptedAlgorithms);

[LoggerMessage(
EventId = 25,
Expand Down
89 changes: 42 additions & 47 deletions src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,66 @@ public static async Task<bool> TryAuthenticate(PrivateKeyCredential keyCredentia
return false;
}

string filename = keyCredential.FilePath;
if (!File.Exists(filename))
PrivateKey? pk;
try
{
return false;
pk = await keyCredential.LoadKeyAsync(ct);
if (pk is null)
{
logger.PublicKeyFileNotFound(keyCredential.Identifier);
return false;
}
}
catch (Exception error)
{
logger.PublicKeyCanNotLoad(keyCredential.Identifier, error);
throw new PrivateKeyLoadException(keyCredential.Identifier, error);
}

if (PrivateKeyParser.TryParsePrivateKeyFile(keyCredential.FilePath, keyCredential.PasswordPrompt, out PrivateKey? pk, out Exception? error))
using (pk)
{
using (pk)
if (pk is RsaPrivateKey rsaKey)
{
if (pk is RsaPrivateKey rsaKey)
if (rsaKey.KeySize < context.MinimumRSAKeySize)
{
if (rsaKey.KeySize < context.MinimumRSAKeySize)
{
// TODO: log
return false;
}
// TODO: log
return false;
}
}

bool acceptedAlgorithm = false;
foreach (var keyAlgorithm in pk.Algorithms)
bool acceptedAlgorithm = false;
foreach (var keyAlgorithm in pk.Algorithms)
{
if (!context.PublicKeyAcceptedAlgorithms.Contains(keyAlgorithm))
{
if (!context.PublicKeyAcceptedAlgorithms.Contains(keyAlgorithm))
{
continue;
}

if (!context.TryStartAuth(AlgorithmNames.PublicKey))
{
Debug.Assert(false); // Already did an eary SkipMethod check.
return false;
}
continue;
}

acceptedAlgorithm = true;
logger.PublicKeyAuth(keyCredential.FilePath, keyAlgorithm);
if (!context.TryStartAuth(AlgorithmNames.PublicKey))
{
Debug.Assert(false); // Already did an eary SkipMethod check.
return false;
}

{
using var userAuthMsg = CreatePublicKeyRequestMessage(
keyAlgorithm, context.SequencePool, context.UserName, connectionInfo.SessionId!, pk!);
await context.SendPacketAsync(userAuthMsg.Move(), ct).ConfigureAwait(false);
}
acceptedAlgorithm = true;
logger.PublicKeyAuth(keyCredential.Identifier, keyAlgorithm);

bool success = await context.ReceiveAuthIsSuccesfullAsync(ct).ConfigureAwait(false);
if (success)
{
return true;
}
{
using var userAuthMsg = CreatePublicKeyRequestMessage(
keyAlgorithm, context.SequencePool, context.UserName, connectionInfo.SessionId!, pk!);
await context.SendPacketAsync(userAuthMsg.Move(), ct).ConfigureAwait(false);
}

if (!acceptedAlgorithm)
bool success = await context.ReceiveAuthIsSuccesfullAsync(ct).ConfigureAwait(false);
if (success)
{
logger.PublicKeyAlgorithmsNotAccepted(keyCredential.FilePath, context.PublicKeyAcceptedAlgorithms);
return true;
}
}
}
else
{
if (error is FileNotFoundException or DirectoryNotFoundException)
{
logger.PublicKeyFileNotFound(filename);
}
else

if (!acceptedAlgorithm)
{
logger.PublicKeyCanNotLoad(filename, error);
throw new PrivateKeyLoadException(filename, error); // TODO: throw or skip?
logger.PublicKeyAlgorithmsNotAccepted(keyCredential.Identifier, context.PublicKeyAcceptedAlgorithms);
}
}

Expand Down

0 comments on commit 7130d4a

Please sign in to comment.