diff --git a/test/Tmds.Ssh.Tests/SftpClientTests.cs b/test/Tmds.Ssh.Tests/SftpClientTests.cs index d4d0848..783236f 100644 --- a/test/Tmds.Ssh.Tests/SftpClientTests.cs +++ b/test/Tmds.Ssh.Tests/SftpClientTests.cs @@ -729,6 +729,26 @@ public async Task UploadDownloadFileWithStream(int size) Assert.Equal(sourceData, downloadStream.ToArray()); } + [InlineData(0)] + [InlineData(10)] + [InlineData(10 * MultiPacketSize)] + [Theory] + public async Task UploadDownloadFileWithAsyncStream(int size) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(); + + byte[] sourceData = new byte[size]; + Random.Shared.NextBytes(sourceData); + Stream uploadStream = new NonSeekableAsyncStream(sourceData); + + string remotePath = $"/tmp/{Path.GetRandomFileName()}"; + await sftpClient.UploadFileAsync(uploadStream, remotePath); + + await using var downloadStream = new NonSeekableAsyncStream(); + await sftpClient.DownloadFileAsync(remotePath, downloadStream); + Assert.Equal(sourceData, downloadStream.ToArray()); + } + [Fact] public async Task DownloadFileThrowsWhenNotFound() { @@ -1228,4 +1248,61 @@ public async Task AutoReconnect(bool autoReconnect) await Assert.ThrowsAsync(() => client.GetFullPathAsync("").AsTask()); } } + + sealed class NonSeekableAsyncStream : Stream + { + private readonly MemoryStream _innerStream = new(); + + public NonSeekableAsyncStream() + { + _innerStream = new(); + } + + public NonSeekableAsyncStream(byte[] data) + { + _innerStream = new(data); + } + + public byte[] ToArray() + => _innerStream.ToArray(); + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => throw new NotImplementedException(); + + public override long Position + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + } }