From d4544e55b64a1373753e9965670d015470912d83 Mon Sep 17 00:00:00 2001 From: Matthew Kelly Date: Wed, 28 Aug 2024 22:19:17 +0100 Subject: [PATCH 1/3] #1291: rewrite of socket receiver to resolve NET8 SSL I/O changes --- examples/sse.ps1 | 4 +- examples/web-pages-https.ps1 | 5 + examples/web-pages.ps1 | 14 +-- src/Listener/Pode.csproj | 1 + src/Listener/PodeConnector.cs | 3 +- src/Listener/PodeContext.cs | 158 +++++++++---------------- src/Listener/PodeEndpoint.cs | 6 +- src/Listener/PodeFileWatcher.cs | 10 +- src/Listener/PodeForm.cs | 2 +- src/Listener/PodeFormData.cs | 10 +- src/Listener/PodeHelpers.cs | 37 +++++- src/Listener/PodeHttpRequest.cs | 161 +++++++++++++------------ src/Listener/PodeListener.cs | 22 ++-- src/Listener/PodeReceiver.cs | 12 +- src/Listener/PodeRequest.cs | 169 +++++++++++++++------------ src/Listener/PodeResponse.cs | 99 ++++++++-------- src/Listener/PodeSignalRequest.cs | 14 ++- src/Listener/PodeSmtpRequest.cs | 72 ++++++------ src/Listener/PodeSocket.cs | 156 +++++++------------------ src/Listener/PodeTcpRequest.cs | 31 ++--- src/Listener/PodeWatcher.cs | 4 +- src/Listener/PodeWebSocket.cs | 57 ++++----- src/Listener/PodeWebSocketRequest.cs | 4 +- src/Private/PodeServer.ps1 | 2 +- src/Public/Responses.ps1 | 2 +- src/Public/SSE.ps1 | 4 +- src/Public/WebSockets.ps1 | 8 +- 27 files changed, 500 insertions(+), 567 deletions(-) diff --git a/examples/sse.ps1 b/examples/sse.ps1 index 0fa57b0bd..564e97696 100644 --- a/examples/sse.ps1 +++ b/examples/sse.ps1 @@ -15,9 +15,9 @@ Start-PodeServer -Threads 3 { # open local sse connection, and send back data Add-PodeRoute -Method Get -Path '/data' -ScriptBlock { ConvertTo-PodeSseConnection -Name 'Data' -Scope Local - Send-PodeSseEvent -Id 1234 -EventType Action -Data 'hello, there!' + Send-PodeSseEvent -Id 1234 -EventType Action -Data 'hello, there!' -FromEvent Start-Sleep -Seconds 3 - Send-PodeSseEvent -Id 1337 -EventType BoldOne -Data 'general kenobi' + Send-PodeSseEvent -Id 1337 -EventType BoldOne -Data 'general kenobi' -FromEvent } # home page to get sse events diff --git a/examples/web-pages-https.ps1 b/examples/web-pages-https.ps1 index a62a40c28..5a61ef689 100644 --- a/examples/web-pages-https.ps1 +++ b/examples/web-pages-https.ps1 @@ -11,6 +11,7 @@ Import-Module "$($path)/src/Pode.psm1" -Force -ErrorAction Stop # create a server, flagged to generate a self-signed cert for dev/testing Start-PodeServer { + New-PodeLoggingMethod -Terminal | Enable-PodeErrorLogging -Levels Error # bind to ip/port and set as https with self-signed cert Add-PodeEndpoint -Address * -Port 8443 -Protocol Https -SelfSigned @@ -31,4 +32,8 @@ Start-PodeServer { Set-PodeResponseStatus -Code 500 } + Add-PodeRoute -Method 'GET' -Path '/test' -ScriptBlock { + Write-PodeTextResponse -Value (Get-Date) + } + } diff --git a/examples/web-pages.ps1 b/examples/web-pages.ps1 index 2d1dcb53b..76d7bf98d 100644 --- a/examples/web-pages.ps1 +++ b/examples/web-pages.ps1 @@ -10,22 +10,22 @@ Import-Module "$($path)/src/Pode.psm1" -Force -ErrorAction Stop # Import-Module Pode # create a server, and start listening on port 8085 -Start-PodeServer -Threads 2 -Verbose { +Start-PodeServer -Threads 1 -Verbose { # listen on localhost:8085 Add-PodeEndpoint -Address * -Port 8090 -Protocol Http -Name '8090Address' Add-PodeEndpoint -Address * -Port $Port -Protocol Http -Name '8085Address' -RedirectTo '8090Address' # allow the local ip and some other ips - Add-PodeAccessRule -Access Allow -Type IP -Values @('127.0.0.1', '[::1]') - Add-PodeAccessRule -Access Allow -Type IP -Values @('192.169.0.1', '192.168.0.2') + # Add-PodeAccessRule -Access Allow -Type IP -Values @('127.0.0.1', '[::1]') + # Add-PodeAccessRule -Access Allow -Type IP -Values @('192.169.0.1', '192.168.0.2') # deny an ip - Add-PodeAccessRule -Access Deny -Type IP -Values 10.10.10.10 - Add-PodeAccessRule -Access Deny -Type IP -Values '10.10.0.0/24' - Add-PodeAccessRule -Access Deny -Type IP -Values all + # Add-PodeAccessRule -Access Deny -Type IP -Values 10.10.10.10 + # Add-PodeAccessRule -Access Deny -Type IP -Values '10.10.0.0/24' + # Add-PodeAccessRule -Access Deny -Type IP -Values all # limit - Add-PodeLimitRule -Type IP -Values all -Limit 100 -Seconds 5 + # Add-PodeLimitRule -Type IP -Values all -Limit 100 -Seconds 5 # log requests to the terminal New-PodeLoggingMethod -Terminal -Batch 10 -BatchTimeout 10 | Enable-PodeRequestLogging diff --git a/src/Listener/Pode.csproj b/src/Listener/Pode.csproj index ad20a3158..ee7d5af90 100644 --- a/src/Listener/Pode.csproj +++ b/src/Listener/Pode.csproj @@ -2,5 +2,6 @@ netstandard2.0;net6.0;net8.0 $(NoWarn);SYSLIB0001 + 9.0 diff --git a/src/Listener/PodeConnector.cs b/src/Listener/PodeConnector.cs index be7a3d5a8..b75aff220 100644 --- a/src/Listener/PodeConnector.cs +++ b/src/Listener/PodeConnector.cs @@ -15,9 +15,8 @@ public class PodeConnector : IDisposable { CancellationToken = cancellationToken == default(CancellationToken) ? cancellationToken - : (new CancellationTokenSource()).Token; + : new CancellationTokenSource().Token; - // IsConnected = true; IsDisposed = false; } diff --git a/src/Listener/PodeContext.cs b/src/Listener/PodeContext.cs index a2de17fe5..66ed462a7 100644 --- a/src/Listener/PodeContext.cs +++ b/src/Listener/PodeContext.cs @@ -5,6 +5,7 @@ using System.Net.Sockets; using System.Security.Cryptography; using System.Threading; +using System.Threading.Tasks; namespace Pode { @@ -18,13 +19,9 @@ public class PodeContext : PodeProtocol, IDisposable public PodeSocket PodeSocket { get; private set; } public DateTime Timestamp { get; private set; } public Hashtable Data { get; private set; } + public string EndpointName => PodeSocket.Name; - public string EndpointName - { - get => PodeSocket.Name; - } - - private object _lockable = new object(); + private object _lockable = new(); private PodeContextState _state; public PodeContextState State @@ -39,73 +36,25 @@ private set } } - public bool CloseImmediately - { - get => (State == PodeContextState.Error + public bool CloseImmediately => State == PodeContextState.Error || State == PodeContextState.Closing || State == PodeContextState.Timeout - || Request.CloseImmediately); - } - - public new bool IsWebSocket - { - get => (base.IsWebSocket || (base.IsUnknown && PodeSocket.IsWebSocket)); - } - - public bool IsWebSocketUpgraded - { - get => (IsWebSocket && Request is PodeSignalRequest); - } - - public new bool IsSmtp - { - get => (base.IsSmtp || (base.IsUnknown && PodeSocket.IsSmtp)); - } - - public new bool IsHttp - { - get => (base.IsHttp || (base.IsUnknown && PodeSocket.IsHttp)); - } - - public PodeSmtpRequest SmtpRequest - { - get => (PodeSmtpRequest)Request; - } - - public PodeHttpRequest HttpRequest - { - get => (PodeHttpRequest)Request; - } + || Request.CloseImmediately; - public PodeSignalRequest SignalRequest - { - get => (PodeSignalRequest)Request; - } - - public bool IsKeepAlive - { - get => ((Request.IsKeepAlive && Response.SseScope != PodeSseScope.Local) || Response.SseScope == PodeSseScope.Global); - } - - public bool IsErrored - { - get => (State == PodeContextState.Error || State == PodeContextState.SslError); - } + public new bool IsWebSocket => base.IsWebSocket || (IsUnknown && PodeSocket.IsWebSocket); + public bool IsWebSocketUpgraded => IsWebSocket && Request is PodeSignalRequest; + public new bool IsSmtp => base.IsSmtp || (IsUnknown && PodeSocket.IsSmtp); + public new bool IsHttp => base.IsHttp || (IsUnknown && PodeSocket.IsHttp); - public bool IsTimeout - { - get => (State == PodeContextState.Timeout); - } + public PodeSmtpRequest SmtpRequest => (PodeSmtpRequest)Request; + public PodeHttpRequest HttpRequest => (PodeHttpRequest)Request; + public PodeSignalRequest SignalRequest => (PodeSignalRequest)Request; - public bool IsClosed - { - get => (State == PodeContextState.Closed); - } - - public bool IsOpened - { - get => (State == PodeContextState.Open); - } + public bool IsKeepAlive => (Request.IsKeepAlive && Response.SseScope != PodeSseScope.Local) || Response.SseScope == PodeSseScope.Global; + public bool IsErrored => State == PodeContextState.Error || State == PodeContextState.SslError; + public bool IsTimeout => State == PodeContextState.Timeout; + public bool IsClosed => State == PodeContextState.Closed; + public bool IsOpened => State == PodeContextState.Open; public CancellationTokenSource ContextTimeoutToken { get; private set; } private Timer TimeoutTimer; @@ -121,14 +70,17 @@ public PodeContext(Socket socket, PodeSocket podeSocket, PodeListener listener) Type = PodeProtocolType.Unknown; State = PodeContextState.New; + } + public async Task Initialise() + { NewResponse(); - NewRequest(); + await NewRequest().ConfigureAwait(false); } private void TimeoutCallback(object state) { - if (Response.SseEnabled) + if (Response.SseEnabled || Request.IsWebSocket) { return; } @@ -140,54 +92,51 @@ private void TimeoutCallback(object state) Request.Error = new HttpRequestException("Request timeout"); Request.Error.Data.Add("PodeStatusCode", 408); - this.Dispose(); + Dispose(); } private void NewResponse() { - Response = new PodeResponse(); - Response.SetContext(this); + Response = new PodeResponse(this); } - private void NewRequest() + private async Task NewRequest() { // create a new request switch (PodeSocket.Type) { case PodeProtocolType.Smtp: - Request = new PodeSmtpRequest(Socket, PodeSocket); + Request = new PodeSmtpRequest(Socket, PodeSocket, this); break; case PodeProtocolType.Tcp: - Request = new PodeTcpRequest(Socket, PodeSocket); + Request = new PodeTcpRequest(Socket, PodeSocket, this); break; default: - Request = new PodeHttpRequest(Socket, PodeSocket); + Request = new PodeHttpRequest(Socket, PodeSocket, this); break; } - Request.SetContext(this); - // attempt to open the request stream try { - Request.Open(); + await Request.Open(CancellationToken.None).ConfigureAwait(false); State = PodeContextState.Open; } catch (AggregateException aex) { PodeHelpers.HandleAggregateException(aex, Listener, PodeLoggingLevel.Debug, true); - State = (Request.InputStream == default(Stream) + State = Request.InputStream == default(Stream) ? PodeContextState.Error - : PodeContextState.SslError); + : PodeContextState.SslError; } catch (Exception ex) { PodeHelpers.WriteException(ex, Listener, PodeLoggingLevel.Debug); - State = (Request.InputStream == default(Stream) + State = Request.InputStream == default(Stream) ? PodeContextState.Error - : PodeContextState.SslError); + : PodeContextState.SslError; } // if request is SMTP or TCP, send ACK if available @@ -195,11 +144,11 @@ private void NewRequest() { if (PodeSocket.IsSmtp) { - SmtpRequest.SendAck(); + await SmtpRequest.SendAck().ConfigureAwait(false); } else if (PodeSocket.IsTcp && !string.IsNullOrWhiteSpace(PodeSocket.AcknowledgeMessage)) { - Response.WriteLine(PodeSocket.AcknowledgeMessage, true); + await Response.WriteLine(PodeSocket.AcknowledgeMessage, true).ConfigureAwait(false); } } } @@ -261,21 +210,17 @@ private void SetContextType() } } - public void RenewTimeoutToken() - { - ContextTimeoutToken = new CancellationTokenSource(); - } - public void CancelTimeout() { TimeoutTimer.Dispose(); } - public async void Receive() + public async Task Receive() { try { // start timeout + ContextTimeoutToken = new CancellationTokenSource(); TimeoutTimer = new Timer(TimeoutCallback, null, Listener.RequestTimeout * 1000, Timeout.Infinite); // start receiving @@ -283,9 +228,9 @@ public async void Receive() try { PodeHelpers.WriteErrorMessage($"Receiving request", Listener, PodeLoggingLevel.Verbose, this); - var close = await Request.Receive(ContextTimeoutToken.Token); + var close = await Request.Receive(ContextTimeoutToken.Token).ConfigureAwait(false); SetContextType(); - EndReceive(close); + await EndReceive(close).ConfigureAwait(false); } catch (OperationCanceledException) { } } @@ -293,11 +238,11 @@ public async void Receive() { PodeHelpers.WriteException(ex, Listener, PodeLoggingLevel.Debug); State = PodeContextState.Error; - PodeSocket.HandleContext(this); + await PodeSocket.HandleContext(this).ConfigureAwait(false); } } - public void EndReceive(bool close) + public async Task EndReceive(bool close) { State = close ? PodeContextState.Closing : PodeContextState.Received; if (close) @@ -305,7 +250,7 @@ public void EndReceive(bool close) Response.StatusCode = 400; } - PodeSocket.HandleContext(this); + await PodeSocket.HandleContext(this).ConfigureAwait(false); } public void StartReceive() @@ -316,7 +261,7 @@ public void StartReceive() PodeHelpers.WriteErrorMessage($"Socket listening", Listener, PodeLoggingLevel.Verbose, this); } - public void UpgradeWebSocket(string clientId = null) + public async Task UpgradeWebSocket(string clientId = null) { PodeHelpers.WriteErrorMessage($"Upgrading Websocket", Listener, PodeLoggingLevel.Verbose, this); @@ -355,7 +300,7 @@ public void UpgradeWebSocket(string clientId = null) } // send message to upgrade web socket - Response.Send(); + await Response.Send().ConfigureAwait(false); // add open web socket to listener var signal = new PodeSignal(this, HttpRequest.Url.AbsolutePath, clientId); @@ -373,10 +318,12 @@ public void Dispose(bool force) { lock (_lockable) { + PodeHelpers.WriteErrorMessage($"Disposing Context", Listener, PodeLoggingLevel.Verbose, this); Listener.RemoveProcessingContext(this); if (IsClosed) { + PodeSocket.RemovePendingSocket(Socket); Request.Dispose(); Response.Dispose(); ContextTimeoutToken.Dispose(); @@ -402,7 +349,7 @@ public void Dispose(bool force) // are we awaiting for more info? if (IsHttp) { - _awaitingBody = (HttpRequest.AwaitingBody && !IsErrored && !IsTimeout); + _awaitingBody = HttpRequest.AwaitingBody && !IsErrored && !IsTimeout; } // only send a response if Http @@ -410,11 +357,11 @@ public void Dispose(bool force) { if (IsTimeout) { - Response.SendTimeout(); + Response.SendTimeout().Wait(); } else { - Response.Send(); + Response.Send().Wait(); } } @@ -431,7 +378,7 @@ public void Dispose(bool force) if (Response.SseEnabled) { - Response.CloseSseConnection(); + Response.CloseSseConnection().Wait(); } Request.Dispose(); @@ -447,8 +394,13 @@ public void Dispose(bool force) // if keep-alive, or awaiting body, setup for re-receive if ((_awaitingBody || (IsKeepAlive && !IsErrored && !IsTimeout && !Response.SseEnabled)) && !force) { + PodeHelpers.WriteErrorMessage($"Re-receiving Request", Listener, PodeLoggingLevel.Verbose, this); StartReceive(); } + else + { + PodeSocket.RemovePendingSocket(Socket); + } } } } diff --git a/src/Listener/PodeEndpoint.cs b/src/Listener/PodeEndpoint.cs index 917dfb643..ceaf167d3 100644 --- a/src/Listener/PodeEndpoint.cs +++ b/src/Listener/PodeEndpoint.cs @@ -1,6 +1,8 @@ using System; using System.Net; using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; namespace Pode { @@ -52,7 +54,7 @@ public void Listen() Socket.Listen(int.MaxValue); } - public bool AcceptAsync(SocketAsyncEventArgs args) + public bool Accept(SocketAsyncEventArgs args) { if (IsDisposed) { @@ -66,7 +68,7 @@ public void Dispose() { IsDisposed = true; PodeSocket.CloseSocket(Socket); - Socket = default(Socket); + Socket = default; } public new bool Equals(object obj) diff --git a/src/Listener/PodeFileWatcher.cs b/src/Listener/PodeFileWatcher.cs index da1249cdb..941b78be6 100644 --- a/src/Listener/PodeFileWatcher.cs +++ b/src/Listener/PodeFileWatcher.cs @@ -16,10 +16,12 @@ public PodeFileWatcher(string name, string path, bool includeSubdirectories, int { Name = name; - FileWatcher = new RecoveringFileSystemWatcher(path); - FileWatcher.IncludeSubdirectories = includeSubdirectories; - FileWatcher.InternalBufferSize = internalBufferSize; - FileWatcher.NotifyFilter = notifyFilters; + FileWatcher = new RecoveringFileSystemWatcher(path) + { + IncludeSubdirectories = includeSubdirectories, + InternalBufferSize = internalBufferSize, + NotifyFilter = notifyFilters + }; EventsRegistered = new HashSet(); RegisterEvent(PodeFileWatcherChangeType.Errored); diff --git a/src/Listener/PodeForm.cs b/src/Listener/PodeForm.cs index 5bfe7e170..b740e06c6 100644 --- a/src/Listener/PodeForm.cs +++ b/src/Listener/PodeForm.cs @@ -202,7 +202,7 @@ private static bool IsLineBoundary(byte[] bytes, string boundary, Encoding conte return false; } - return (contentEncoding.GetString(bytes).StartsWith(boundary)); + return contentEncoding.GetString(bytes).StartsWith(boundary); } public static bool IsLineBoundary(string line, string boundary) diff --git a/src/Listener/PodeFormData.cs b/src/Listener/PodeFormData.cs index 09e1dcf76..487fe8dda 100644 --- a/src/Listener/PodeFormData.cs +++ b/src/Listener/PodeFormData.cs @@ -11,15 +11,17 @@ public class PodeFormData public string[] Values => _values.ToArray(); public int Count => _values.Count; - public bool IsSingular => (_values.Count == 1); - public bool IsEmpty => (_values.Count == 0); + public bool IsSingular => _values.Count == 1; + public bool IsEmpty => _values.Count == 0; public PodeFormData(string key, string value) { Key = key; - _values = new List(); - _values.Add(value); + _values = new List + { + value + }; } public void AddValue(string value) diff --git a/src/Listener/PodeHelpers.cs b/src/Listener/PodeHelpers.cs index bb68f6742..44a7e8f95 100644 --- a/src/Listener/PodeHelpers.cs +++ b/src/Listener/PodeHelpers.cs @@ -5,6 +5,8 @@ using System.Security.Cryptography; using System.Reflection; using System.Runtime.Versioning; +using System.Threading.Tasks; +using System.Threading; namespace Pode { @@ -13,12 +15,14 @@ public class PodeHelpers public static readonly string[] HTTP_METHODS = new string[] { "CONNECT", "DELETE", "GET", "HEAD", "MERGE", "OPTIONS", "PATCH", "POST", "PUT", "TRACE" }; public const string WEB_SOCKET_MAGIC_KEY = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; public readonly static char[] NEW_LINE_ARRAY = new char[] { '\r', '\n' }; + public readonly static char[] SPACE_ARRAY = new char[] { ' ' }; public const string NEW_LINE = "\r\n"; public const string NEW_LINE_UNIX = "\n"; public const int BYTE_SIZE = sizeof(byte); public const byte NEW_LINE_BYTE = 10; public const byte CARRIAGE_RETURN_BYTE = 13; public const byte DASH_BYTE = 45; + public const byte PERIOD_BYTE = 46; private static string _dotnet_version = string.Empty; private static bool _is_net_framework = false; @@ -26,7 +30,7 @@ public static bool IsNetFramework { get { - if (String.IsNullOrWhiteSpace(_dotnet_version)) + if (string.IsNullOrWhiteSpace(_dotnet_version)) { _dotnet_version = Assembly.GetEntryAssembly()?.GetCustomAttribute()?.FrameworkName ?? "Framework"; _is_net_framework = _dotnet_version.Equals("Framework", StringComparison.InvariantCultureIgnoreCase); @@ -71,7 +75,7 @@ public static bool IsNetFramework return true; } - PodeHelpers.WriteException(ex, connector, level); + WriteException(ex, connector, level); return false; }); } @@ -114,27 +118,50 @@ public static string NewGuid(int length = 16) { var bytes = new byte[length]; rnd.GetBytes(bytes); - return (new Guid(bytes)).ToString(); + return new Guid(bytes).ToString(); } } - public static void WriteTo(MemoryStream stream, byte[] array, int startIndex, int count = 0) + public static async Task WriteTo(MemoryStream stream, byte[] array, int startIndex, int count, CancellationToken cancellationToken) { + // Validate startIndex and count to avoid unnecessary work + if (startIndex < 0 || startIndex > array.Length) + { + throw new ArgumentOutOfRangeException(nameof(startIndex)); + } + if (count <= 0 || startIndex + count > array.Length) { count = array.Length - startIndex; } - stream.Write(array, startIndex, count); + // Perform the asynchronous write operation + if (count > 0) + { + await stream.WriteAsync(array, startIndex, count, cancellationToken).ConfigureAwait(false); + } } public static byte[] Slice(byte[] array, int startIndex, int count = 0) { + // Validate startIndex and adjust count if needed + if (startIndex < 0 || startIndex > array.Length) + { + throw new ArgumentOutOfRangeException(nameof(startIndex)); + } + + // If count is zero or less, or exceeds the array bounds, adjust it if (count <= 0 || startIndex + count > array.Length) { count = array.Length - startIndex; } + // If the count is zero, return an empty array + if (count == 0) + { + return Array.Empty(); + } + var newArray = new byte[count]; Buffer.BlockCopy(array, startIndex * BYTE_SIZE, newArray, 0, count * BYTE_SIZE); return newArray; diff --git a/src/Listener/PodeHttpRequest.cs b/src/Listener/PodeHttpRequest.cs index 54b6ffbd4..2d713eb67 100644 --- a/src/Listener/PodeHttpRequest.cs +++ b/src/Listener/PodeHttpRequest.cs @@ -5,10 +5,11 @@ using System.Net.Http; using System.Net.Sockets; using System.Text; -using System.Text.RegularExpressions; using System.Web; using System.Linq; using System.IO; +using System.Threading; +using System.Threading.Tasks; namespace Pode { @@ -58,17 +59,17 @@ public string Body public override bool CloseImmediately { - get => (string.IsNullOrWhiteSpace(HttpMethod) - || (IsWebSocket && !HttpMethod.Equals("GET", StringComparison.InvariantCultureIgnoreCase))); + get => string.IsNullOrWhiteSpace(HttpMethod) + || (IsWebSocket && !HttpMethod.Equals("GET", StringComparison.InvariantCultureIgnoreCase)); } public override bool IsProcessable { - get => (!CloseImmediately && !AwaitingBody); + get => !CloseImmediately && !AwaitingBody; } - public PodeHttpRequest(Socket socket, PodeSocket podeSocket) - : base(socket, podeSocket) + public PodeHttpRequest(Socket socket, PodeSocket podeSocket, PodeContext context) + : base(socket, podeSocket, context) { Protocol = "HTTP/1.1"; Type = PodeProtocolType.Http; @@ -85,12 +86,11 @@ protected override bool ValidateInput(byte[] bytes) // wait until we have the rest of the payload if (AwaitingBody) { - return (bytes.Length >= (ContentLength - BodyStream.Length)); + return bytes.Length >= (ContentLength - BodyStream.Length); } - var lf = (byte)10; var previousIndex = -1; - var index = Array.IndexOf(bytes, lf); + var index = Array.IndexOf(bytes, PodeHelpers.NEW_LINE_BYTE); // do we have a request line yet? if (index == -1) @@ -102,7 +102,8 @@ protected override bool ValidateInput(byte[] bytes) if (!IsRequestLineValid) { var reqLine = Encoding.GetString(bytes, 0, index).Trim(); - var reqMeta = Regex.Split(reqLine, "\\s+"); + var reqMeta = reqLine.Split(PodeHelpers.SPACE_ARRAY, StringSplitOptions.RemoveEmptyEntries); + if (reqMeta.Length != 3) { throw new HttpRequestException($"Invalid request line: {reqLine} [{reqMeta.Length}]"); @@ -115,27 +116,17 @@ protected override bool ValidateInput(byte[] bytes) while (true) { previousIndex = index; - index = Array.IndexOf(bytes, lf, index + 1); + index = Array.IndexOf(bytes, PodeHelpers.NEW_LINE_BYTE, index + 1); - if (index - previousIndex <= 2) - { - if (index - previousIndex == 1) - { - break; - } - - if (bytes[previousIndex + 1] == (byte)13) - { - break; - } - } - - if (index == bytes.Length - 1) + // If the difference between indexes indicates the end of headers, exit the loop + if (index == previousIndex + 1 || + (index > previousIndex + 1 && bytes[previousIndex + 1] == PodeHelpers.CARRIAGE_RETURN_BYTE)) { break; } - if (index == -1) + // Return false if LF not found and end of array is reached + if (index == -1 || index >= bytes.Length - 1) { return false; } @@ -146,7 +137,7 @@ protected override bool ValidateInput(byte[] bytes) return true; } - protected override bool Parse(byte[] bytes) + protected override async Task Parse(byte[] bytes, CancellationToken cancellationToken) { // if there are no bytes, return (0 bytes read means we can close the socket) if (bytes.Length == 0) @@ -168,14 +159,14 @@ protected override bool Parse(byte[] bytes) var reqLines = content.Split(new string[] { newline }, StringSplitOptions.None); content = string.Empty; - bodyIndex = ParseHeaders(reqLines, newline); + bodyIndex = ParseHeaders(reqLines); bodyIndex = reqLines.Take(bodyIndex).Sum(x => x.Length) + (bodyIndex * newline.Length); - reqLines = default(string[]); + reqLines = default; } // parse the body - ParseBody(bytes, newline, bodyIndex); - AwaitingBody = (ContentLength > 0 && BodyStream.Length < ContentLength && Error == default(HttpRequestException)); + await ParseBody(bytes, newline, bodyIndex, cancellationToken).ConfigureAwait(false); + AwaitingBody = ContentLength > 0 && BodyStream.Length < ContentLength && Error == default(HttpRequestException); if (!AwaitingBody) { @@ -184,21 +175,21 @@ protected override bool Parse(byte[] bytes) if (BodyStream != default(MemoryStream)) { BodyStream.Dispose(); - BodyStream = default(MemoryStream); + BodyStream = default; } } - return (!AwaitingBody); + return !AwaitingBody; } - private int ParseHeaders(string[] reqLines, string newline) + private int ParseHeaders(string[] reqLines) { // reset raw body - RawBody = default(byte[]); + RawBody = default; _body = string.Empty; // first line is method/url - var reqMeta = Regex.Split(reqLines[0].Trim(), "\\s+"); + var reqMeta = reqLines[0].Trim().Split(' '); if (reqMeta.Length != 3) { throw new HttpRequestException($"Invalid request line: {reqLines[0]} [{reqMeta.Length}]"); @@ -206,18 +197,18 @@ private int ParseHeaders(string[] reqLines, string newline) // http method HttpMethod = reqMeta[0].Trim(); - if (Array.IndexOf(PodeHelpers.HTTP_METHODS, HttpMethod) == -1) + if (!PodeHelpers.HTTP_METHODS.Contains(HttpMethod)) { throw new HttpRequestException($"Invalid request HTTP method: {HttpMethod}"); } // query string var reqQuery = reqMeta[1].Trim(); - var qmIndex = string.IsNullOrEmpty(reqQuery) ? 0 : reqQuery.IndexOf("?"); + var qmIndex = reqQuery.IndexOf("?"); QueryString = qmIndex > 0 - ? HttpUtility.ParseQueryString(reqQuery.Substring(qmIndex)) - : default(NameValueCollection); + ? HttpUtility.ParseQueryString(reqQuery.Substring(qmIndex + 1)) + : default; // http protocol version Protocol = (reqMeta[2] ?? "HTTP/1.1").Trim(); @@ -226,7 +217,7 @@ private int ParseHeaders(string[] reqLines, string newline) throw new HttpRequestException($"Invalid request version: {Protocol}"); } - ProtocolVersion = Regex.Split(Protocol, "/")[1]; + ProtocolVersion = Protocol.Split('/')[1]; // headers Headers = new Hashtable(StringComparer.InvariantCultureIgnoreCase); @@ -246,38 +237,41 @@ private int ParseHeaders(string[] reqLines, string newline) } h_index = h_line.IndexOf(":"); - h_name = h_line.Substring(0, h_index).Trim(); - h_value = h_line.Substring(h_index + 1).Trim(); - Headers.Add(h_name, h_value); + if (h_index > 0) + { + h_name = h_line.Substring(0, h_index).Trim(); + h_value = h_line.Substring(h_index + 1).Trim(); + Headers.Add(h_name, h_value); + } } // build required URI details - var _proto = (IsSsl ? "https" : "http"); - Host = $"{Headers["Host"]}"; - Url = new Uri($"{_proto}://{Host}{reqQuery}"); + var _proto = IsSsl ? "https" : "http"; + Host = Headers["Host"]?.ToString(); // check the host header - if (!Context.PodeSocket.CheckHostname(Host)) + if (string.IsNullOrWhiteSpace(Host) || !Context.PodeSocket.CheckHostname(Host)) { - throw new HttpRequestException($"Invalid request Host: {Host}"); + throw new HttpRequestException($"Invalid Host header: {Host}"); } + // build the URL + Url = new Uri($"{_proto}://{Host}{reqQuery}"); + // get the content length - var strContentLength = $"{Headers["Content-Length"]}"; - if (string.IsNullOrWhiteSpace(strContentLength)) + ContentLength = 0; + if (int.TryParse(Headers["Content-Length"]?.ToString(), out int _contentLength)) { - strContentLength = "0"; + ContentLength = _contentLength; } - ContentLength = int.Parse(strContentLength); - // set the transfer encoding - TransferEncoding = $"{Headers["Transfer-Encoding"]}"; + TransferEncoding = Headers["Transfer-Encoding"]?.ToString(); // set other default headers - UrlReferrer = $"{Headers["Referer"]}"; - UserAgent = $"{Headers["User-Agent"]}"; - ContentType = $"{Headers["Content-Type"]}"; + UrlReferrer = Headers["Referer"]?.ToString(); + UserAgent = Headers["User-Agent"]?.ToString(); + ContentType = Headers["Content-Type"]?.ToString(); // set content encoding ContentEncoding = System.Text.Encoding.UTF8; @@ -286,9 +280,9 @@ private int ParseHeaders(string[] reqLines, string newline) var atoms = ContentType.Split(';'); foreach (var atom in atoms) { - if (atom.Trim().ToLowerInvariant().StartsWith("charset")) + if (atom.Trim().StartsWith("charset", StringComparison.InvariantCultureIgnoreCase)) { - ContentEncoding = System.Text.Encoding.GetEncoding((atom.Split('=')[1].Trim())); + ContentEncoding = System.Text.Encoding.GetEncoding(atom.Split('=')[1].Trim()); break; } } @@ -301,30 +295,29 @@ private int ParseHeaders(string[] reqLines, string newline) } // do we have an SSE ClientId? - SseClientId = $"{Headers["X-Pode-Sse-Client-Id"]}"; + SseClientId = Headers["X-Pode-Sse-Client-Id"]?.ToString(); if (HasSseClientId) { - SseClientName = $"{Headers["X-Pode-Sse-Name"]}"; - SseClientGroup = $"{Headers["X-Pode-Sse-Group"]}"; + SseClientName = Headers["X-Pode-Sse-Name"]?.ToString(); + SseClientGroup = Headers["X-Pode-Sse-Group"]?.ToString(); } // keep-alive? - IsKeepAlive = (IsWebSocket || + IsKeepAlive = IsWebSocket || (Headers.ContainsKey("Connection") - && $"{Headers["Connection"]}".Equals("keep-alive", StringComparison.InvariantCultureIgnoreCase))); + && Headers["Connection"]?.ToString().Equals("keep-alive", StringComparison.InvariantCultureIgnoreCase) == true); // return index where body starts in req return bodyIndex; } - private void ParseBody(byte[] bytes, string newline, int start) + private async Task ParseBody(byte[] bytes, string newline, int start, CancellationToken cancellationToken) { - if (BodyStream == default(MemoryStream)) - { - BodyStream = new MemoryStream(); - } + // set the body stream + BodyStream ??= new MemoryStream(); - var isChunked = (!string.IsNullOrWhiteSpace(TransferEncoding) && TransferEncoding.Contains("chunked")); + // are we chunked? + var isChunked = !string.IsNullOrWhiteSpace(TransferEncoding) && TransferEncoding.Contains("chunked"); // if chunked, and we have a content-length, fail if (isChunked && ContentLength > 0) @@ -346,12 +339,7 @@ private void ParseBody(byte[] bytes, string newline, int start) // get index of newline char, read start>index bytes as HEX for length c_index = Array.IndexOf(bytes, (byte)newline[0], start); c_hexBytes = PodeHelpers.Slice(bytes, start, c_index - start); - - c_hex = string.Empty; - foreach (var b in c_hexBytes) - { - c_hex += (char)b; - } + c_hex = Encoding.GetString(c_hexBytes.ToArray()); // if no length, continue c_length = Convert.ToInt32(c_hex, 16); @@ -368,19 +356,19 @@ private void ParseBody(byte[] bytes, string newline, int start) start = (start + c_length - 1) + newline.Length + 1; } - PodeHelpers.WriteTo(BodyStream, c_rawBytes.ToArray(), 0, c_rawBytes.Count); + await PodeHelpers.WriteTo(BodyStream, c_rawBytes.ToArray(), 0, c_rawBytes.Count, cancellationToken).ConfigureAwait(false); } // else use content length else if (ContentLength > 0) { - PodeHelpers.WriteTo(BodyStream, bytes, start, ContentLength); + await PodeHelpers.WriteTo(BodyStream, bytes, start, ContentLength, cancellationToken).ConfigureAwait(false); } // else just read all else { - PodeHelpers.WriteTo(BodyStream, bytes, start); + await PodeHelpers.WriteTo(BodyStream, bytes, start, 0, cancellationToken).ConfigureAwait(false); } // check body size @@ -398,9 +386,20 @@ public void ParseFormData() Form = PodeForm.Parse(RawBody, ContentType, ContentEncoding); } + public override void PartialDispose() + { + if (BodyStream != default(MemoryStream)) + { + BodyStream.Dispose(); + BodyStream = default; + } + + base.PartialDispose(); + } + public override void Dispose() { - RawBody = default(byte[]); + RawBody = default; _body = string.Empty; if (BodyStream != default(MemoryStream)) diff --git a/src/Listener/PodeListener.cs b/src/Listener/PodeListener.cs index f5a1dc655..64d8c1a06 100644 --- a/src/Listener/PodeListener.cs +++ b/src/Listener/PodeListener.cs @@ -46,7 +46,7 @@ public bool ShowServerDetails } } - public PodeListener(CancellationToken cancellationToken = default(CancellationToken)) + public PodeListener(CancellationToken cancellationToken = default) : base(cancellationToken) { Sockets = new List(); @@ -77,12 +77,12 @@ private void Bind(PodeSocket socket) Sockets.Add(socket); } - public PodeContext GetContext(CancellationToken cancellationToken = default(CancellationToken)) + public PodeContext GetContext(CancellationToken cancellationToken = default) { return Contexts.Get(cancellationToken); } - public Task GetContextAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task GetContextAsync(CancellationToken cancellationToken = default) { return Contexts.GetAsync(cancellationToken); } @@ -137,7 +137,7 @@ public void AddSseConnection(PodeServerEvent sse) public void SendSseEvent(string name, string[] groups, string[] clientIds, string eventType, string data, string id = null) { - Task.Factory.StartNew(() => + Task.Run(async () => { if (!ServerEvents.ContainsKey(name)) { @@ -158,7 +158,7 @@ public void SendSseEvent(string name, string[] groups, string[] clientIds, strin if (ServerEvents[name][clientId].IsForGroup(groups)) { - ServerEvents[name][clientId].Context.Response.SendSseEvent(eventType, data, id); + await ServerEvents[name][clientId].Context.Response.SendSseEvent(eventType, data, id).ConfigureAwait(false); } } }, CancellationToken); @@ -166,7 +166,7 @@ public void SendSseEvent(string name, string[] groups, string[] clientIds, strin public void CloseSseConnection(string name, string[] groups, string[] clientIds) { - Task.Factory.StartNew(() => + Task.Run(async () => { if (!ServerEvents.ContainsKey(name)) { @@ -187,7 +187,7 @@ public void CloseSseConnection(string name, string[] groups, string[] clientIds) if (ServerEvents[name][clientId].IsForGroup(groups)) { - ServerEvents[name][clientId].Context.Response.CloseSseConnection(); + await ServerEvents[name][clientId].Context.Response.CloseSseConnection().ConfigureAwait(false); } } }, CancellationToken); @@ -211,12 +211,12 @@ public bool TestSseConnectionExists(string name, string clientId) return true; } - public PodeServerSignal GetServerSignal(CancellationToken cancellationToken = default(CancellationToken)) + public PodeServerSignal GetServerSignal(CancellationToken cancellationToken = default) { return ServerSignals.Get(cancellationToken); } - public Task GetServerSignalAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task GetServerSignalAsync(CancellationToken cancellationToken = default) { return ServerSignals.GetAsync(cancellationToken); } @@ -231,12 +231,12 @@ public void RemoveProcessingServerSignal(PodeServerSignal signal) ServerSignals.RemoveProcessing(signal); } - public PodeClientSignal GetClientSignal(CancellationToken cancellationToken = default(CancellationToken)) + public PodeClientSignal GetClientSignal(CancellationToken cancellationToken = default) { return ClientSignals.Get(cancellationToken); } - public Task GetClientSignalAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task GetClientSignalAsync(CancellationToken cancellationToken = default) { return ClientSignals.GetAsync(cancellationToken); } diff --git a/src/Listener/PodeReceiver.cs b/src/Listener/PodeReceiver.cs index c723e2e2c..21cff8f2c 100644 --- a/src/Listener/PodeReceiver.cs +++ b/src/Listener/PodeReceiver.cs @@ -20,8 +20,10 @@ public class PodeReceiver : PodeConnector Start(); } - public void ConnectWebSocket(string name, string url, string contentType) + public async Task ConnectWebSocket(string name, string url, string contentType) { + var socket = default(PodeWebSocket); + lock (WebSockets) { if (WebSockets.ContainsKey(name)) @@ -29,16 +31,16 @@ public void ConnectWebSocket(string name, string url, string contentType) throw new Exception($"WebSocket connection with name {name} already defined"); } - var socket = new PodeWebSocket(name, url, contentType); - socket.BindReceiver(this); - socket.Connect(); + socket = new PodeWebSocket(name, url, contentType, this); WebSockets.Add(name, socket); } + + await socket.Connect().ConfigureAwait(false); } public PodeWebSocket GetWebSocket(string name) { - return (WebSockets.ContainsKey(name) ? WebSockets[name] : default(PodeWebSocket)); + return WebSockets.ContainsKey(name) ? WebSockets[name] : default; } public void DisconnectWebSocket(string name) diff --git a/src/Listener/PodeRequest.cs b/src/Listener/PodeRequest.cs index e0633e3a7..bc3bfc07e 100644 --- a/src/Listener/PodeRequest.cs +++ b/src/Listener/PodeRequest.cs @@ -30,20 +30,14 @@ public class PodeRequest : PodeProtocol, IDisposable public SslPolicyErrors ClientCertificateErrors { get; set; } public SslProtocols Protocols { get; private set; } public HttpRequestException Error { get; set; } - public bool IsAborted => (Error != default(HttpRequestException)); + public bool IsAborted => Error != default(HttpRequestException); public bool IsDisposed { get; private set; } - public virtual string Address - { - get => (Context.PodeSocket.HasHostnames + public virtual string Address => Context.PodeSocket.HasHostnames ? $"{Context.PodeSocket.Hostname}:{((IPEndPoint)LocalEndPoint).Port}" - : $"{((IPEndPoint)LocalEndPoint).Address}:{((IPEndPoint)LocalEndPoint).Port}"); - } + : $"{((IPEndPoint)LocalEndPoint).Address}:{((IPEndPoint)LocalEndPoint).Port}"; - public virtual string Scheme - { - get => (SslUpgraded ? $"{Context.PodeSocket.Type}s" : $"{Context.PodeSocket.Type}"); - } + public virtual string Scheme => SslUpgraded ? $"{Context.PodeSocket.Type}s" : $"{Context.PodeSocket.Type}"; private Socket Socket; protected PodeContext Context; @@ -53,16 +47,17 @@ public virtual string Scheme private MemoryStream BufferStream; private const int BufferSize = 16384; - public PodeRequest(Socket socket, PodeSocket podeSocket) + public PodeRequest(Socket socket, PodeSocket podeSocket, PodeContext context) { Socket = socket; RemoteEndPoint = socket.RemoteEndPoint; LocalEndPoint = socket.LocalEndPoint; TlsMode = podeSocket.TlsMode; Certificate = podeSocket.Certificate; - IsSsl = (Certificate != default(X509Certificate)); + IsSsl = Certificate != default(X509Certificate); AllowClientCertificate = podeSocket.AllowClientCertificate; Protocols = podeSocket.Protocols; + Context = context; } public PodeRequest(PodeRequest request) @@ -81,7 +76,7 @@ public PodeRequest(PodeRequest request) TlsMode = request.TlsMode; } - public void Open() + public async Task Open(CancellationToken cancellationToken) { // open the socket's stream InputStream = new NetworkStream(Socket, true); @@ -92,20 +87,41 @@ public void Open() } // otherwise, convert the stream to an ssl stream - UpgradeToSSL(); + await UpgradeToSSL(cancellationToken).ConfigureAwait(false); } - public void UpgradeToSSL() + public async Task UpgradeToSSL(CancellationToken cancellationToken) { + // if we've already upgraded, return if (SslUpgraded) { return; } + // create the ssl stream var ssl = new SslStream(InputStream, false, new RemoteCertificateValidationCallback(ValidateCertificateCallback)); - ssl.AuthenticateAsServerAsync(Certificate, AllowClientCertificate, Protocols, false).Wait(Context.Listener.CancellationToken); - InputStream = ssl; - SslUpgraded = true; + + using (cancellationToken.Register(() => ssl.Dispose())) + { + try + { + // authenticate the stream + await ssl.AuthenticateAsServerAsync(Certificate, AllowClientCertificate, Protocols, false).ConfigureAwait(false); + + // if we've upgraded, set the stream + InputStream = ssl; + SslUpgraded = true; + } + catch (OperationCanceledException) { } + catch (IOException) { } + catch (ObjectDisposedException) { } + catch (Exception ex) + { + PodeHelpers.WriteException(ex, Context.Listener, PodeLoggingLevel.Error); + Error = new HttpRequestException(ex.Message, ex); + Error.Data.Add("PodeStatusCode", 502); + } + } } private bool ValidateCertificateCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) @@ -113,69 +129,71 @@ private bool ValidateCertificateCallback(object sender, X509Certificate certific ClientCertificateErrors = sslPolicyErrors; ClientCertificate = certificate == default(X509Certificate) - ? default(X509Certificate2) + ? default : new X509Certificate2(certificate); return true; } - protected async Task BeginRead(byte[] buffer, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - return await Task.Factory.FromAsync(InputStream.BeginRead, InputStream.EndRead, buffer, 0, BufferSize, null); - } - public async Task Receive(CancellationToken cancellationToken) { try { - Error = default(HttpRequestException); + Error = default; Buffer = new byte[BufferSize]; - BufferStream = new MemoryStream(); - - var read = 0; - var close = true; - - while ((read = await BeginRead(Buffer, cancellationToken)) > 0) + using (BufferStream = new MemoryStream()) { - cancellationToken.ThrowIfCancellationRequested(); - BufferStream.Write(Buffer, 0, read); + var close = true; - if (Socket.Available > 0 || !ValidateInput(BufferStream.ToArray())) + while (true) { - continue; - } - - if (!Parse(BufferStream.ToArray())) - { - BufferStream.Dispose(); - BufferStream = new MemoryStream(); - continue; + // read the input stream + var read = await InputStream.ReadAsync(Buffer, 0, BufferSize, cancellationToken).ConfigureAwait(false); + if (read <= 0) + { + break; + } + + // write the buffer to the stream + await BufferStream.WriteAsync(Buffer, 0, read, cancellationToken).ConfigureAwait(false); + + // if we have more data, or the input is invalid, continue + if (Socket.Available > 0 || !ValidateInput(BufferStream.ToArray())) + { + continue; + } + + // parse the buffer + if (!await Parse(BufferStream.ToArray(), cancellationToken).ConfigureAwait(false)) + { + BufferStream.SetLength(0); + continue; + } + + close = false; + break; } - close = false; - break; + return close; } - - cancellationToken.ThrowIfCancellationRequested(); - return close; } + catch (OperationCanceledException) { } + catch (IOException) { } catch (HttpRequestException httpex) { + PodeHelpers.WriteException(httpex, Context.Listener, PodeLoggingLevel.Error); Error = httpex; } catch (Exception ex) { - cancellationToken.ThrowIfCancellationRequested(); + PodeHelpers.WriteException(ex, Context.Listener, PodeLoggingLevel.Error); Error = new HttpRequestException(ex.Message, ex); Error.Data.Add("PodeStatusCode", 400); } finally { - BufferStream.Dispose(); - BufferStream = default(MemoryStream); - Buffer = default(byte[]); + PartialDispose(); } return false; @@ -184,16 +202,21 @@ public async Task Receive(CancellationToken cancellationToken) public async Task Read(byte[] checkBytes, CancellationToken cancellationToken) { var buffer = new byte[BufferSize]; - var bufferStream = new MemoryStream(); - - try + using (var bufferStream = new MemoryStream()) { - var read = 0; - while ((read = await BeginRead(buffer, cancellationToken)) > 0) + while (true) { - cancellationToken.ThrowIfCancellationRequested(); - bufferStream.Write(buffer, 0, read); + // read the input stream + var read = await InputStream.ReadAsync(buffer, 0, BufferSize, cancellationToken).ConfigureAwait(false); + if (read <= 0) + { + break; + } + + // write the buffer to the stream + await bufferStream.WriteAsync(buffer, 0, read, cancellationToken).ConfigureAwait(false); + // if we have more data, or the input is invalid, continue if (Socket.Available > 0 || !ValidateInputInternal(bufferStream.ToArray(), checkBytes)) { continue; @@ -202,15 +225,8 @@ public async Task Read(byte[] checkBytes, CancellationToken cancellation break; } - cancellationToken.ThrowIfCancellationRequested(); return Encoding.GetString(bufferStream.ToArray()).Trim(); } - finally - { - bufferStream.Dispose(); - bufferStream = default(MemoryStream); - buffer = default(byte[]); - } } private bool ValidateInputInternal(byte[] bytes, byte[] checkBytes) @@ -245,7 +261,7 @@ private bool ValidateInputInternal(byte[] bytes, byte[] checkBytes) return true; } - protected virtual bool Parse(byte[] bytes) + protected virtual Task Parse(byte[] bytes, CancellationToken cancellationToken) { throw new NotImplementedException(); } @@ -255,9 +271,15 @@ protected virtual bool ValidateInput(byte[] bytes) return true; } - public void SetContext(PodeContext context) + public virtual void PartialDispose() { - Context = context; + if (BufferStream != default(MemoryStream)) + { + BufferStream.Dispose(); + BufferStream = default; + } + + Buffer = default; } public virtual void Dispose() @@ -277,15 +299,10 @@ public virtual void Dispose() if (InputStream != default(Stream)) { InputStream.Dispose(); - InputStream = default(Stream); - } - - if (BufferStream != default(MemoryStream)) - { - BufferStream.Dispose(); - BufferStream = default(MemoryStream); + InputStream = default; } + PartialDispose(); PodeHelpers.WriteErrorMessage($"Request disposed", Context.Listener, PodeLoggingLevel.Verbose, Context); } } diff --git a/src/Listener/PodeResponse.cs b/src/Listener/PodeResponse.cs index 3c42a9865..786545994 100644 --- a/src/Listener/PodeResponse.cs +++ b/src/Listener/PodeResponse.cs @@ -5,6 +5,7 @@ using System.IO; using System.Net; using System.Text; +using System.Threading.Tasks; namespace Pode { @@ -77,13 +78,14 @@ public string HttpResponseLine private static UTF8Encoding Encoding = new UTF8Encoding(); - public PodeResponse() + public PodeResponse(PodeContext context) { Headers = new PodeResponseHeaders(); OutputStream = new MemoryStream(); + Context = context; } - public void Send() + public async Task Send() { if (Sent || IsDisposed || (SentHeaders && SseEnabled)) { @@ -94,12 +96,12 @@ public void Send() try { - SendHeaders(Context.IsTimeout); - SendBody(Context.IsTimeout); + await SendHeaders(Context.IsTimeout).ConfigureAwait(false); + await SendBody(Context.IsTimeout).ConfigureAwait(false); PodeHelpers.WriteErrorMessage($"Response sent", Context.Listener, PodeLoggingLevel.Verbose, Context); } - catch (OperationCanceledException) {} - catch (IOException) {} + catch (OperationCanceledException) { } + catch (IOException) { } catch (AggregateException aex) { PodeHelpers.HandleAggregateException(aex, Context.Listener); @@ -111,11 +113,11 @@ public void Send() } finally { - Flush(); + await Flush().ConfigureAwait(false); } } - public void SendTimeout() + public async Task SendTimeout() { if (SentHeaders || IsDisposed) { @@ -127,11 +129,11 @@ public void SendTimeout() try { - SendHeaders(true); + await SendHeaders(true).ConfigureAwait(false); PodeHelpers.WriteErrorMessage($"Response timed-out sent", Context.Listener, PodeLoggingLevel.Verbose, Context); } - catch (OperationCanceledException) {} - catch (IOException) {} + catch (OperationCanceledException) { } + catch (IOException) { } catch (AggregateException aex) { PodeHelpers.HandleAggregateException(aex, Context.Listener); @@ -143,11 +145,11 @@ public void SendTimeout() } finally { - Flush(); + await Flush().ConfigureAwait(false); } } - private void SendHeaders(bool timeout) + private async Task SendHeaders(bool timeout) { if (SentHeaders || !Request.InputStream.CanWrite) { @@ -164,12 +166,12 @@ private void SendHeaders(bool timeout) // stream response output var buffer = Encoding.GetBytes(BuildHeaders(Headers)); - Request.InputStream.WriteAsync(buffer, 0, buffer.Length).Wait(Context.Listener.CancellationToken); - buffer = default(byte[]); + await Request.InputStream.WriteAsync(buffer, 0, buffer.Length, Context.Listener.CancellationToken).ConfigureAwait(false); + buffer = default; SentHeaders = true; } - private void SendBody(bool timeout) + private async Task SendBody(bool timeout) { if (SentBody || SseEnabled || !Request.InputStream.CanWrite) { @@ -179,21 +181,21 @@ private void SendBody(bool timeout) // stream response output if (!timeout && OutputStream.Length > 0) { - OutputStream.WriteTo(Request.InputStream); + await Task.Run(() => OutputStream.WriteTo(Request.InputStream), Context.Listener.CancellationToken).ConfigureAwait(false); } SentBody = true; } - public void Flush() + public async Task Flush() { if (Request.InputStream.CanWrite) { - Request.InputStream.Flush(); + await Request.InputStream.FlushAsync().ConfigureAwait(false); } } - public string SetSseConnection(PodeSseScope scope, string clientId, string name, string group, int retry, bool allowAllOrigins) + public async Task SetSseConnection(PodeSseScope scope, string clientId, string name, string group, int retry, bool allowAllOrigins) { // do nothing for no scope if (scope == PodeSseScope.None) @@ -231,9 +233,9 @@ public string SetSseConnection(PodeSseScope scope, string clientId, string name, } // send headers, and open event - Send(); - SendSseRetry(retry); - SendSseEvent("pode.open", $"{{\"clientId\":\"{clientId}\",\"group\":\"{group}\",\"name\":\"{name}\"}}"); + await Send().ConfigureAwait(false); + await SendSseRetry(retry).ConfigureAwait(false); + await SendSseEvent("pode.open", $"{{\"clientId\":\"{clientId}\",\"group\":\"{group}\",\"name\":\"{name}\"}}").ConfigureAwait(false); // if global, cache connection in listener if (scope == PodeSseScope.Global) @@ -245,60 +247,60 @@ public string SetSseConnection(PodeSseScope scope, string clientId, string name, return clientId; } - public void CloseSseConnection() + public async Task CloseSseConnection() { - SendSseEvent("pode.close", string.Empty); + await SendSseEvent("pode.close", string.Empty).ConfigureAwait(false); } - public void SendSseEvent(string eventType, string data, string id = null) + public async Task SendSseEvent(string eventType, string data, string id = null) { if (!string.IsNullOrEmpty(id)) { - WriteLine($"id: {id}"); + await WriteLine($"id: {id}").ConfigureAwait(false); } if (!string.IsNullOrEmpty(eventType)) { - WriteLine($"event: {eventType}"); + await WriteLine($"event: {eventType}").ConfigureAwait(false); } - WriteLine($"data: {data}{PodeHelpers.NEW_LINE}", true); + await WriteLine($"data: {data}{PodeHelpers.NEW_LINE}", true).ConfigureAwait(false); } - public void SendSseRetry(int retry) + public async Task SendSseRetry(int retry) { if (retry <= 0) { return; } - WriteLine($"retry: {retry}", true); + await WriteLine($"retry: {retry}", true).ConfigureAwait(false); } - public void SendSignal(PodeServerSignal signal) + public async Task SendSignal(PodeServerSignal signal) { if (!string.IsNullOrEmpty(signal.Value)) { - Write(signal.Value); + await Write(signal.Value).ConfigureAwait(false); } } - public void Write(string message, bool flush = false) + public async Task Write(string message, bool flush = false) { // simple messages if (!Context.IsWebSocket) { - Write(Encoding.GetBytes(message), flush); + await Write(Encoding.GetBytes(message), flush).ConfigureAwait(false); } // web socket message else { - WriteFrame(message, PodeWsOpCode.Text, flush); + await WriteFrame(message, PodeWsOpCode.Text, flush).ConfigureAwait(false); } } - public void WriteFrame(string message, PodeWsOpCode opCode = PodeWsOpCode.Text, bool flush = false) + public async Task WriteFrame(string message, PodeWsOpCode opCode = PodeWsOpCode.Text, bool flush = false) { if (IsDisposed) { @@ -332,15 +334,15 @@ public void WriteFrame(string message, PodeWsOpCode opCode = PodeWsOpCode.Text, } buffer.AddRange(msgBytes); - Write(buffer.ToArray(), flush); + await Write(buffer.ToArray(), flush).ConfigureAwait(false); } - public void WriteLine(string message, bool flush = false) + public async Task WriteLine(string message, bool flush = false) { - Write(Encoding.GetBytes($"{message}{PodeHelpers.NEW_LINE}"), flush); + await Write(Encoding.GetBytes($"{message}{PodeHelpers.NEW_LINE}"), flush).ConfigureAwait(false); } - public void Write(byte[] buffer, bool flush = false) + public async Task Write(byte[] buffer, bool flush = false) { if (Request.IsDisposed || !Request.InputStream.CanWrite) { @@ -349,15 +351,15 @@ public void Write(byte[] buffer, bool flush = false) try { - Request.InputStream.WriteAsync(buffer, 0, buffer.Length).Wait(Context.Listener.CancellationToken); + await Request.InputStream.WriteAsync(buffer, 0, buffer.Length, Context.Listener.CancellationToken).ConfigureAwait(false); if (flush) { - Flush(); + await Flush().ConfigureAwait(false); } } - catch (OperationCanceledException) {} - catch (IOException) {} + catch (OperationCanceledException) { } + catch (IOException) { } catch (AggregateException aex) { PodeHelpers.HandleAggregateException(aex, Context.Listener); @@ -445,11 +447,6 @@ private string BuildHeaders(PodeResponseHeaders headers) return builder.ToString(); } - public void SetContext(PodeContext context) - { - Context = context; - } - public void Dispose() { if (IsDisposed) @@ -462,7 +459,7 @@ public void Dispose() if (OutputStream != default(MemoryStream)) { OutputStream.Dispose(); - OutputStream = default(MemoryStream); + OutputStream = default; } PodeHelpers.WriteErrorMessage($"Response disposed", Context.Listener, PodeLoggingLevel.Verbose, Context); diff --git a/src/Listener/PodeSignalRequest.cs b/src/Listener/PodeSignalRequest.cs index 903a7f35e..f7acc49cb 100644 --- a/src/Listener/PodeSignalRequest.cs +++ b/src/Listener/PodeSignalRequest.cs @@ -1,5 +1,7 @@ using System; using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; namespace Pode { @@ -27,12 +29,12 @@ public string CloseDescription public override bool CloseImmediately { - get => (OpCode == PodeWsOpCode.Close); + get => OpCode == PodeWsOpCode.Close; } public override bool IsProcessable { - get => (!CloseImmediately && OpCode != PodeWsOpCode.Pong && OpCode != PodeWsOpCode.Ping && !string.IsNullOrEmpty(Body)); + get => !CloseImmediately && OpCode != PodeWsOpCode.Pong && OpCode != PodeWsOpCode.Ping && !string.IsNullOrEmpty(Body); } public PodeSignalRequest(PodeHttpRequest request, PodeSignal signal) @@ -42,7 +44,7 @@ public PodeSignalRequest(PodeHttpRequest request, PodeSignal signal) IsKeepAlive = true; Type = PodeProtocolType.Ws; - var _proto = (IsSsl ? "wss" : "ws"); + var _proto = IsSsl ? "wss" : "ws"; Host = request.Host; Url = new Uri($"{_proto}://{request.Url.Authority}{request.Url.PathAndQuery}"); } @@ -52,7 +54,7 @@ public PodeClientSignal NewClientSignal() return new PodeClientSignal(Signal, Body, Context.Listener); } - protected override bool Parse(byte[] bytes) + protected override async Task Parse(byte[] bytes, CancellationToken cancellationToken) { // get the length and op-code var dataLength = bytes[1] - 128; @@ -118,7 +120,7 @@ protected override bool Parse(byte[] bytes) // send back a pong case PodeWsOpCode.Ping: - Context.Response.WriteFrame(string.Empty, PodeWsOpCode.Pong); + await Context.Response.WriteFrame(string.Empty, PodeWsOpCode.Pong).ConfigureAwait(false); break; } @@ -131,7 +133,7 @@ public override void Dispose() if (!IsDisposed) { PodeHelpers.WriteErrorMessage($"Closing Websocket", Context.Listener, PodeLoggingLevel.Verbose, Context); - Context.Response.WriteFrame(string.Empty, PodeWsOpCode.Close); + Context.Response.WriteFrame(string.Empty, PodeWsOpCode.Close).Wait(); } // remove client, and dispose diff --git a/src/Listener/PodeSmtpRequest.cs b/src/Listener/PodeSmtpRequest.cs index 8039d45f3..6bb1290d5 100644 --- a/src/Listener/PodeSmtpRequest.cs +++ b/src/Listener/PodeSmtpRequest.cs @@ -8,6 +8,8 @@ using System.Globalization; using _Encoding = System.Text.Encoding; using System.IO; +using System.Threading.Tasks; +using System.Threading; namespace Pode { @@ -29,17 +31,17 @@ public class PodeSmtpRequest : PodeRequest public override bool CloseImmediately { - get => (Command == PodeSmtpCommand.None || Command == PodeSmtpCommand.Quit); + get => Command == PodeSmtpCommand.None || Command == PodeSmtpCommand.Quit; } private bool _canProcess = false; public override bool IsProcessable { - get => (!CloseImmediately && _canProcess); + get => !CloseImmediately && _canProcess; } - public PodeSmtpRequest(Socket socket, PodeSocket podeSocket) - : base(socket, podeSocket) + public PodeSmtpRequest(Socket socket, PodeSocket podeSocket, PodeContext context) + : base(socket, podeSocket, context) { _canProcess = false; IsKeepAlive = true; @@ -58,13 +60,13 @@ private bool IsCommand(string content, string command) return content.StartsWith(command, true, CultureInfo.InvariantCulture); } - public void SendAck() + public async Task SendAck() { var ack = string.IsNullOrWhiteSpace(Context.PodeSocket.AcknowledgeMessage) ? $"{Context.PodeSocket.Hostname} -- Pode Proxy Server" : Context.PodeSocket.AcknowledgeMessage; - Context.Response.WriteLine($"220 {ack}", true); + await Context.Response.WriteLine($"220 {ack}", true).ConfigureAwait(false); } protected override bool ValidateInput(byte[] bytes) @@ -83,15 +85,15 @@ protected override bool ValidateInput(byte[] bytes) return false; } - return (bytes[bytes.Length - 3] == (byte)46 - && bytes[bytes.Length - 2] == (byte)13 - && bytes[bytes.Length - 1] == (byte)10); + return bytes[bytes.Length - 3] == PodeHelpers.PERIOD_BYTE + && bytes[bytes.Length - 2] == PodeHelpers.CARRIAGE_RETURN_BYTE + && bytes[bytes.Length - 1] == PodeHelpers.NEW_LINE_BYTE; } return true; } - protected override bool Parse(byte[] bytes) + protected override async Task Parse(byte[] bytes, CancellationToken cancellationToken) { // if there are no bytes, return (0 bytes read means we can close the socket) if (bytes.Length == 0) @@ -107,7 +109,7 @@ protected override bool Parse(byte[] bytes) if (string.IsNullOrWhiteSpace(content)) { Command = PodeSmtpCommand.None; - Context.Response.WriteLine("501 Invalid command received", true); + await Context.Response.WriteLine("501 Invalid command received", true).ConfigureAwait(false); return true; } @@ -115,7 +117,7 @@ protected override bool Parse(byte[] bytes) if (IsCommand(content, "QUIT")) { Command = PodeSmtpCommand.Quit; - Context.Response.WriteLine("221 OK", true); + await Context.Response.WriteLine("221 OK", true).ConfigureAwait(false); return true; } @@ -123,7 +125,7 @@ protected override bool Parse(byte[] bytes) if (StartType == PodeSmtpStartType.Ehlo && TlsMode == PodeTlsMode.Explicit && !SslUpgraded && !IsCommand(content, "STARTTLS")) { Command = PodeSmtpCommand.None; - Context.Response.WriteLine("530 Must issue a STARTTLS command first", true); + await Context.Response.WriteLine("530 Must issue a STARTTLS command first", true).ConfigureAwait(false); return true; } @@ -132,7 +134,7 @@ protected override bool Parse(byte[] bytes) { Command = PodeSmtpCommand.Helo; StartType = PodeSmtpStartType.Helo; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); return true; } @@ -141,14 +143,14 @@ protected override bool Parse(byte[] bytes) { Command = PodeSmtpCommand.Ehlo; StartType = PodeSmtpStartType.Ehlo; - Context.Response.WriteLine($"250-{Context.PodeSocket.Hostname} hello there", true); + await Context.Response.WriteLine($"250-{Context.PodeSocket.Hostname} hello there", true).ConfigureAwait(false); if (TlsMode == PodeTlsMode.Explicit && !SslUpgraded) { - Context.Response.WriteLine("250-STARTTLS", true); + await Context.Response.WriteLine("250-STARTTLS", true).ConfigureAwait(false); } - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); return true; } @@ -158,14 +160,14 @@ protected override bool Parse(byte[] bytes) if (TlsMode != PodeTlsMode.Explicit) { Command = PodeSmtpCommand.None; - Context.Response.WriteLine("501 SMTP server not running on Explicit TLS for the STARTTLS command", true); + await Context.Response.WriteLine("501 SMTP server not running on Explicit TLS for the STARTTLS command", true).ConfigureAwait(false); return true; } Reset(); Command = PodeSmtpCommand.StartTls; - Context.Response.WriteLine("220 Ready to start TLS"); - UpgradeToSSL(); + await Context.Response.WriteLine("220 Ready to start TLS").ConfigureAwait(false); + await UpgradeToSSL(cancellationToken).ConfigureAwait(false); return true; } @@ -174,7 +176,7 @@ protected override bool Parse(byte[] bytes) { Reset(); Command = PodeSmtpCommand.Reset; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); return true; } @@ -182,7 +184,7 @@ protected override bool Parse(byte[] bytes) if (IsCommand(content, "NOOP")) { Command = PodeSmtpCommand.NoOp; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); return true; } @@ -190,7 +192,7 @@ protected override bool Parse(byte[] bytes) if (IsCommand(content, "RCPT TO")) { Command = PodeSmtpCommand.RcptTo; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); To.Add(ParseEmail(content)); return true; } @@ -199,7 +201,7 @@ protected override bool Parse(byte[] bytes) if (IsCommand(content, "MAIL FROM")) { Command = PodeSmtpCommand.MailFrom; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); From = ParseEmail(content); return true; } @@ -208,7 +210,7 @@ protected override bool Parse(byte[] bytes) if (IsCommand(content, "DATA")) { Command = PodeSmtpCommand.Data; - Context.Response.WriteLine("354 Start mail input; end with .", true); + await Context.Response.WriteLine("354 Start mail input; end with .", true).ConfigureAwait(false); return true; } @@ -217,17 +219,17 @@ protected override bool Parse(byte[] bytes) { case PodeSmtpCommand.Data: _canProcess = true; - Context.Response.WriteLine("250 OK", true); + await Context.Response.WriteLine("250 OK", true).ConfigureAwait(false); RawBody = bytes; Attachments = new List(); // parse the headers Headers = ParseHeaders(content); - Subject = $"{Headers["Subject"]}"; - IsUrgent = ($"{Headers["Priority"]}".Equals("urgent", StringComparison.InvariantCultureIgnoreCase) || $"{Headers["Importance"]}".Equals("high", StringComparison.InvariantCultureIgnoreCase)); - ContentEncoding = $"{Headers["Content-Transfer-Encoding"]}"; + Subject = Headers["Subject"]?.ToString(); + IsUrgent = $"{Headers["Priority"]}".Equals("urgent", StringComparison.InvariantCultureIgnoreCase) || $"{Headers["Importance"]}".Equals("high", StringComparison.InvariantCultureIgnoreCase); + ContentEncoding = Headers["Content-Transfer-Encoding"]?.ToString(); - ContentType = $"{Headers["Content-Type"]}"; + ContentType = Headers["Content-Type"]?.ToString(); if (!string.IsNullOrEmpty(Boundary) && !ContentType.Contains("boundary=")) { ContentType = ContentType.TrimEnd(';'); @@ -249,7 +251,7 @@ protected override bool Parse(byte[] bytes) else { Command = PodeSmtpCommand.None; - Context.Response.WriteLine("501 Invalid DATA received", true); + await Context.Response.WriteLine("501 Invalid DATA received", true).ConfigureAwait(false); return true; } break; @@ -270,7 +272,7 @@ public void Reset() From = string.Empty; To = new List(); Body = string.Empty; - RawBody = default(byte[]); + RawBody = default; Command = PodeSmtpCommand.None; ContentType = string.Empty; ContentEncoding = string.Empty; @@ -365,7 +367,7 @@ private Hashtable ParseHeaders(string value) private bool IsBodyValid(string value) { var lines = value.Split(new string[] { PodeHelpers.NEW_LINE }, StringSplitOptions.None); - return (Array.LastIndexOf(lines, ".") > -1); + return Array.LastIndexOf(lines, ".") > -1; } private void ParseBoundary() @@ -464,7 +466,7 @@ private byte[] ConvertBodyEncoding(string body, string contentEncoding) var match = default(Match); while ((match = Regex.Match(body, "(?=(?[0-9A-F]{2}))")).Success) { - body = (body.Replace(match.Groups["code"].Value, $"{(char)Convert.ToInt32(match.Groups["hex"].Value, 16)}")); + body = body.Replace(match.Groups["code"].Value, $"{(char)Convert.ToInt32(match.Groups["hex"].Value, 16)}"); } return _Encoding.UTF8.GetBytes(body); @@ -516,7 +518,7 @@ private string ConvertBodyType(byte[] bytes, string contentType) public override void Dispose() { - RawBody = default(byte[]); + RawBody = default; Body = string.Empty; if (Attachments != default(List)) diff --git a/src/Listener/PodeSocket.cs b/src/Listener/PodeSocket.cs index 8f8b5fd7d..7c337c9e3 100644 --- a/src/Listener/PodeSocket.cs +++ b/src/Listener/PodeSocket.cs @@ -25,15 +25,11 @@ public class PodeSocket : PodeProtocol, IDisposable public bool DualMode { get; private set; } private ConcurrentQueue AcceptConnections; - private ConcurrentQueue ReceiveConnections; private IDictionary PendingSockets; private PodeListener Listener; - public bool IsSsl - { - get => Certificate != default(X509Certificate); - } + public bool IsSsl => Certificate != default(X509Certificate); private int _receiveTimeout; public int ReceiveTimeout @@ -50,11 +46,7 @@ public int ReceiveTimeout } public bool HasHostnames => Hostnames.Any(); - - public string Hostname - { - get => HasHostnames ? Hostnames[0] : Endpoints[0].IPAddress.ToString(); - } + public string Hostname => HasHostnames ? Hostnames[0] : Endpoints[0].IPAddress.ToString(); public PodeSocket(string name, IPAddress[] ipAddress, int port, SslProtocols protocols, PodeProtocolType type, X509Certificate certificate = null, bool allowClientCertificate = false, PodeTlsMode tlsMode = PodeTlsMode.Implicit, bool dualMode = false) : base(type) @@ -68,7 +60,6 @@ public PodeSocket(string name, IPAddress[] ipAddress, int port, SslProtocols pro DualMode = dualMode; AcceptConnections = new ConcurrentQueue(); - ReceiveConnections = new ConcurrentQueue(); PendingSockets = new Dictionary(); Endpoints = new List(); @@ -95,13 +86,13 @@ public void Start() { foreach (var ep in Endpoints) { - StartEndpoint(ep); + _ = Task.Run(() => StartEndpoint(ep), Listener.CancellationToken); } } private void StartEndpoint(PodeEndpoint endpoint) { - if (endpoint.IsDisposed) + if (endpoint.IsDisposed || Listener.CancellationToken.IsCancellationRequested) { return; } @@ -117,7 +108,7 @@ private void StartEndpoint(PodeEndpoint endpoint) try { - raised = endpoint.AcceptAsync(args); + raised = endpoint.Accept(args); } catch (ObjectDisposedException) { @@ -130,49 +121,46 @@ private void StartEndpoint(PodeEndpoint endpoint) } } - private void StartReceive(Socket acceptedSocket) + private async Task StartReceive(Socket acceptedSocket) { + // add the socket to pending + AddPendingSocket(acceptedSocket); + + // create the context var context = new PodeContext(acceptedSocket, this, Listener); + PodeHelpers.WriteErrorMessage($"Opening Receive", Listener, PodeLoggingLevel.Verbose, context); + + // initialise the context + await context.Initialise().ConfigureAwait(false); if (context.IsErrored) { context.Dispose(true); return; } + // start receiving data StartReceive(context); } public void StartReceive(PodeContext context) { - var args = GetReceiveConnection(); - args.AcceptSocket = context.Socket; - args.UserToken = context; - StartReceive(args); - } - - private void StartReceive(SocketAsyncEventArgs args) - { - args.SetBuffer(new byte[0], 0, 0); - bool raised; + PodeHelpers.WriteErrorMessage($"Starting Receive", Listener, PodeLoggingLevel.Verbose, context); try { - AddPendingSocket(args.AcceptSocket); - raised = args.AcceptSocket.ReceiveAsync(args); + _ = Task.Run(async () => await context.Receive().ConfigureAwait(false), Listener.CancellationToken); } - catch (ObjectDisposedException) + catch (OperationCanceledException) { } + catch (IOException) { } + catch (AggregateException aex) { - return; + PodeHelpers.HandleAggregateException(aex, Listener, PodeLoggingLevel.Error, true); + context.Socket.Close(); } catch (Exception ex) { PodeHelpers.WriteException(ex, Listener); - throw; - } - - if (!raised) - { - ProcessReceive(args); + context.Socket.Close(); } } @@ -205,69 +193,28 @@ private void ProcessAccept(SocketAsyncEventArgs args) else { // start receive - StartReceive(args.AcceptSocket); - } - - // add args back to connections - ClearSocketAsyncEvent(args); - AcceptConnections.Enqueue(args); - } - - private void ProcessReceive(SocketAsyncEventArgs args) - { - // get details - var received = args.AcceptSocket; - var context = (PodeContext)args.UserToken; - var error = args.SocketError; - - // remove the socket from pending - RemovePendingSocket(received); - - // close socket if not successful, or if listener is stopped - close now! - if ((received == default(Socket)) || (error != SocketError.Success) || (!Listener.IsConnected)) - { - if (error != SocketError.Success) + try { - PodeHelpers.WriteErrorMessage($"Closing receiving socket: {error}", Listener, PodeLoggingLevel.Debug); + _ = Task.Run(async () => await StartReceive(accepted), Listener.CancellationToken).ConfigureAwait(false); } - - // close socket - if (received != default(Socket)) + catch (OperationCanceledException) { } + catch (IOException) { } + catch (AggregateException aex) { - received.Close(); + PodeHelpers.HandleAggregateException(aex, Listener, PodeLoggingLevel.Error, true); + } + catch (Exception ex) + { + PodeHelpers.WriteException(ex, Listener); } - - // close the context - context.Dispose(true); - - // add args back to connections - ClearSocketAsyncEvent(args); - ReceiveConnections.Enqueue(args); - return; - } - - try - { - context.RenewTimeoutToken(); - Task.Factory.StartNew(() => context.Receive(), context.ContextTimeoutToken.Token); - } - catch (OperationCanceledException) { } - catch (IOException) { } - catch (AggregateException aex) - { - PodeHelpers.HandleAggregateException(aex, Listener, PodeLoggingLevel.Error, true); - } - catch (Exception ex) - { - PodeHelpers.WriteException(ex, Listener); } // add args back to connections ClearSocketAsyncEvent(args); - ReceiveConnections.Enqueue(args); + AcceptConnections.Enqueue(args); } - public void HandleContext(PodeContext context) + public async Task HandleContext(PodeContext context) { try { @@ -287,7 +234,7 @@ public void HandleContext(PodeContext context) { if (!context.IsWebSocketUpgraded) { - context.UpgradeWebSocket(); + await context.UpgradeWebSocket().ConfigureAwait(false); process = false; context.Dispose(); } @@ -350,36 +297,11 @@ private SocketAsyncEventArgs NewAcceptConnection() } } - private SocketAsyncEventArgs NewReceiveConnection() - { - lock (ReceiveConnections) - { - var args = new SocketAsyncEventArgs(); - args.Completed += new EventHandler(Receive_Completed); - return args; - } - } - - private SocketAsyncEventArgs GetReceiveConnection() - { - if (!ReceiveConnections.TryDequeue(out SocketAsyncEventArgs args)) - { - args = NewReceiveConnection(); - } - - return args; - } - private void Accept_Completed(object sender, SocketAsyncEventArgs e) { ProcessAccept(e); } - private void Receive_Completed(object sender, SocketAsyncEventArgs e) - { - ProcessReceive(e); - } - private void AddPendingSocket(Socket socket) { lock (PendingSockets) @@ -392,7 +314,7 @@ private void AddPendingSocket(Socket socket) } } - private void RemovePendingSocket(Socket socket) + public void RemovePendingSocket(Socket socket) { lock (PendingSockets) { @@ -468,8 +390,8 @@ public static void CloseSocket(Socket socket) private void ClearSocketAsyncEvent(SocketAsyncEventArgs e) { - e.AcceptSocket = default(Socket); - e.UserToken = default(object); + e.AcceptSocket = default; + e.UserToken = default; } public new bool Equals(object obj) diff --git a/src/Listener/PodeTcpRequest.cs b/src/Listener/PodeTcpRequest.cs index af2a60aea..c48982538 100644 --- a/src/Listener/PodeTcpRequest.cs +++ b/src/Listener/PodeTcpRequest.cs @@ -1,4 +1,6 @@ using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; namespace Pode { @@ -22,11 +24,11 @@ public string Body public override bool CloseImmediately { - get => (IsDisposed || RawBody == default(byte[]) || RawBody.Length == 0); + get => IsDisposed || RawBody == default(byte[]) || RawBody.Length == 0; } - public PodeTcpRequest(Socket socket, PodeSocket podeSocket) - : base(socket, podeSocket) + public PodeTcpRequest(Socket socket, PodeSocket podeSocket, PodeContext context) + : base(socket, podeSocket, context) { IsKeepAlive = true; Type = PodeProtocolType.Tcp; @@ -43,31 +45,30 @@ protected override bool ValidateInput(byte[] bytes) // expect to end with ? if (Context.PodeSocket.CRLFMessageEnd) { - return (bytes[bytes.Length - 2] == (byte)13 - && bytes[bytes.Length - 1] == (byte)10); + return bytes[bytes.Length - 2] == PodeHelpers.CARRIAGE_RETURN_BYTE + && bytes[bytes.Length - 1] == PodeHelpers.NEW_LINE_BYTE; } return true; } - protected override bool Parse(byte[] bytes) + protected override Task Parse(byte[] bytes, CancellationToken cancellationToken) { - RawBody = bytes; + // check if the request is cancelled + cancellationToken.ThrowIfCancellationRequested(); - // if there are no bytes, return (0 bytes read means we can close the socket) - if (bytes.Length == 0) - { - return true; - } + // set the raw body + RawBody = bytes; - return true; + // return that we're done + return Task.FromResult(true); } public void Reset() { PodeHelpers.WriteErrorMessage($"Request reset", Context.Listener, PodeLoggingLevel.Verbose, Context); _body = string.Empty; - RawBody = default(byte[]); + RawBody = default; } public void Close() @@ -77,7 +78,7 @@ public void Close() public override void Dispose() { - RawBody = default(byte[]); + RawBody = default; _body = string.Empty; base.Dispose(); } diff --git a/src/Listener/PodeWatcher.cs b/src/Listener/PodeWatcher.cs index eebdec709..ed3134cb6 100644 --- a/src/Listener/PodeWatcher.cs +++ b/src/Listener/PodeWatcher.cs @@ -11,7 +11,7 @@ public class PodeWatcher : PodeConnector public PodeItemQueue FileEvents { get; private set; } - public PodeWatcher(CancellationToken cancellationToken = default(CancellationToken)) + public PodeWatcher(CancellationToken cancellationToken = default) : base(cancellationToken) { FileWatchers = new List(); @@ -24,7 +24,7 @@ public void AddFileWatcher(PodeFileWatcher watcher) FileWatchers.Add(watcher); } - public Task GetFileEventAsync(CancellationToken cancellationToken = default(CancellationToken)) + public Task GetFileEventAsync(CancellationToken cancellationToken = default) { return FileEvents.GetAsync(cancellationToken); } diff --git a/src/Listener/PodeWebSocket.cs b/src/Listener/PodeWebSocket.cs index 10d6a75f9..87cf4ed79 100644 --- a/src/Listener/PodeWebSocket.cs +++ b/src/Listener/PodeWebSocket.cs @@ -17,27 +17,23 @@ public class PodeWebSocket : IDisposable public string ContentType { get; private set; } public bool IsConnected { - get => (WebSocket != default(ClientWebSocket) && WebSocket.State == WebSocketState.Open); + get => WebSocket != default(ClientWebSocket) && WebSocket.State == WebSocketState.Open; } private ClientWebSocket WebSocket; - public PodeWebSocket(string name, string url, string contentType) + public PodeWebSocket(string name, string url, string contentType, PodeReceiver receiver) { Name = name; URL = new Uri(url); + Receiver = receiver; ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; } - public void BindReceiver(PodeReceiver receiver) - { - Receiver = receiver; - } - - public async void Connect() + public async Task Connect() { if (IsConnected) { @@ -46,29 +42,29 @@ public async void Connect() if (WebSocket != default(ClientWebSocket)) { - Disconnect(PodeWebSocketCloseFrom.Client); + await Disconnect(PodeWebSocketCloseFrom.Client).ConfigureAwait(false); WebSocket.Dispose(); } WebSocket = new ClientWebSocket(); WebSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(60); - await WebSocket.ConnectAsync(URL, Receiver.CancellationToken); - await Task.Factory.StartNew(Receive, Receiver.CancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); + await WebSocket.ConnectAsync(URL, Receiver.CancellationToken).ConfigureAwait(false); + await Task.Factory.StartNew(Receive, Receiver.CancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default).ConfigureAwait(false); } - public void Reconnect(string url) + public async Task Reconnect(string url) { if (!string.IsNullOrWhiteSpace(url)) { URL = new Uri(url); } - Disconnect(PodeWebSocketCloseFrom.Client); - Connect(); + await Disconnect(PodeWebSocketCloseFrom.Client).ConfigureAwait(false); + await Connect().ConfigureAwait(false); } - public async void Receive() + public async Task Receive() { var result = default(WebSocketReceiveResult); var buffer = _WebSocket.CreateClientBuffer(1024, 1024); @@ -80,7 +76,7 @@ public async void Receive() { do { - result = await WebSocket.ReceiveAsync(buffer, Receiver.CancellationToken); + result = await WebSocket.ReceiveAsync(buffer, Receiver.CancellationToken).ConfigureAwait(false); if (result.MessageType != WebSocketMessageType.Close) { bufferStream.Write(buffer.ToArray(), 0, result.Count); @@ -90,7 +86,7 @@ public async void Receive() if (result.MessageType == WebSocketMessageType.Close) { - Disconnect(PodeWebSocketCloseFrom.Server); + await Disconnect(PodeWebSocketCloseFrom.Server).ConfigureAwait(false); break; } @@ -105,7 +101,8 @@ public async void Receive() bufferStream = new MemoryStream(); } } - catch (TaskCanceledException) {} + catch (OperationCanceledException) { } + catch (IOException) { } catch (WebSocketException ex) { PodeHelpers.WriteException(ex, Receiver, PodeLoggingLevel.Debug); @@ -113,23 +110,27 @@ public async void Receive() } finally { - bufferStream.Dispose(); - bufferStream = default(MemoryStream); - buffer = default(ArraySegment); + if (bufferStream != default) + { + bufferStream.Dispose(); + bufferStream = default; + } + + buffer = default; } } - public void Send(string message, WebSocketMessageType type = WebSocketMessageType.Text) + public async Task Send(string message, WebSocketMessageType type = WebSocketMessageType.Text) { if (!IsConnected) { return; } - WebSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(message)), type, true, Receiver.CancellationToken).Wait(); + await WebSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(message)), type, true, Receiver.CancellationToken).ConfigureAwait(false); } - public void Disconnect(PodeWebSocketCloseFrom closeFrom) + public async Task Disconnect(PodeWebSocketCloseFrom closeFrom) { if (WebSocket == default(ClientWebSocket)) { @@ -143,26 +144,26 @@ public void Disconnect(PodeWebSocketCloseFrom closeFrom) // only close output in client closing if (closeFrom == PodeWebSocketCloseFrom.Client) { - WebSocket.CloseOutputAsync(WebSocketCloseStatus.Empty, string.Empty, CancellationToken.None).Wait(); + await WebSocket.CloseOutputAsync(WebSocketCloseStatus.Empty, string.Empty, CancellationToken.None).ConfigureAwait(false); } // if the server is closing, or client and netcore, then close properly if (closeFrom == PodeWebSocketCloseFrom.Server || !PodeHelpers.IsNetFramework) { - WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait(); + await WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); } PodeHelpers.WriteErrorMessage($"Closed client web socket: {Name}", Receiver, PodeLoggingLevel.Verbose); } WebSocket.Dispose(); - WebSocket = default(ClientWebSocket); + WebSocket = default; PodeHelpers.WriteErrorMessage($"Disconnected client web socket: {Name}", Receiver, PodeLoggingLevel.Verbose); } public void Dispose() { - Disconnect(PodeWebSocketCloseFrom.Client); + Disconnect(PodeWebSocketCloseFrom.Client).Wait(); } } } \ No newline at end of file diff --git a/src/Listener/PodeWebSocketRequest.cs b/src/Listener/PodeWebSocketRequest.cs index 7305d9666..39f623f29 100644 --- a/src/Listener/PodeWebSocketRequest.cs +++ b/src/Listener/PodeWebSocketRequest.cs @@ -13,7 +13,7 @@ public class PodeWebSocketRequest : IDisposable public int ContentLength { - get => (RawBody == default(byte[]) ? 0 : RawBody.Length); + get => RawBody == default(byte[]) ? 0 : RawBody.Length; } private string _body = string.Empty; @@ -39,7 +39,7 @@ public PodeWebSocketRequest(PodeWebSocket webSocket, MemoryStream bytes) public void Dispose() { WebSocket.Receiver.RemoveProcessingWebSocketRequest(this); - RawBody = default(byte[]); + RawBody = default; _body = string.Empty; } diff --git a/src/Private/PodeServer.ps1 b/src/Private/PodeServer.ps1 index 69dc7107c..d18adddf2 100644 --- a/src/Private/PodeServer.ps1 +++ b/src/Private/PodeServer.ps1 @@ -326,7 +326,7 @@ function Start-PodeWebServer { # send the message to all found sockets foreach ($socket in $sockets) { try { - $socket.Context.Response.SendSignal($message) + $null = Wait-PodeTask -Task $socket.Context.Response.SendSignal($message) } catch { $null = $Listener.Signals.Remove($socket.ClientId) diff --git a/src/Public/Responses.ps1 b/src/Public/Responses.ps1 index 0a3130e31..e46bb932a 100644 --- a/src/Public/Responses.ps1 +++ b/src/Public/Responses.ps1 @@ -1787,6 +1787,6 @@ function Send-PodeResponse { param() if ($null -ne $WebEvent.Response) { - $WebEvent.Response.Send() + $null = Wait-PodeTask -Task $WebEvent.Response.Send() } } \ No newline at end of file diff --git a/src/Public/SSE.ps1 b/src/Public/SSE.ps1 index 87ba03f80..07b708612 100644 --- a/src/Public/SSE.ps1 +++ b/src/Public/SSE.ps1 @@ -91,7 +91,7 @@ function ConvertTo-PodeSseConnection { $ClientId = New-PodeSseClientId -ClientId $ClientId # set and send SSE headers - $ClientId = $WebEvent.Response.SetSseConnection($Scope, $ClientId, $Name, $Group, $RetryDuration, $AllowAllOrigins.IsPresent) + $ClientId = Wait-PodeTask -Task $WebEvent.Response.SetSseConnection($Scope, $ClientId, $Name, $Group, $RetryDuration, $AllowAllOrigins.IsPresent) # create SSE property on WebEvent $WebEvent.Sse = @{ @@ -252,7 +252,7 @@ function Send-PodeSseEvent { # send directly back to current connection if ($FromEvent -and $WebEvent.Sse.IsLocal) { - $WebEvent.Response.SendSseEvent($EventType, $Data, $Id) + $null = Wait-PodeTask -Task $WebEvent.Response.SendSseEvent($EventType, $Data, $Id) return } diff --git a/src/Public/WebSockets.ps1 b/src/Public/WebSockets.ps1 index 9ea868852..e1d9f4958 100644 --- a/src/Public/WebSockets.ps1 +++ b/src/Public/WebSockets.ps1 @@ -133,7 +133,7 @@ function Connect-PodeWebSocket { # connect try { - $PodeContext.Server.WebSockets.Receiver.ConnectWebSocket($Name, $Url, $ContentType) + $null = Wait-PodeTask -Task $PodeContext.Server.WebSockets.Receiver.ConnectWebSocket($Name, $Url, $ContentType) } catch { # Failed to connect to websocket @@ -283,7 +283,7 @@ function Send-PodeWebSocket { $Message = ConvertTo-PodeResponseContent -InputObject $Message -ContentType $ws.ContentType -Depth $Depth # send message - $ws.Send($Message, $Type) + $null = Wait-PodeTask -Task $ws.Send($Message, $Type) } <# @@ -318,7 +318,7 @@ function Reset-PodeWebSocket { ) if ([string]::IsNullOrWhiteSpace($Name) -and ($null -ne $WsEvent)) { - $WsEvent.Request.WebSocket.Reconnect($Url) + $null = Wait-PodeTask -Task $WsEvent.Request.WebSocket.Reconnect($Url) return } @@ -328,7 +328,7 @@ function Reset-PodeWebSocket { } if (Test-PodeWebSocket -Name $Name) { - $PodeContext.Server.WebSockets.Receiver.GetWebSocket($Name).Reconnect($Url) + $null = Wait-PodeTask -Task $PodeContext.Server.WebSockets.Receiver.GetWebSocket($Name).Reconnect($Url) } } From 2942fe397cbba82354bf8f4fe6e89fc592d9e90a Mon Sep 17 00:00:00 2001 From: Matthew Kelly Date: Thu, 29 Aug 2024 09:35:35 +0100 Subject: [PATCH 2/3] #1291: set RestApi.Https.Tests to use both Invoke-RestMethod and curl, to help detect SSL issues in the future --- tests/integration/RestApi.Https.Tests.ps1 | 346 +++++++++++----------- 1 file changed, 165 insertions(+), 181 deletions(-) diff --git a/tests/integration/RestApi.Https.Tests.ps1 b/tests/integration/RestApi.Https.Tests.ps1 index 141b30c04..17dea15fe 100644 --- a/tests/integration/RestApi.Https.Tests.ps1 +++ b/tests/integration/RestApi.Https.Tests.ps1 @@ -5,46 +5,28 @@ param() Describe 'REST API Requests' { BeforeAll { $splatter = @{} - $UseCurl = $true $version = $PSVersionTable.PSVersion - if ( $version.Major -eq 5) { + + if ($version.Major -eq 5) { # Ignore SSL certificate validation errors Add-Type @' -using System.Net; -using System.Security.Cryptography.X509Certificates; -public class TrustAllCertsPolicy : ICertificatePolicy { -public bool CheckValidationResult( - ServicePoint srvPoint, X509Certificate certificate, - WebRequest request, int certificateProblem) { - return true; -} -} + using System.Net; + using System.Security.Cryptography.X509Certificates; + public class TrustAllCertsPolicy : ICertificatePolicy { + public bool CheckValidationResult( + ServicePoint srvPoint, X509Certificate certificate, + WebRequest request, int certificateProblem) { + return true; + } + } '@ [System.Net.ServicePointManager]::CertificatePolicy = New-Object TrustAllCertsPolicy - $UseCurl = $false - } - elseif ($PSVersionTable.OS -like '*Windows*') { - # OS check passed, now check PowerShell version - # Split version by '.' and compare major and minor version - if ( $version.Major -gt 7 -or ($version.Major -eq 7 -and $version.Minor -ge 4)) { - # Running on Windows with PowerShell Core 7.4 or greater. - $UseCurl = $true - } - else { - $UseCurl = $false - $splatter.SkipCertificateCheck = $true - # Running on Windows but with PowerShell version less than 7.4. - } - } else { - # Not running on Windows." - $UseCurl = $false $splatter.SkipCertificateCheck = $true } - $Port = 8080 $Endpoint = "https://127.0.0.1:$($Port)" @@ -136,139 +118,135 @@ public bool CheckValidationResult( AfterAll { Receive-Job -Name 'Pode' | Out-Default - if ($UseCurl) { - curl -s -X DELETE "$($Endpoint)/close" -k - } - else { - Invoke-RestMethod -Uri "$($Endpoint)/close" -Method Get @splatter | Out-Null - } + Invoke-RestMethod -Uri "$($Endpoint)/close" -Method Get @splatter | Out-Null Get-Job -Name 'Pode' | Remove-Job -Force } It 'responds back with pong' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/ping" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/ping" -k) | ConvertFrom-Json + $result.Result | Should -Be 'Pong' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Get @splatter $result.Result | Should -Be 'Pong' } It 'responds back with 404 for invalid route' { - if ($UseCurl) { - $status_code = (curl -s -o /dev/null -w '%{http_code}' "$Endpoint/eek" -k) - $status_code | Should -be 404 - } - else { - { Invoke-RestMethod -Uri "$($Endpoint)/eek" -Method Get -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*404*' - } + # test curl + $status_code = (curl -s -o /dev/null -w '%{http_code}' "$Endpoint/eek" -k) + $status_code | Should -be 404 + + # test Invoke-RestMethod + { Invoke-RestMethod -Uri "$($Endpoint)/eek" -Method Get -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*404*' } It 'responds back with 405 for incorrect method' { - if ($UseCurl) { - $status_code = (curl -X POST -s -o /dev/null -w '%{http_code}' "$Endpoint/ping" -k) - $status_code | Should -be 405 - } - else { - { Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Post -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*405*' - } + # test curl + $status_code = (curl -X POST -s -o /dev/null -w '%{http_code}' "$Endpoint/ping" -k) + $status_code | Should -be 405 + + # test Invoke-RestMethod + { Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Post -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*405*' } It 'responds with simple query parameter' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/data/query?username=rick" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/query?username=rick" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/data/query?username=rick" -k) | ConvertFrom-Json + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/query?username=rick" -Method Get @splatter $result.Username | Should -Be 'rick' } It 'responds with simple payload parameter - json' { - if ($UseCurl) { - $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/json' -d '{"username":"rick"}' -k | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body '{"username":"rick"}' -ContentType 'application/json' @splatter - } + # test curl + $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/json' -d '{"username":"rick"}' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body '{"username":"rick"}' -ContentType 'application/json' @splatter $result.Username | Should -Be 'rick' } It 'responds with simple payload parameter - xml' { - if ($UseCurl) { - $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/xml' -d 'rick' -k | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body 'rick' -ContentType 'application/xml' @splatter - } + # test curl + $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/xml' -d 'rick' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body 'rick' -ContentType 'application/xml' @splatter $result.Username | Should -Be 'rick' } It 'responds with simple payload parameter forced to json' { - if ($UseCurl) { - $result = curl -s -X POST "$($Endpoint)/data/payload-forced-type" -d '{"username":"rick"}' -k | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload-forced-type" -Method Post -Body '{"username":"rick"}' @splatter - } + # test curl + $result = curl -s -X POST "$($Endpoint)/data/payload-forced-type" -d '{"username":"rick"}' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload-forced-type" -Method Post -Body '{"username":"rick"}' @splatter $result.Username | Should -Be 'rick' } It 'responds with simple route parameter' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/data/param/rick" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/data/param/rick" -k) | ConvertFrom-Json + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick" -Method Get @splatter $result.Username | Should -Be 'rick' } It 'responds with simple route parameter long' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/data/param/rick/messages" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick/messages" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/data/param/rick/messages" -k) | ConvertFrom-Json + $result.Messages[0] | Should -Be 'Hello, world!' + $result.Messages[1] | Should -Be 'Greetings' + $result.Messages[2] | Should -Be 'Wubba Lub' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick/messages" -Method Get @splatter $result.Messages[0] | Should -Be 'Hello, world!' $result.Messages[1] | Should -Be 'Greetings' $result.Messages[2] | Should -Be 'Wubba Lub' } It 'responds ok to remove account' { - if ($UseCurl) { - $result = (curl -s -X DELETE "$($Endpoint)/api/rick/remove" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/remove" -Method Delete @splatter - } + # test curl + $result = (curl -s -X DELETE "$($Endpoint)/api/rick/remove" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/remove" -Method Delete @splatter $result.Result | Should -Be 'OK' } It 'responds ok to replace account' { - if ($UseCurl) { - $result = (curl -s -X PUT "$($Endpoint)/api/rick/replace" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/replace" -Method Put @splatter - } + # test curl + $result = (curl -s -X PUT "$($Endpoint)/api/rick/replace" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/replace" -Method Put @splatter $result.Result | Should -Be 'OK' } It 'responds ok to update account' { - if ($UseCurl) { - $result = (curl -s -X PATCH "$($Endpoint)/api/rick/update" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/update" -Method Patch @splatter - } + # test curl + $result = (curl -s -X PATCH "$($Endpoint)/api/rick/update" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/update" -Method Patch @splatter $result.Result | Should -Be 'OK' } It 'decodes encoded payload parameter - gzip' { + # test curl $data = @{ username = 'rick' } $message = ($data | ConvertTo-Json) @@ -279,27 +257,29 @@ public bool CheckValidationResult( $gzip.Write($bytes, 0, $bytes.Length) $gzip.Close() - if ($UseCurl) { + try { + # get the compressed data + $ms.Position = 0 $compressedData = $ms.ToArray() - $ms.Dispose() + # Save the compressed data to a temporary file $tempFile = [System.IO.Path]::GetTempFileName() [System.IO.File]::WriteAllBytes($tempFile, $compressedData) + # make the request $result = curl -s -X POST "$Endpoint/encoding/transfer" -H 'Transfer-Encoding: gzip' -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json # Cleanup the temporary file Remove-Item -Path $tempFile - } - else { + $result.Username | Should -Be 'rick' + # make the request - $ms.Position = 0 - $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $ms.ToArray() -Headers @{ 'Transfer-Encoding' = 'gzip' } -ContentType 'application/json' @splatter + $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $compressedData -Headers @{ 'Transfer-Encoding' = 'gzip' } -ContentType 'application/json' @splatter + $result.Username | Should -Be 'rick' + } + finally { $ms.Dispose() } - - $result.Username | Should -Be 'rick' - } It 'decodes encoded payload parameter - deflate' { @@ -309,13 +289,16 @@ public bool CheckValidationResult( # compress the message using deflate $bytes = [System.Text.Encoding]::UTF8.GetBytes($message) $ms = New-Object -TypeName System.IO.MemoryStream - $gzip = New-Object System.IO.Compression.DeflateStream($ms, [IO.Compression.CompressionMode]::Compress, $true) - $gzip.Write($bytes, 0, $bytes.Length) - $gzip.Close() - if ($UseCurl) { + $deflate = New-Object System.IO.Compression.DeflateStream($ms, [IO.Compression.CompressionMode]::Compress, $true) + $deflate.Write($bytes, 0, $bytes.Length) + $deflate.Close() + + try { + # get the compressed data + $ms.Position = 0 $compressedData = $ms.ToArray() - $ms.Dispose() + # test curl # Save the compressed data to a temporary file $tempFile = [System.IO.Path]::GetTempFileName() [System.IO.File]::WriteAllBytes($tempFile, $compressedData) @@ -325,15 +308,15 @@ public bool CheckValidationResult( # Cleanup the temporary file Remove-Item -Path $tempFile + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $compressedData -Headers @{ 'Transfer-Encoding' = 'deflate' } -ContentType 'application/json' @splatter + $result.Username | Should -Be 'rick' } - else { - # make the request - $ms.Position = 0 - $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $ms.ToArray() -Headers @{ 'Transfer-Encoding' = 'deflate' } -ContentType 'application/json' @splatter + finally { $ms.Dispose() } - - $result.Username | Should -Be 'rick' } It 'decodes encoded payload parameter forced to gzip' { @@ -346,93 +329,94 @@ public bool CheckValidationResult( $gzip = New-Object System.IO.Compression.GZipStream($ms, [IO.Compression.CompressionMode]::Compress, $true) $gzip.Write($bytes, 0, $bytes.Length) $gzip.Close() - if ($UseCurl) { + try { + # get the compressed data + $ms.Position = 0 $compressedData = $ms.ToArray() - $ms.Dispose() + # test curl # Save the compressed data to a temporary file $tempFile = [System.IO.Path]::GetTempFileName() [System.IO.File]::WriteAllBytes($tempFile, $compressedData) + # make the request $result = curl -s -X POST "$Endpoint/encoding/transfer-forced-type" -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json # Cleanup the temporary file Remove-Item -Path $tempFile + $result.Username | Should -Be 'rick' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer-forced-type" -Method Post -Body $compressedData -ContentType 'application/json' @splatter + $result.Username | Should -Be 'rick' } - else { - # make the request - $ms.Position = 0 - $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer-forced-type" -Method Post -Body $ms.ToArray() -ContentType 'application/json' @splatter + finally { $ms.Dispose() } - - $result.Username | Should -Be 'rick' } It 'works with any method' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + # test curl + $result = (curl -s -X GET "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X PUT "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X PUT "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X PATCH "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Get @splatter - $result.Result | Should -Be 'OK' + $result = (curl -s -X PATCH "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Put @splatter - $result.Result | Should -Be 'OK' + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Get @splatter + $result.Result | Should -Be 'OK' - $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Patch @splatter - $result.Result | Should -Be 'OK' - } + $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Put @splatter + $result.Result | Should -Be 'OK' + + $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Patch @splatter + $result.Result | Should -Be 'OK' } It 'route with a wild card' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/api/stuff/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + # test curl + $result = (curl -s -X GET "$($Endpoint)/api/stuff/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X GET "$($Endpoint)/api/random/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X GET "$($Endpoint)/api/random/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X GET "$($Endpoint)/api/123/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/stuff/hello" -Method Get @splatter - $result.Result | Should -Be 'OK' + $result = (curl -s -X GET "$($Endpoint)/api/123/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/random/hello" -Method Get @splatter - $result.Result | Should -Be 'OK' + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/stuff/hello" -Method Get @splatter + $result.Result | Should -Be 'OK' - $result = Invoke-RestMethod -Uri "$($Endpoint)/api/123/hello" -Method Get @splatter - $result.Result | Should -Be 'OK' - } + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/random/hello" -Method Get @splatter + $result.Result | Should -Be 'OK' + + $result = Invoke-RestMethod -Uri "$($Endpoint)/api/123/hello" -Method Get @splatter + $result.Result | Should -Be 'OK' } It 'route importing outer function' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/imported/func/outer" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/outer" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/imported/func/outer" -k) | ConvertFrom-Json + $result.Message | Should -Be 'Outer Hello' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/outer" -Method Get @splatter $result.Message | Should -Be 'Outer Hello' } It 'route importing outer function' { - if ($UseCurl) { - $result = (curl -s -X GET "$($Endpoint)/imported/func/inner" -k) | ConvertFrom-Json - } - else { - $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/inner" -Method Get @splatter - } + # test curl + $result = (curl -s -X GET "$($Endpoint)/imported/func/inner" -k) | ConvertFrom-Json + $result.Message | Should -Be 'Inner Hello' + + # test Invoke-RestMethod + $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/inner" -Method Get @splatter $result.Message | Should -Be 'Inner Hello' } } \ No newline at end of file From 22e7d94a3fc6b14e204c5f04a994dbbb5caab8ee Mon Sep 17 00:00:00 2001 From: Matthew Kelly Date: Thu, 29 Aug 2024 09:52:52 +0100 Subject: [PATCH 3/3] #1291: only use curl on PS7.4+ --- tests/integration/RestApi.Https.Tests.ps1 | 177 +++++++++++++--------- 1 file changed, 109 insertions(+), 68 deletions(-) diff --git a/tests/integration/RestApi.Https.Tests.ps1 b/tests/integration/RestApi.Https.Tests.ps1 index 17dea15fe..0e9184124 100644 --- a/tests/integration/RestApi.Https.Tests.ps1 +++ b/tests/integration/RestApi.Https.Tests.ps1 @@ -6,6 +6,7 @@ Describe 'REST API Requests' { BeforeAll { $splatter = @{} $version = $PSVersionTable.PSVersion + $useCurl = $false if ($version.Major -eq 5) { # Ignore SSL certificate validation errors @@ -20,10 +21,13 @@ Describe 'REST API Requests' { } } '@ - [System.Net.ServicePointManager]::CertificatePolicy = New-Object TrustAllCertsPolicy } else { + if ($version -ge [version]'7.4.0') { + $useCurl = $true + } + $splatter.SkipCertificateCheck = $true } @@ -125,8 +129,10 @@ Describe 'REST API Requests' { It 'responds back with pong' { # test curl - $result = (curl -s -X GET "$($Endpoint)/ping" -k) | ConvertFrom-Json - $result.Result | Should -Be 'Pong' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/ping" -k) | ConvertFrom-Json + $result.Result | Should -Be 'Pong' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Get @splatter @@ -135,8 +141,10 @@ Describe 'REST API Requests' { It 'responds back with 404 for invalid route' { # test curl - $status_code = (curl -s -o /dev/null -w '%{http_code}' "$Endpoint/eek" -k) - $status_code | Should -be 404 + if ($useCurl) { + $status_code = (curl -s -o /dev/null -w '%{http_code}' "$Endpoint/eek" -k) + $status_code | Should -be 404 + } # test Invoke-RestMethod { Invoke-RestMethod -Uri "$($Endpoint)/eek" -Method Get -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*404*' @@ -144,8 +152,10 @@ Describe 'REST API Requests' { It 'responds back with 405 for incorrect method' { # test curl - $status_code = (curl -X POST -s -o /dev/null -w '%{http_code}' "$Endpoint/ping" -k) - $status_code | Should -be 405 + if ($useCurl) { + $status_code = (curl -X POST -s -o /dev/null -w '%{http_code}' "$Endpoint/ping" -k) + $status_code | Should -be 405 + } # test Invoke-RestMethod { Invoke-RestMethod -Uri "$($Endpoint)/ping" -Method Post -ErrorAction Stop @splatter } | Should -Throw -ExpectedMessage '*405*' @@ -153,8 +163,10 @@ Describe 'REST API Requests' { It 'responds with simple query parameter' { # test curl - $result = (curl -s -X GET "$($Endpoint)/data/query?username=rick" -k) | ConvertFrom-Json - $result.Username | Should -Be 'rick' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/data/query?username=rick" -k) | ConvertFrom-Json + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/query?username=rick" -Method Get @splatter @@ -163,8 +175,10 @@ Describe 'REST API Requests' { It 'responds with simple payload parameter - json' { # test curl - $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/json' -d '{"username":"rick"}' -k | ConvertFrom-Json - $result.Username | Should -Be 'rick' + if ($useCurl) { + $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/json' -d '{"username":"rick"}' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body '{"username":"rick"}' -ContentType 'application/json' @splatter @@ -173,8 +187,10 @@ Describe 'REST API Requests' { It 'responds with simple payload parameter - xml' { # test curl - $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/xml' -d 'rick' -k | ConvertFrom-Json - $result.Username | Should -Be 'rick' + if ($useCurl) { + $result = curl -s -X POST "$($Endpoint)/data/payload" -H 'Content-Type: application/xml' -d 'rick' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload" -Method Post -Body 'rick' -ContentType 'application/xml' @splatter @@ -183,8 +199,10 @@ Describe 'REST API Requests' { It 'responds with simple payload parameter forced to json' { # test curl - $result = curl -s -X POST "$($Endpoint)/data/payload-forced-type" -d '{"username":"rick"}' -k | ConvertFrom-Json - $result.Username | Should -Be 'rick' + if ($useCurl) { + $result = curl -s -X POST "$($Endpoint)/data/payload-forced-type" -d '{"username":"rick"}' -k | ConvertFrom-Json + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/payload-forced-type" -Method Post -Body '{"username":"rick"}' @splatter @@ -193,8 +211,10 @@ Describe 'REST API Requests' { It 'responds with simple route parameter' { # test curl - $result = (curl -s -X GET "$($Endpoint)/data/param/rick" -k) | ConvertFrom-Json - $result.Username | Should -Be 'rick' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/data/param/rick" -k) | ConvertFrom-Json + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick" -Method Get @splatter @@ -203,10 +223,12 @@ Describe 'REST API Requests' { It 'responds with simple route parameter long' { # test curl - $result = (curl -s -X GET "$($Endpoint)/data/param/rick/messages" -k) | ConvertFrom-Json - $result.Messages[0] | Should -Be 'Hello, world!' - $result.Messages[1] | Should -Be 'Greetings' - $result.Messages[2] | Should -Be 'Wubba Lub' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/data/param/rick/messages" -k) | ConvertFrom-Json + $result.Messages[0] | Should -Be 'Hello, world!' + $result.Messages[1] | Should -Be 'Greetings' + $result.Messages[2] | Should -Be 'Wubba Lub' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/data/param/rick/messages" -Method Get @splatter @@ -217,8 +239,10 @@ Describe 'REST API Requests' { It 'responds ok to remove account' { # test curl - $result = (curl -s -X DELETE "$($Endpoint)/api/rick/remove" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + if ($useCurl) { + $result = (curl -s -X DELETE "$($Endpoint)/api/rick/remove" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/remove" -Method Delete @splatter @@ -227,8 +251,10 @@ Describe 'REST API Requests' { It 'responds ok to replace account' { # test curl - $result = (curl -s -X PUT "$($Endpoint)/api/rick/replace" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + if ($useCurl) { + $result = (curl -s -X PUT "$($Endpoint)/api/rick/replace" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/replace" -Method Put @splatter @@ -237,8 +263,10 @@ Describe 'REST API Requests' { It 'responds ok to update account' { # test curl - $result = (curl -s -X PATCH "$($Endpoint)/api/rick/update" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + if ($useCurl) { + $result = (curl -s -X PATCH "$($Endpoint)/api/rick/update" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/api/rick/update" -Method Patch @splatter @@ -246,7 +274,6 @@ Describe 'REST API Requests' { } It 'decodes encoded payload parameter - gzip' { - # test curl $data = @{ username = 'rick' } $message = ($data | ConvertTo-Json) @@ -262,16 +289,18 @@ Describe 'REST API Requests' { $ms.Position = 0 $compressedData = $ms.ToArray() - # Save the compressed data to a temporary file - $tempFile = [System.IO.Path]::GetTempFileName() - [System.IO.File]::WriteAllBytes($tempFile, $compressedData) + if ($useCurl) { + # Save the compressed data to a temporary file + $tempFile = [System.IO.Path]::GetTempFileName() + [System.IO.File]::WriteAllBytes($tempFile, $compressedData) - # make the request - $result = curl -s -X POST "$Endpoint/encoding/transfer" -H 'Transfer-Encoding: gzip' -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json + # make the request + $result = curl -s -X POST "$Endpoint/encoding/transfer" -H 'Transfer-Encoding: gzip' -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json - # Cleanup the temporary file - Remove-Item -Path $tempFile - $result.Username | Should -Be 'rick' + # Cleanup the temporary file + Remove-Item -Path $tempFile + $result.Username | Should -Be 'rick' + } # make the request $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $compressedData -Headers @{ 'Transfer-Encoding' = 'gzip' } -ContentType 'application/json' @splatter @@ -299,16 +328,18 @@ Describe 'REST API Requests' { $compressedData = $ms.ToArray() # test curl - # Save the compressed data to a temporary file - $tempFile = [System.IO.Path]::GetTempFileName() - [System.IO.File]::WriteAllBytes($tempFile, $compressedData) + if ($useCurl) { + # Save the compressed data to a temporary file + $tempFile = [System.IO.Path]::GetTempFileName() + [System.IO.File]::WriteAllBytes($tempFile, $compressedData) - # make the request - $result = curl -s -X POST "$Endpoint/encoding/transfer" -H 'Transfer-Encoding: deflate' -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json + # make the request + $result = curl -s -X POST "$Endpoint/encoding/transfer" -H 'Transfer-Encoding: deflate' -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json - # Cleanup the temporary file - Remove-Item -Path $tempFile - $result.Username | Should -Be 'rick' + # Cleanup the temporary file + Remove-Item -Path $tempFile + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer" -Method Post -Body $compressedData -Headers @{ 'Transfer-Encoding' = 'deflate' } -ContentType 'application/json' @splatter @@ -336,16 +367,18 @@ Describe 'REST API Requests' { $compressedData = $ms.ToArray() # test curl - # Save the compressed data to a temporary file - $tempFile = [System.IO.Path]::GetTempFileName() - [System.IO.File]::WriteAllBytes($tempFile, $compressedData) + if ($useCurl) { + # Save the compressed data to a temporary file + $tempFile = [System.IO.Path]::GetTempFileName() + [System.IO.File]::WriteAllBytes($tempFile, $compressedData) - # make the request - $result = curl -s -X POST "$Endpoint/encoding/transfer-forced-type" -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json + # make the request + $result = curl -s -X POST "$Endpoint/encoding/transfer-forced-type" -H 'Content-Type: application/json' --data-binary "@$tempFile" -k | ConvertFrom-Json - # Cleanup the temporary file - Remove-Item -Path $tempFile - $result.Username | Should -Be 'rick' + # Cleanup the temporary file + Remove-Item -Path $tempFile + $result.Username | Should -Be 'rick' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/encoding/transfer-forced-type" -Method Post -Body $compressedData -ContentType 'application/json' @splatter @@ -358,14 +391,16 @@ Describe 'REST API Requests' { It 'works with any method' { # test curl - $result = (curl -s -X GET "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X PUT "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X PUT "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X PATCH "$($Endpoint)/all" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X PATCH "$($Endpoint)/all" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/all" -Method Get @splatter @@ -380,14 +415,16 @@ Describe 'REST API Requests' { It 'route with a wild card' { # test curl - $result = (curl -s -X GET "$($Endpoint)/api/stuff/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/api/stuff/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X GET "$($Endpoint)/api/random/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X GET "$($Endpoint)/api/random/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' - $result = (curl -s -X GET "$($Endpoint)/api/123/hello" -k) | ConvertFrom-Json - $result.Result | Should -Be 'OK' + $result = (curl -s -X GET "$($Endpoint)/api/123/hello" -k) | ConvertFrom-Json + $result.Result | Should -Be 'OK' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/api/stuff/hello" -Method Get @splatter @@ -402,8 +439,10 @@ Describe 'REST API Requests' { It 'route importing outer function' { # test curl - $result = (curl -s -X GET "$($Endpoint)/imported/func/outer" -k) | ConvertFrom-Json - $result.Message | Should -Be 'Outer Hello' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/imported/func/outer" -k) | ConvertFrom-Json + $result.Message | Should -Be 'Outer Hello' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/outer" -Method Get @splatter @@ -412,8 +451,10 @@ Describe 'REST API Requests' { It 'route importing outer function' { # test curl - $result = (curl -s -X GET "$($Endpoint)/imported/func/inner" -k) | ConvertFrom-Json - $result.Message | Should -Be 'Inner Hello' + if ($useCurl) { + $result = (curl -s -X GET "$($Endpoint)/imported/func/inner" -k) | ConvertFrom-Json + $result.Message | Should -Be 'Inner Hello' + } # test Invoke-RestMethod $result = Invoke-RestMethod -Uri "$($Endpoint)/imported/func/inner" -Method Get @splatter