Skip to content

Commit

Permalink
chore(dotnet): house-keeping, fixed Rider warnings,style. Added extra…
Browse files Browse the repository at this point in the history
… comments (#468)

Code quality improvement on dotnet package.
  • Loading branch information
irvingoujAtDevolution authored May 24, 2024
1 parent ed8883e commit 63743e9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 61 deletions.
63 changes: 33 additions & 30 deletions ffi/dotnet/Devolutions.IronRdp/src/Connection.cs
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@

using System.Net;
using System.Net.Security;
using System.Net.Sockets;

namespace Devolutions.IronRdp;

public static class Connection
{

public static async Task<(ConnectionResult, Framed<SslStream>)> Connect(Config config, string servername, CliprdrBackendFactory? factory, int port = 3389)
public static async Task<(ConnectionResult, Framed<SslStream>)> Connect(Config config, string serverName,
CliprdrBackendFactory? factory, int port = 3389)
{

var stream = await CreateTcpConnection(servername, port);
var stream = await CreateTcpConnection(serverName, port);
var framed = new Framed<NetworkStream>(stream);

ClientConnector connector = ClientConnector.New(config);
var connector = ClientConnector.New(config);

var ip = await Dns.GetHostAddressesAsync(servername);
var ip = await Dns.GetHostAddressesAsync(serverName);
if (ip.Length == 0)
{
throw new IronRdpLibException(IronRdpLibExceptionType.CannotResolveDns, "Cannot resolve DNS to " + servername);
throw new IronRdpLibException(IronRdpLibExceptionType.CannotResolveDns,
"Cannot resolve DNS to " + serverName);
}

var socketAddrString = ip[0].ToString() + ":" + port;
connector.WithServerAddr(socketAddrString);

var serverAddr = ip[0] + ":" + port;
connector.WithServerAddr(serverAddr);
if (factory != null)
{
var cliprdr = factory.BuildCliprdr();
connector.AttachStaticCliprdr(cliprdr);
}

await connectBegin(framed, connector);
var (serverPublicKey, framedSsl) = await securityUpgrade(servername, framed, connector);
var result = await ConnectFinalize(servername, connector, serverPublicKey, framedSsl);
await ConnectBegin(framed, connector);
var (serverPublicKey, framedSsl) = await SecurityUpgrade(framed, connector);
var result = await ConnectFinalize(serverName, connector, serverPublicKey, framedSsl);
return (result, framedSsl);
}

private static async Task<(byte[], Framed<SslStream>)> securityUpgrade(string servername, Framed<NetworkStream> framed, ClientConnector connector)
private static async Task<(byte[], Framed<SslStream>)> SecurityUpgrade(Framed<NetworkStream> framed,
ClientConnector connector)
{
byte[] serverPublicKey;
Framed<SslStream> framedSsl;
var (streamRequireUpgrade, _) = framed.GetInner();
var promise = new TaskCompletionSource<byte[]>();
var sslStream = new SslStream(streamRequireUpgrade, false, (sender, certificate, chain, sslPolicyErrors) =>
var sslStream = new SslStream(streamRequireUpgrade, false, (_, certificate, _, _) =>
{
promise.SetResult(certificate!.GetPublicKey());
return true;
Expand All @@ -51,14 +49,14 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions()
{
AllowTlsResume = false
});
serverPublicKey = await promise.Task;
framedSsl = new Framed<SslStream>(sslStream);
var serverPublicKey = await promise.Task;
Framed<SslStream> framedSsl = new(sslStream);
connector.MarkSecurityUpgradeAsDone();

return (serverPublicKey, framedSsl);
}

private static async Task connectBegin(Framed<NetworkStream> framed, ClientConnector connector)
private static async Task ConnectBegin(Framed<NetworkStream> framed, ClientConnector connector)
{
var writeBuf = WriteBuf.New();
while (!connector.ShouldPerformSecurityUpgrade())
Expand All @@ -68,13 +66,15 @@ private static async Task connectBegin(Framed<NetworkStream> framed, ClientConne
}


private static async Task<ConnectionResult> ConnectFinalize(string servername, ClientConnector connector, byte[] serverpubkey, Framed<SslStream> framedSsl)
private static async Task<ConnectionResult> ConnectFinalize(string serverName, ClientConnector connector,
byte[] serverPubKey, Framed<SslStream> framedSsl)
{
var writeBuf2 = WriteBuf.New();
if (connector.ShouldPerformCredssp())
{
await PerformCredsspSteps(connector, servername, writeBuf2, framedSsl, serverpubkey);
await PerformCredsspSteps(connector, serverName, writeBuf2, framedSsl, serverPubKey);
}

while (!connector.GetDynState().IsTerminal())
{
await SingleConnectStep(connector, writeBuf2, framedSsl);
Expand All @@ -92,12 +92,13 @@ private static async Task<ConnectionResult> ConnectFinalize(string servername, C
}
}

private static async Task PerformCredsspSteps(ClientConnector connector, string serverName, WriteBuf writeBuf, Framed<SslStream> framedSsl, byte[] serverpubkey)
private static async Task PerformCredsspSteps(ClientConnector connector, string serverName, WriteBuf writeBuf,
Framed<SslStream> framedSsl, byte[] serverpubkey)
{
var credsspSequenceInitResult = CredsspSequence.Init(connector, serverName, serverpubkey, null);
var credsspSequence = credsspSequenceInitResult.GetCredsspSequence();
var tsRequest = credsspSequenceInitResult.GetTsRequest();
TcpClient tcpClient = new TcpClient();
var tcpClient = new TcpClient();
while (true)
{
var generator = credsspSequence.ProcessTsRequest(tsRequest);
Expand All @@ -122,6 +123,7 @@ private static async Task PerformCredsspSteps(ClientConnector connector, string
var pdu = await framedSsl.ReadByHint(pduHint);
var decoded = credsspSequence.DecodeServerMessage(pdu);

// Don't remove, DecodeServerMessage is generated, and it can return null
if (null == decoded)
{
break;
Expand Down Expand Up @@ -149,8 +151,8 @@ private static async Task<ClientState> ResolveGenerator(CredsspProcessGenerator
var split = url.Split(":");
await tcpClient.ConnectAsync(split[0], int.Parse(split[1]));
stream = tcpClient.GetStream();

}

if (protocol == NetworkRequestProtocol.Tcp)
{
stream.Write(Utils.VecU8ToByte(data));
Expand All @@ -174,12 +176,14 @@ private static async Task<ClientState> ResolveGenerator(CredsspProcessGenerator
}

static async Task SingleConnectStep<T>(ClientConnector connector, WriteBuf buf, Framed<T> framed)
where T : Stream
where T : Stream
{
buf.Clear();

var pduHint = connector.NextPduHint();
Written written;

// Don't remove, NextPduHint is generated, and it can return null
if (pduHint != null)
{
byte[] pdu = await framed.ReadByHint(pduHint);
Expand Down Expand Up @@ -227,16 +231,15 @@ static async Task<NetworkStream> CreateTcpConnection(String servername, int port

return stream;
}

}

public static class Utils
{
public static byte[] VecU8ToByte(VecU8 vecU8)
{
var len = vecU8.GetSize();
byte[] buffer = new byte[len];
var buffer = new byte[len];
vecU8.Fill(buffer);
return buffer;
}
}
}
8 changes: 5 additions & 3 deletions ffi/dotnet/Devolutions.IronRdp/src/Exceptions.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
namespace Devolutions.IronRdp;

[Serializable]
public class IronRdpLibException : Exception
{
public IronRdpLibExceptionType type { get; private set; }
public IronRdpLibExceptionType ErrorType { get; private set; }

public IronRdpLibException(IronRdpLibExceptionType type, string message) : base(message)
public IronRdpLibException(IronRdpLibExceptionType errorType, string message) : base(message)
{
this.type = type;
ErrorType = errorType;
}

}
Expand Down
54 changes: 26 additions & 28 deletions ffi/dotnet/Devolutions.IronRdp/src/Framed.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
using System.Linq.Expressions;
using System.Runtime.InteropServices;
using Devolutions.IronRdp;

public class Framed<S> where S : Stream
namespace Devolutions.IronRdp;

public class Framed<TS> where TS : Stream
{
private S stream;
private List<byte> buffer;
private readonly Mutex writeLock = new Mutex();
private readonly TS _stream;
private List<byte> _buffer;
private readonly Mutex _writeLock = new();

public Framed(S stream)
public Framed(TS stream)
{
this.stream = stream;
this.buffer = new List<byte>();
_stream = stream;
_buffer = new List<byte>();
}

public (S, List<byte>) GetInner()
public (TS, List<byte>) GetInner()
{
return (this.stream, this.buffer);
return (_stream, _buffer);
}

public async Task<(Devolutions.IronRdp.Action, byte[])> ReadPdu()
public async Task<(Action, byte[])> ReadPdu()
{
while (true)
{
var pduInfo = IronRdpPdu.New().FindSize(this.buffer.ToArray());
var pduInfo = IronRdpPdu.New().FindSize(this._buffer.ToArray());

// Don't remove, FindSize is generated and can return null
if (null != pduInfo)
{
var frame = await this.ReadExact(pduInfo.GetLength());
Expand All @@ -48,7 +50,7 @@ public Framed(S stream)
/// <returns>A span that represents a portion of the underlying buffer.</returns>
public Span<byte> Peek()
{
return CollectionsMarshal.AsSpan(this.buffer);
return CollectionsMarshal.AsSpan(this._buffer);
}

/// <summary>
Expand All @@ -60,10 +62,10 @@ public async Task<byte[]> ReadExact(nuint size)
{
while (true)
{
if (buffer.Count >= (int)size)
if (_buffer.Count >= (int)size)
{
var res = this.buffer.Take((int)size).ToArray();
this.buffer = this.buffer.Skip((int)size).ToList();
var res = this._buffer.Take((int)size).ToArray();
this._buffer = this._buffer.Skip((int)size).ToList();
return res;
}

Expand All @@ -79,26 +81,22 @@ async Task<int> Read()
{
var buffer = new byte[8096];
Memory<byte> memory = buffer;
var size = await this.stream.ReadAsync(memory);
this.buffer.AddRange(buffer.Take(size));
var size = await this._stream.ReadAsync(memory);
this._buffer.AddRange(buffer.Take(size));
return size;
}

public async Task Write(byte[] data)
{
writeLock.WaitOne();
_writeLock.WaitOne();
try
{
ReadOnlyMemory<byte> memory = data;
await this.stream.WriteAsync(memory);
}
catch (Exception e)
{
throw e;
await _stream.WriteAsync(memory);
}
finally
{
writeLock.ReleaseMutex();
_writeLock.ReleaseMutex();
}
}

Expand All @@ -112,7 +110,7 @@ public async Task<byte[]> ReadByHint(PduHint pduHint)
{
while (true)
{
var size = pduHint.FindSize(this.buffer.ToArray());
var size = pduHint.FindSize(this._buffer.ToArray());
if (size.IsSome())
{
return await this.ReadExact(size.Get());
Expand All @@ -128,4 +126,4 @@ public async Task<byte[]> ReadByHint(PduHint pduHint)

}
}
}
}

0 comments on commit 63743e9

Please sign in to comment.