Skip to content

Commit

Permalink
SftpClient: support uploading file from Stream.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds committed Dec 11, 2024
1 parent 1599735 commit e2db13f
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 57 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class SftpClient : IDisposable

ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken);
ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions, CancellationToken cancellationToken = default);
ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken);
ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = null, CancellationToken cancellationToken = default);
ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default);
ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default);

Expand Down
145 changes: 101 additions & 44 deletions src/Tmds.Ssh/SftpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -782,78 +782,135 @@ private static UnixFilePermissions GetPermissionsForFile(SafeFileHandle fileHand

public async ValueTask UploadFileAsync(string localPath, string remotePath, long? length, bool overwrite, UnixFilePermissions? permissions, CancellationToken cancellationToken)
{
using SafeFileHandle localFile = File.OpenHandle(localPath, FileMode.Open, FileAccess.Read, FileShare.Read);
using FileStream localFile = new FileStream(localPath, FileMode.Open, FileAccess.Read, FileShare.Read, bufferSize: 0);

permissions ??= GetPermissionsForFile(localFile);
permissions ??= GetPermissionsForFile(localFile.SafeFileHandle);

using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions.Value, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!;
await UploadFileAsync(localFile, remotePath, length, overwrite, permissions.Value, cancellationToken).ConfigureAwait(false);
}

length ??= RandomAccess.GetLength(localFile);
public async ValueTask UploadFileAsync(Stream source, string remotePath, long? length, bool overwrite, UnixFilePermissions permissions, CancellationToken cancellationToken)
{
using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!;

ValueTask previous = default;
// Pipeline the writes when the source is a sync, seekable Stream.
bool pipelineSyncWrites = source.CanSeek && IsSyncStream(source);

CancellationTokenSource? breakLoop = length > 0 ? new() : null;
if (!pipelineSyncWrites)
{
await source.CopyToAsync(remoteFile, GetMaxWritePayload(remoteFile.Handle)).ConfigureAwait(false);

for (long offset = 0; offset < length; offset += GetMaxWritePayload(remoteFile.Handle))
await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
}
else
{
Debug.Assert(breakLoop is not null);
if (!breakLoop.IsCancellationRequested)
length ??= source.Length;
if (length == 0)
{
await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
previous = CopyBuffer(previous, offset, GetMaxWritePayload(remoteFile.Handle));
return;
}
}

await previous.ConfigureAwait(false);
ValueTask previous = default;
long startOffset = source.Position;
long bytesSuccesfullyWritten = 0;
CancellationTokenSource breakLoop = new();
int maxWritePayload = GetMaxWritePayload(remoteFile.Handle);

await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
for (long offset = 0; offset < length; offset += maxWritePayload)
{
if (!breakLoop.IsCancellationRequested)
{
await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
int copyLength = (int)Math.Min((long)maxWritePayload, length.Value - offset);
previous = CopyBuffer(previous, offset, copyLength);
}
}

async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length)
{
bool ignorePositionUpdateException = false;
try
{
byte[]? buffer = null;
await previous.ConfigureAwait(false);

await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
}
catch
{
ignorePositionUpdateException = true;

throw;
}
finally
{
// Set the position to what was succesfully written.
try
{
if (breakLoop.IsCancellationRequested)
source.Position = startOffset + bytesSuccesfullyWritten;
}
catch when (ignorePositionUpdateException)
{ }
}

async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length)
{
try
{
byte[]? buffer = null;
try
{
return;
}
if (breakLoop.IsCancellationRequested)
{
return;
}

buffer = ArrayPool<byte>.Shared.Rent(length);
do
buffer = ArrayPool<byte>.Shared.Rent(length);
int remaining = length;
do
{
int bytesRead;
lock (breakLoop) // Ensure only one thread is reading the Stream concurrently.
{
source.Position = startOffset + offset;
bytesRead = source.Read(buffer.AsSpan(length - remaining));
}
if (bytesRead == 0)
{
throw new IOException("Unexpected end of file. The source was truncated during the upload.");
}
remaining -= bytesRead;
} while (remaining > 0);

await remoteFile.WriteAtAsync(buffer.AsMemory(0, length), offset, cancellationToken).ConfigureAwait(false);
}
catch
{
int bytesRead = RandomAccess.Read(localFile, buffer.AsSpan(0, length), offset);
if (bytesRead == 0)
length = 0; // Assume nothing was written succesfully.
breakLoop.Cancel();
throw;
}
finally
{
if (buffer != null)
{
break;
ArrayPool<byte>.Shared.Return(buffer);
}
await remoteFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset, cancellationToken).ConfigureAwait(false);
length -= bytesRead;
offset += bytesRead;
} while (length > 0);
}
catch
{
breakLoop.Cancel();
throw;
s_uploadBufferSemaphore.Release();
}
}
finally
{
if (buffer != null)
{
ArrayPool<byte>.Shared.Return(buffer);
}
s_uploadBufferSemaphore.Release();
await previousCopy.ConfigureAwait(false);

// Update with our length after the previous write completed succesfully.
bytesSuccesfullyWritten += length;
}
}
finally
{
await previousCopy.ConfigureAwait(false);
}
}
}

