Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Aug 27, 2024
1 parent 280e43b commit 6b467f8
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Supported key exchange methods:
Supported encryption algorithms:
- aes256-gcm@openssh.com
- aes128-gcm@openssh.com
- chacha20-poly1305@openssh.com

Supported message authentication code algorithms:
- none
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/AlgorithmNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ static class AlgorithmNames // TODO: rename to KnownNames
public static Name Aes128Gcm => new Name(Aes128GcmBytes);
private static readonly byte[] Aes256GcmBytes = "[email protected]"u8.ToArray();
public static Name Aes256Gcm => new Name(Aes256GcmBytes);
private static readonly byte[] ChaCha20Poly1305Bytes = "[email protected]"u8.ToArray();
public static Name ChaCha20Poly1305 => new Name(ChaCha20Poly1305Bytes);

// KDF algorithms:
private static readonly byte[] BCryptBytes = "bcrypt"u8.ToArray();
Expand Down
130 changes: 130 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketDecoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Security.Cryptography;

namespace Tmds.Ssh;

sealed class ChaCha20Poly1305PacketDecoder : ChaCha20Poly1305PacketEncDecBase, IPacketDecoder
{
private readonly SequencePool _sequencePool;
private int _currentPacketLength = -1;

public ChaCha20Poly1305PacketDecoder(SequencePool sequencePool, byte[] key) :
base(key)
{
_sequencePool = sequencePool;
}

public void Dispose()
{ }

public bool TryDecodePacket(Sequence receiveBuffer, uint sequenceNumber, int maxLength, out Packet packet)
{
packet = new Packet(null);

// Wait for the length.
if (receiveBuffer.Length < LengthSize)
{
return false;
}

// Decrypt length.
int packetLength = _currentPacketLength;
Span<byte> length_unencrypted = stackalloc byte[LengthSize];
if (packetLength == -1)
{
ConfigureCiphers(sequenceNumber);

Span<byte> length_encrypted = stackalloc byte[LengthSize];
if (receiveBuffer.FirstSpan.Length >= LengthSize)
{
receiveBuffer.FirstSpan.Slice(0, LengthSize).CopyTo(length_encrypted);
}
else
{
receiveBuffer.AsReadOnlySequence().Slice(0, LengthSize).CopyTo(length_encrypted);
}

LengthCipher.ProcessBytes(length_encrypted, length_unencrypted);

// Verify the packet length isn't too long and properly padded.
uint packet_length = BinaryPrimitives.ReadUInt32BigEndian(length_unencrypted);
if (packet_length > maxLength || (packet_length % PaddTo) != 0)
{
ThrowHelper.ThrowProtocolPacketTooLong();
}

_currentPacketLength = packetLength = (int)packet_length;
}
else
{
BinaryPrimitives.WriteInt32BigEndian(length_unencrypted, _currentPacketLength);
}

// Wait for the full encrypted packet.
int total_length = LengthSize + packetLength + TagSize;
if (receiveBuffer.Length < total_length)
{
return false;
}

// Check the mac.
ReadOnlySequence<byte> receiveBufferROSequence = receiveBuffer.AsReadOnlySequence();
ReadOnlySequence<byte> hashed = receiveBufferROSequence.Slice(0, LengthSize + packetLength);
Span<byte> packetTag = stackalloc byte[TagSize];
receiveBufferROSequence.Slice(LengthSize + packetLength, TagSize).CopyTo(packetTag);
if (hashed.IsSingleSegment)
{
Mac.BlockUpdate(hashed.FirstSpan);
}
else
{
foreach (var memory in hashed)
{
Mac.BlockUpdate(memory.Span);
}
}
Span<byte> tag = stackalloc byte[TagSize];
Mac.DoFinal(tag);
if (!CryptographicOperations.FixedTimeEquals(packetTag, tag))
{
throw new CryptographicException();
}

int decodedLength = total_length - TagSize;
Sequence decoded = _sequencePool.RentSequence();
Span<byte> dst = decoded.AllocGetSpan(decodedLength);

// Decrypt length.
length_unencrypted.CopyTo(dst);

// Decrypt payload.
Span<byte> plaintext = dst.Slice(LengthSize, packetLength);
ReadOnlySequence<byte> ciphertext = receiveBufferROSequence.Slice(LengthSize, packetLength);
if (ciphertext.IsSingleSegment)
{
PayloadCipher.ProcessBytes(ciphertext.FirstSpan, plaintext);
}
else
{
foreach (var memory in ciphertext)
{
PayloadCipher.ProcessBytes(memory.Span, plaintext);
plaintext = plaintext.Slice(memory.Length);
}
}

decoded.AppendAlloced(decodedLength);
packet = new Packet(decoded);

receiveBuffer.Remove(total_length);

_currentPacketLength = -1; // start decoding a new packet

return true;
}
}
60 changes: 60 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketEncDecBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers.Binary;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Macs;
using Org.BouncyCastle.Crypto.Parameters;

