From ba41694cf3560eb77db4c1ffc0cf17ee954f5142 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Wed, 29 May 2024 17:11:42 +0200 Subject: [PATCH] SftpFile: disallow concurrent reads/writes. (#171) --- src/Common/SftpFile.cs | 79 +++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/src/Common/SftpFile.cs b/src/Common/SftpFile.cs index 8d3d057..17220cd 100644 --- a/src/Common/SftpFile.cs +++ b/src/Common/SftpFile.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using System.IO; +using System.Diagnostics; namespace Tmds.Ssh; @@ -20,6 +21,8 @@ public sealed class SftpFile : Stream // The position is updated at the start of the operation to support concurrent requests. private long _position; + private int _inProgress; + internal SftpFile(SftpClient client, byte[] handle) { _client = client; @@ -68,17 +71,16 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { ThrowIfDisposed(); - long readOffset = Interlocked.Add(ref _position, buffer.Length) - buffer.Length; - int bytesRead = 0; + SetInProgress(true); try { - bytesRead = await _client.ReadFileAsync(Handle, readOffset, buffer, cancellationToken).ConfigureAwait(false); - + int bytesRead = await _client.ReadFileAsync(Handle, _position, buffer, cancellationToken).ConfigureAwait(false); + _position += bytesRead; return bytesRead; } finally { - Interlocked.Add(ref _position, bytesRead - buffer.Length); + SetInProgress(false); } } @@ -99,15 +101,15 @@ public async override ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella { ThrowIfDisposed(); - long writeOffset = Interlocked.Add(ref _position, buffer.Length) - buffer.Length; + SetInProgress(true); try { - await _client.WriteFileAsync(Handle, writeOffset, buffer, cancellationToken).ConfigureAwait(false); + await _client.WriteFileAsync(Handle, _position, buffer, cancellationToken).ConfigureAwait(false); + _position += buffer.Length; } - catch + finally { - Interlocked.Add(ref _position, -buffer.Length); - throw; + SetInProgress(false); } } @@ -147,18 +149,32 @@ public async ValueTask SetAttributesAsync( { ThrowIfDisposed(); - await _client.SetAttributesForHandleAsync( - handle: Handle, - length: length, - ids: ids, - permissions: permissions, - times: times, - extendedAttributes: extendedAttributes, - cancellationToken).ConfigureAwait(false); - - if (_position > length) + if (length.HasValue) + { + SetInProgress(true); + } + try { - _position = length.Value; + await _client.SetAttributesForHandleAsync( + handle: Handle, + length: length, + ids: ids, + permissions: permissions, + times: times, + extendedAttributes: extendedAttributes, + cancellationToken).ConfigureAwait(false); + + if (_position > length) + { + _position = length.Value; + } + } + finally + { + if (length.HasValue) + { + SetInProgress(false); + } } } @@ -194,4 +210,25 @@ public override long Seek(long offset, SeekOrigin origin) public override void SetLength(long value) => throw new NotSupportedException(); + + private void SetInProgress(bool value) + { + if (value) + { + if (Interlocked.CompareExchange(ref _inProgress, 1, 0) != 0) + { + ThrowConcurrentOperations(); + } + } + else + { + Debug.Assert(_inProgress == 1); + Volatile.Write(ref _inProgress, 0); + } + } + + private void ThrowConcurrentOperations() + { + throw new InvalidOperationException("Concurrent read/write operations are not allowed."); + } }