// Consider it okay to do sync operation on these types of streams.
private static bool IsSyncStream(Stream stream)
=> stream is MemoryStream or FileStream;

private IAsyncEnumerable<T> GetDirectoryEntriesAsync<T>(string path, SftpFileEntryTransform<T> transform, EnumerationOptions options)
=> new SftpFileSystemEnumerable<T>(this, path, transform, options);

Expand Down Expand Up @@ -1087,7 +1144,7 @@ private async ValueTask DownloadFileAsync(string remotePath, string? localPath,

Debug.Assert(destination is not null);

bool writeSync = destination is FileStream or MemoryStream;
bool writeSync = IsSyncStream(destination);

ValueTask previous = default;
CancellationTokenSource? breakLoop = length > 0 ? new() : null;
Expand Down
25 changes: 17 additions & 8 deletions src/Tmds.Ssh/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)

internal async ValueTask OpenAsync(CancellationToken cancellationToken)
{
await GetChannelAsync(cancellationToken);
await GetChannelAsync(cancellationToken).ConfigureAwait(false);
}

internal ValueTask<SftpChannel> GetChannelAsync(CancellationToken cancellationToken, bool explicitConnect = false)
Expand Down Expand Up @@ -330,7 +330,7 @@ public ValueTask CreateDirectoryAsync(string path, CancellationToken cancellatio
public async ValueTask CreateDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken);
await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancellationToken)
Expand All @@ -339,7 +339,7 @@ public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancella
public async ValueTask CreateNewDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken);
await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default)
Expand All @@ -348,7 +348,7 @@ public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteD
public async ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken);
await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken)
Expand All @@ -357,7 +357,16 @@ public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, Ca
public async ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = default, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken);
await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken)
=> UploadFileAsync(source, remoteFilePath, overwrite: false, createPermissions: null, cancellationToken);

public async ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = null, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadFileAsync(source, remoteFilePath, length: null, overwrite, createPermissions ?? OwnershipPermissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, CancellationToken cancellationToken = default)
Expand All @@ -366,7 +375,7 @@ public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string loca
public async ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, DownloadEntriesOptions? options, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken);
await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken).ConfigureAwait(false);
}

public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, CancellationToken cancellationToken)
Expand All @@ -375,13 +384,13 @@ public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath,
public async ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, bool overwrite = false, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken);
await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken).ConfigureAwait(false);
}

public async ValueTask DownloadFileAsync(string remoteFilePath, Stream destination, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken);
await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken).ConfigureAwait(false);
}

private ObjectDisposedException NewObjectDisposedException()
Expand Down
19 changes: 14 additions & 5 deletions test/Tmds.Ssh.Tests/SftpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -708,15 +708,24 @@ public async Task UploadDownloadFile(int fileSize)
}
}

[Fact]
public async Task DownloadFileToStream()
[InlineData(0)]
[InlineData(10)]
[InlineData(10 * MultiPacketSize)] // Ensure some pipelined writing.
[Theory]
public async Task UploadDownloadFileWithStream(int size)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync();
var (sourceFileName, sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 100);

byte[] sourceData = new byte[size];
Random.Shared.NextBytes(sourceData);
MemoryStream uploadStream = new MemoryStream(sourceData);

await using var downloadStream = new MemoryStream();
await sftpClient.DownloadFileAsync(sourceFileName, downloadStream);
string remotePath = $"/tmp/{Path.GetRandomFileName()}";
await sftpClient.UploadFileAsync(uploadStream, remotePath);
Assert.Equal(sourceData.Length, uploadStream.Position);

await using var downloadStream = new MemoryStream();
await sftpClient.DownloadFileAsync(remotePath, downloadStream);
Assert.Equal(sourceData, downloadStream.ToArray());
}

Expand Down

0 comments on commit e2db13f

Please sign in to comment.