diff --git a/src/DocumentDbTests/Reading/query_by_sql.cs b/src/DocumentDbTests/Reading/query_by_sql.cs index e39c6a6c10..d24b4a989a 100644 --- a/src/DocumentDbTests/Reading/query_by_sql.cs +++ b/src/DocumentDbTests/Reading/query_by_sql.cs @@ -379,4 +379,23 @@ public async Task get_count_asynchronously() var sum = sumResults.Single(); sum.ShouldBe(4); } + + [Fact] + public async Task can_query_using_with_select() + { + await using var session = theStore.LightweightSession(); + var time = DateTimeOffset.UtcNow; + var u = new User { FirstName = "Jeremy", LastName = "Miller", ModifiedAt = time, Age = 28 }; + session.Store(u); + await session.SaveChangesAsync(); + + var users = + await + session.QueryAsync( + "with my_with_query as (select data from mt_doc_user where data ->> 'FirstName' = 'Jeremy') select data from my_with_query"); + var user = users.Single(); + + user.LastName.ShouldBe("Miller"); + user.Id.ShouldBe(u.Id); + } } diff --git a/src/Marten/Linq/QueryHandlers/UserSuppliedQueryHandler.cs b/src/Marten/Linq/QueryHandlers/UserSuppliedQueryHandler.cs index 71a26e6506..b81113ddb3 100644 --- a/src/Marten/Linq/QueryHandlers/UserSuppliedQueryHandler.cs +++ b/src/Marten/Linq/QueryHandlers/UserSuppliedQueryHandler.cs @@ -26,7 +26,8 @@ public UserSuppliedQueryHandler(IMartenSession session, string sql, object[] par { _sql = sql.TrimStart(); _parameters = parameters; - SqlContainsCustomSelect = _sql.StartsWith("select", StringComparison.OrdinalIgnoreCase); + SqlContainsCustomSelect = _sql.StartsWith("select", StringComparison.OrdinalIgnoreCase) + || IsWithFollowedBySelect(_sql); _selectClause = GetSelectClause(session); _selector = (ISelector)_selectClause.BuildSelector(session); @@ -136,4 +137,39 @@ private ISelectClause GetSelectClause(IMartenSession session) return session.StorageFor(typeof(T)); } + + private static bool IsWithFollowedBySelect(string sql) + { + var parenthesesLevel = 0; + var isWithBlockDetected = false; + + for (var i = 0; i < sql.Length; i++) + { + var c = sql[i]; + + // Check for parentheses to handle nested structures + if (c == '(') + { + parenthesesLevel++; + } + else if (c == ')') + { + parenthesesLevel--; + } + + // Detect the beginning of the WITH block + if (!isWithBlockDetected && i < sql.Length - 4 && sql.Substring(i, 4).Equals("with", StringComparison.OrdinalIgnoreCase)) + { + isWithBlockDetected = true; + } + + // Detect the beginning of the SELECT block only if WITH block is detected and at top-level + if (isWithBlockDetected && i < sql.Length - 6 && sql.Substring(i, 6).Equals("select", StringComparison.OrdinalIgnoreCase) && parenthesesLevel == 0) + { + return true; + } + } + + return false; + } }