Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SftpClient.CopyFileAsync. #240

Merged
merged 12 commits into from
Oct 19, 2024
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class SftpClient : IDisposable

ValueTask RenameAsync(string oldpath, string newpath, CancellationToken cancellationToken = default);

ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default);

ValueTask<FileEntryAttributes?> GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default);
ValueTask SetAttributesAsync(
string path,
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/SftpChannel.PacketType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ internal enum PacketType : byte
SSH_FXP_ATTRS = 105,
SSH_FXP_EXTENDED = 200,
SSH_FXP_EXTENDED_REPLY = 201,

SSH_SFTP_STATUS_RESPONSE = 0
}
}
6 changes: 6 additions & 0 deletions src/Tmds.Ssh/SftpChannel.Writer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ internal void WriteInt64(long value)
_length += 8;
}

internal void WriteUInt64(ulong value)
{
BinaryPrimitives.WriteUInt64BigEndian(_buffer.AsSpan(_length), value);
_length += 8;
}

public void WriteAttributes(
long? length = default,
(int Uid, int Gid)? ids = default,
Expand Down
406 changes: 341 additions & 65 deletions src/Tmds.Ssh/SftpChannel.cs

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion src/Tmds.Ssh/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ enum State
// For testing.
internal SshClient SshClient => _client;
internal bool IsDisposed => _state == State.Disposed;
internal SftpExtension EnabledExtensions
{
get
{
SftpChannel channel = _channel ?? throw new InvalidOperationException();
return channel.EnabledExtensions;
}
}

public SftpClient(string destination, ILoggerFactory? loggerFactory = null, SftpClientOptions? options = null) :
this(destination, SshConfigSettings.NoConfig, loggerFactory, options)
Expand Down Expand Up @@ -175,7 +183,7 @@ private async Task<SftpChannel> DoOpenAsync(bool explicitConnect, CancellationTo
bool success = false;
try
{
SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, cancellationToken).ConfigureAwait(false);
SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, _options, cancellationToken).ConfigureAwait(false);
_channel = channel;
success = true;
return channel;
Expand Down Expand Up @@ -267,6 +275,12 @@ public async ValueTask RenameAsync(string oldPath, string newPath, CancellationT
await channel.RenameAsync(oldPath, newPath, cancellationToken).ConfigureAwait(false);
}

public async ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CopyFileAsync(sourcePath, destinationPath, overwrite, cancellationToken).ConfigureAwait(false);
}

public async ValueTask<FileEntryAttributes?> GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
Expand Down
5 changes: 4 additions & 1 deletion src/Tmds.Ssh/SftpClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
namespace Tmds.Ssh;

public sealed partial class SftpClientOptions
{ }
{
// For testing.
internal SftpExtension DisabledExtensions { get; set; }
}
9 changes: 9 additions & 0 deletions src/Tmds.Ssh/SftpExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Tmds.Ssh;

[Flags]
enum SftpExtension
{
None = 0,
// https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00#section-7
CopyData = 1 // copy-data 1
}
4 changes: 2 additions & 2 deletions src/Tmds.Ssh/SshClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ public async Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = n
}
}