namespace Tmds.Ssh;

class ChaCha20Poly1305PacketEncDecBase
{
public const int TagSize = 16; // Poly1305 hash length.
protected const int PaddTo = 8; // We're not a block cipher. Padd to 8 octets per rfc4253.
protected const int LengthSize = 4; // SSH packet length field is 4 bytes.

protected readonly MyChaCha20 LengthCipher;
protected readonly MyChaCha20 PayloadCipher;
protected readonly Poly1305 Mac;
private readonly byte[] _iv;

protected ChaCha20Poly1305PacketEncDecBase(byte[] key)
{
_iv = new byte[12];
byte[] K_1 = key.AsSpan(32, 32).ToArray();
byte[] K_2 = key.AsSpan(0, 32).ToArray();
LengthCipher = new(K_1, _iv);
PayloadCipher = new(K_2, _iv);
Mac = new();
}

protected void ConfigureCiphers(uint sequenceNumber)
{
BinaryPrimitives.WriteUInt64BigEndian(_iv.AsSpan(4), sequenceNumber);
LengthCipher.SetIv(_iv);
PayloadCipher.SetIv(_iv);

// note: encrypting 64 bytes increments the ChaCha20 block counter.
Span<byte> polyKey = stackalloc byte[64];
PayloadCipher.ProcessBytes(input: polyKey, output: polyKey);
Mac.Init(new KeyParameter(polyKey[..32]));
}

// This class eliminates per packet ParametersWithIV/KeyParameter allocations.
sealed protected class MyChaCha20 : ChaCha7539Engine
{
public MyChaCha20(byte[] key, byte[] dummyIv)
{
Init(forEncryption: true, new ParametersWithIV(new KeyParameter(key), dummyIv));
}

public void SetIv(byte[] iv)
{
SetKey(null, iv);

Reset();
}
}
}
68 changes: 68 additions & 0 deletions src/Tmds.Ssh/ChaCha20Poly1305PacketEncoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System;
using System.Buffers;

namespace Tmds.Ssh;

// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?annotate=HEAD
sealed class ChaCha20Poly1305PacketEncoder : ChaCha20Poly1305PacketEncDecBase, IPacketEncoder
{
public ChaCha20Poly1305PacketEncoder(byte[] key) :
base(key)
{ }

public void Dispose()
{ }

public void Encode(uint sequenceNumber, Packet packet, Sequence output)
{
using var pkt = packet.Move(); // Dispose the packet.

ConfigureCiphers(sequenceNumber);

// Padding.
uint payload_length = (uint)pkt.PayloadLength;
// PT (Plain Text)
// byte padding_length; // 4 <= padding_length < 256
// byte[n1] payload; // n1 = packet_length-padding_length-1
// byte[n2] random_padding; // n2 = padding_length
byte padding_length = IPacketEncoder.DeterminePaddingLength(payload_length + 1, multipleOf: PaddTo);
pkt.WriteHeaderAndPadding(padding_length);

var unencrypted_packet = pkt.AsReadOnlySequence();
ReadOnlySpan<byte> packet_length = unencrypted_packet.FirstSpan.Slice(0, LengthSize); // packet_length
ReadOnlySequence<byte> pt = unencrypted_packet.Slice(LengthSize); // PT (Plain Text)

int textLength = (int)pt.Length;
int encodedLength = LengthSize + textLength + TagSize;
Span<byte> dst = output.AllocGetSpan(encodedLength);

// Encrypt length.
Span<byte> length_encrypted = dst.Slice(0, LengthSize);
LengthCipher.ProcessBytes(packet_length, length_encrypted);

// Encrypt payload.
Span<byte> ciphertext = dst.Slice(LengthSize, textLength);
if (pt.IsSingleSegment)
{
PayloadCipher.ProcessBytes(pt.FirstSpan, ciphertext);
}
else
{
foreach (var memory in pt)
{
PayloadCipher.ProcessBytes(memory.Span, ciphertext);
ciphertext = ciphertext.Slice(memory.Length);
}
}

// Mac.
Span<byte> tag = dst.Slice(LengthSize + textLength, TagSize);
Mac.BlockUpdate(dst.Slice(0, LengthSize + textLength));
Mac.DoFinal(tag);

output.AppendAlloced(encodedLength);
}
}
41 changes: 26 additions & 15 deletions src/Tmds.Ssh/ECDHKeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ public async Task<KeyExchangeOutput> TryExchangeAsync(SshConnection connection,
}