internal async Task<SftpChannel> OpenSftpChannelAsync(Action<SshChannel> onAbort, bool explicitConnect, CancellationToken cancellationToken)
internal async Task<SftpChannel> OpenSftpChannelAsync(Action<SshChannel> onAbort, bool explicitConnect, SftpClientOptions options, CancellationToken cancellationToken)
{
SshSession session = await GetSessionAsync(cancellationToken, explicitConnect).ConfigureAwait(false);

var channel = await session.OpenSftpClientChannelAsync(onAbort, cancellationToken).ConfigureAwait(false);

var sftpChannel = new SftpChannel(channel);
var sftpChannel = new SftpChannel(channel, options);

try
{
Expand Down
101 changes: 99 additions & 2 deletions test/Tmds.Ssh.Tests/SftpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Tmds.Ssh.Tests;
public class SftpClientTests
{
const int PacketSize = 32768; // roughly amount of bytes sent/received in a single sftp packet.
const int MultiPacketSize = 2 * PacketSize + 1024;

private readonly SshServer _sshServer;

Expand Down Expand Up @@ -54,7 +55,7 @@ public async Task SftpClientCtorFromSshClientSettings()

[InlineData(10)]
[InlineData(10 * 1024)] // 10 kiB
[InlineData(2 * PacketSize + 1024)]
[InlineData(MultiPacketSize)]
[Theory]
public async Task ReadWriteFile(int fileSize)
{
Expand Down Expand Up @@ -615,7 +616,7 @@ public async Task UploadDownloadDirectory()

[InlineData(0)]
[InlineData(10)]
[InlineData(2 * PacketSize + 1024)]
[InlineData(MultiPacketSize)]
[Theory]
public async Task UploadDownloadFile(int fileSize)
{
Expand Down Expand Up @@ -993,6 +994,102 @@ public async Task CacheLength()
}
}

[InlineData(0, SftpExtension.CopyData)]
[InlineData(10, SftpExtension.CopyData)]
[InlineData(MultiPacketSize, SftpExtension.CopyData)]
[InlineData(0, SftpExtension.None)]
[InlineData(10, SftpExtension.None)]
[InlineData(MultiPacketSize, SftpExtension.None)]
[SkippableTheory]
public async Task CopyFile(int fileSize, SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, fileSize);

string destinationFileName = $"/tmp/{Path.GetRandomFileName()}";
await sftpClient.CopyFileAsync(sourceFileName, destinationFileName);

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName);
}

[InlineData(true, SftpExtension.CopyData)]
[InlineData(false, SftpExtension.CopyData)]
[InlineData(true, SftpExtension.None)]
[InlineData(false, SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileOverwrite(bool overwrite, SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);
(string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);

Task copyTask = sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite).AsTask();

if (overwrite)
{
await copyTask;
}
else
{
await Assert.ThrowsAsync<SftpException>(() => copyTask);
}

byte[] expectedData = overwrite ? sourceData : destinationData;
await AssertRemoteFileContentEqualsAsync(sftpClient, expectedData, destinationFileName);
}

[InlineData(SftpExtension.CopyData)]
[InlineData(SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileToSelfDoesntLooseData(SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);

await sftpClient.CopyFileAsync(sourceFileName, sourceFileName, overwrite: true);

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, sourceFileName);
}

[InlineData(SftpExtension.CopyData)]
[InlineData(SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileOverwriteToLargerTruncates(SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

const int SourceLength = 10;
(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: SourceLength);
const int DestinationLength = SourceLength + SourceLength;
(string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: DestinationLength);

await sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite: true).AsTask();

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName);
}

private async Task AssertRemoteFileContentEqualsAsync(SftpClient client, byte[] expected, string remoteFileName)
{
using var readFile = await client.OpenFileAsync(remoteFileName, FileAccess.Read);
Assert.NotNull(readFile);
var memoryStream = new MemoryStream();
await readFile.CopyToAsync(memoryStream);
Assert.Equal(expected, memoryStream.ToArray());
}

private async Task<(string filename, byte[] data)> CreateRemoteFileWithRandomDataAsync(SftpClient client, int length)
{
string filename = $"/tmp/{Path.GetRandomFileName()}";
byte[] data = new byte[10];
Random.Shared.NextBytes(data);
using var writeFile = await client.CreateNewFileAsync(filename, FileAccess.Write);
await writeFile.WriteAsync(data.AsMemory());
return (filename, data);
}

[InlineData(true)]
[InlineData(false)]
[Theory]
Expand Down
9 changes: 9 additions & 0 deletions test/Tmds.Ssh.Tests/SftpExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Tmds.Ssh.Tests;

// Copy of Tmds.Ssh.SftpExtensions with public access.
[Flags]
public enum SftpExtension
{
None = Tmds.Ssh.SftpExtension.None,
CopyData = Tmds.Ssh.SftpExtension.CopyData
}
33 changes: 31 additions & 2 deletions test/Tmds.Ssh.Tests/SshServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;
using SkipException = Xunit.SkipException;

namespace Tmds.Ssh.Tests;

Expand Down Expand Up @@ -347,11 +348,39 @@ public async Task<SshClient> CreateClientAsync(Action<SshClientSettings>? config
return client;
}

public async Task<SftpClient> CreateSftpClientAsync(Action<SshClientSettings>? configureSsh = null, CancellationToken cancellationToken = default, bool connect = true)
public async Task<SftpClient> CreateSftpClientAsync(Tmds.Ssh.Tests.SftpExtension enabledExtensions, Action<SshClientSettings>? configureSsh = null, CancellationToken cancellationToken = default)
{
var settings = CreateSshClientSettings(configureSsh);

var client = new SftpClient(settings);
SftpClientOptions? options = new()
{
DisabledExtensions = (Tmds.Ssh.SftpExtension)~enabledExtensions
};

var client = new SftpClient(settings, options: options);

await client.ConnectAsync(cancellationToken);

if (client.EnabledExtensions != (Tmds.Ssh.SftpExtension)enabledExtensions)
{
throw new SkipException($"The test server does not support the required {((Tmds.Ssh.SftpExtension)enabledExtensions) & ~client.EnabledExtensions} extensions.");
}

return client;
}

public async Task<SftpClient> CreateSftpClientAsync(Action<SshClientSettings>? configureSsh = null, Action<SftpClientOptions>? configureSftp = null, CancellationToken cancellationToken = default, bool connect = true)
{
var settings = CreateSshClientSettings(configureSsh);

SftpClientOptions? sftpClientOptions = null;
if (configureSftp is not null)
{
sftpClientOptions = new();
configureSftp.Invoke(sftpClientOptions);
}

var client = new SftpClient(settings, options: sftpClientOptions);

if (connect)
{
Expand Down
Loading