byte[] sessionId = input.ConnectionInfo.SessionId ?? exchangeHash;
byte[] initialIVC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength);
byte[] initialIVS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength);
byte[] encryptionKeyC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength);
byte[] encryptionKeyS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength);
byte[] integrityKeyC2S = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength);
byte[] integrityKeyS2C = Hash(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength);
byte[] initialIVC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength);
byte[] initialIVS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength);
byte[] encryptionKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength);
byte[] encryptionKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength);
byte[] integrityKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength);
byte[] integrityKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength);

return new KeyExchangeOutput(exchangeHash,
initialIVS2C, encryptionKeyS2C, integrityKeyS2C,
Expand Down Expand Up @@ -117,14 +117,13 @@ private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInf
return hash.GetHashAndReset();
}

private byte[] Hash(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int hashLength)
private byte[] CalculateKey(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int keyLength)
{
// https://tools.ietf.org/html/rfc4253#section-7.2

byte[] hashRv = new byte[hashLength];
int hashOffset = 0;
byte[] key = new byte[keyLength];
int keyOffset = 0;

// TODO: handle 'If the key length needed is longer than the output of the HASH'
// HASH(K || H || c || session_id)
using Sequence sequence = sequencePool.RentSequence();
var writer = new SequenceWriter(sequence);
Expand All @@ -139,16 +138,28 @@ private byte[] Hash(SequencePool sequencePool, BigInteger sharedSecret, byte[] e
hash.AppendData(segment.Span);
}
byte[] K1 = hash.GetHashAndReset();
Append(hashRv, K1, ref hashOffset);
Append(key, K1, ref keyOffset);

while (hashOffset != hashRv.Length)
while (keyOffset != key.Length)
{
// TODO: handle 'If the key length needed is longer than the output of the HASH'
sequence.Clear();

// K3 = HASH(K || H || K1 || K2)
throw new NotSupportedException();
writer = new SequenceWriter(sequence);
writer.WriteMPInt(sharedSecret);
writer.Write(exchangeHash);
writer.Write(key.AsSpan(0, keyOffset));

foreach (var segment in sequence.AsReadOnlySequence())
{
hash.AppendData(segment.Span);
}
byte[] Kn = hash.GetHashAndReset();

Append(key, Kn, ref keyOffset);
}

return hashRv;
return key;

static void Append(byte[] key, byte[] append, ref int offset)
{
Expand Down
8 changes: 8 additions & 0 deletions src/Tmds.Ssh/EncryptionAlgorithm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,13 @@ public static EncryptionAlgorithm Find(Name name)
=> new AesGcmPacketDecoder(sequencePool, key, iv, algorithm.TagLength),
isAuthenticated: true,
tagLength: 16) },
{ AlgorithmNames.ChaCha20Poly1305,
new EncryptionAlgorithm(keyLength: 512 / 8, ivLength: 0,
(EncryptionAlgorithm algorithm, byte[] key, byte[] iv, HMacAlgorithm? hmac, byte[] hmacKey)
=> new ChaCha20Poly1305PacketEncoder(key),
(EncryptionAlgorithm algorithm, SequencePool sequencePool, byte[] key, byte[] iv, HMacAlgorithm? hmac, byte[] hmacKey)
=> new ChaCha20Poly1305PacketDecoder(sequencePool, key),
isAuthenticated: true,
tagLength: ChaCha20Poly1305PacketEncoder.TagSize) },
};
}
12 changes: 10 additions & 2 deletions src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken
int sendWindow = Volatile.Read(ref _sendWindow);
if (sendWindow > 0)
{
// We need to check the cancellation token in case we send a huge amount of data
// and the peer can keep up (and the send window never becomes zero).
if (cancellationToken.IsCancellationRequested)
{
Cancel();

cancellationToken.ThrowIfCancellationRequested();
}

int toSend = Math.Min(sendWindow, memory.Length);
toSend = Math.Min(toSend, SendMaxPacket);
if (Interlocked.CompareExchange(ref _sendWindow, sendWindow - toSend, sendWindow) == sendWindow)
Expand All @@ -213,8 +222,7 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken
{
Cancel();

cancellationToken.ThrowIfCancellationRequested();
throw CreateCloseException();
throw;
}
}
}
Expand Down
Loading

0 comments on commit 6b467f8

Please sign in to comment.