diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d01617ac76d5..6fe28cabee06 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -74,16 +74,16 @@ checksum = "f410d3907b6b3647b9e7bca4551274b2e3d716aa940afb67b7287257401da921" dependencies = [ "ahash", "arrow-arith", - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", "arrow-csv", - "arrow-data", + "arrow-data 34.0.0", "arrow-ipc", "arrow-json", "arrow-ord", "arrow-row", - "arrow-schema", + "arrow-schema 34.0.0", "arrow-select", "arrow-string", "comfy-table", @@ -95,10 +95,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87391cf46473c9bc53dab68cb8872c3a81d4dfd1703f1c8aa397dba9880a043" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "num", @@ -111,15 +111,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d35d5475e65c57cffba06d0022e3006b677515f99b54af33a7cd54f6cdd4a5b5" dependencies = [ "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "hashbrown 0.13.2", "num", ] +[[package]] +name = "arrow-array" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43489bbff475545b78b0e20bde1d22abd6c99e54499839f9e815a2fa5134a51b" +dependencies = [ + "ahash", + "arrow-buffer 35.0.0", + "arrow-data 35.0.0", + "arrow-schema 35.0.0", + "chrono", + "chrono-tz", + "half", + "hashbrown 0.13.2", + "num", +] + [[package]] name = "arrow-buffer" version = "34.0.0" @@ -130,16 +147,26 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-buffer" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3759e4a52c593281184787af5435671dc8b1e78333e5a30242b2e2d6e3c9d1f" +dependencies = [ + "half", + "num", +] + [[package]] name = "arrow-cast" version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a7285272c9897321dfdba59de29f5b05aeafd3cdedf104a941256d155f6d304" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "chrono", "lexical-core", @@ -152,11 +179,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "981ee4e7f6a120da04e00d0b39182e1eeacccb59c8da74511de753c56b7fddf7" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "csv", "csv-core", @@ -171,8 +198,20 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27cc673ee6989ea6e4b4e8c7d461f7e06026a096c8f0b1a7288885ff71ae1e56" dependencies = [ - "arrow-buffer", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-schema 34.0.0", + "half", + "num", +] + +[[package]] +name = "arrow-data" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19c7787c6cdbf9539b1ffb860bfc18c5848926ec3d62cbd52dc3b1ea35c874fd" +dependencies = [ + "arrow-buffer 35.0.0", + "arrow-schema 35.0.0", "half", "num", ] @@ -183,11 +222,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37b8b69d9e59116b6b538e8514e0ec63a30f08b617ce800d31cb44e3ef64c1a" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "flatbuffers", ] @@ -197,11 +236,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "80c3fa0bed7cfebf6d18e46b733f9cb8a1cb43ce8e6539055ca3e1e48a426266" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "indexmap", @@ -216,10 +255,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d247dce7bed6a8d6a3c6debfa707a3a2f694383f0c692a39d736a593eae5ef94" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "num", ] @@ -231,10 +270,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d609c0181f963cea5c70fddf9a388595b5be441f3aa1d1cdbf728ca834bbd3a" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "half", "hashbrown 0.13.2", ] @@ -245,16 +284,22 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64951898473bfb8e22293e83a44f02874d2257514d49cd95f9aa4afcff183fbc" +[[package]] +name = "arrow-schema" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6b26f6a6f8410e3b9531cbd1886399b99842701da77d4b4cf2013f7708f20f" + [[package]] name = "arrow-select" version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a513d89c2e1ac22b28380900036cf1f3992c6443efc5e079de631dcf83c6888" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "num", ] @@ -264,10 +309,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5288979b2705dae1114c864d73150629add9153b9b8f1d7ee3963db94c372ba5" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "regex", "regex-syntax", @@ -440,9 +485,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.23" +version = "0.4.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" dependencies = [ "iana-time-zone", "num-integer", @@ -451,6 +496,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "chrono-tz" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa48fa079165080f11d7753fd0bc175b7d391f276b965fe4b55bfad67856e463" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9998fb9f7e9b2111641485bf8beb32f92945f97f92a3d061f744cfef335f751" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + [[package]] name = "clap" version = "3.2.23" @@ -546,9 +613,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ad85c1f65dc7b37604eb0e89748faf0b9653065f2a8ef69f96a687ec1e9279" +checksum = "13418e745008f7349ec7e449155f419a61b92b58a99cc3616942b926825ec76b" [[package]] name = "core-foundation-sys" @@ -737,6 +804,7 @@ name = "datafusion-common" version = "20.0.0" dependencies = [ "arrow", + "arrow-array 35.0.0", "chrono", "num_cpus", "object_store", @@ -792,8 +860,8 @@ version = "20.0.0" dependencies = [ "ahash", "arrow", - "arrow-buffer", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-schema 34.0.0", "blake2", "blake3", "chrono", @@ -829,7 +897,7 @@ dependencies = [ name = "datafusion-sql" version = "20.0.0" dependencies = [ - "arrow-schema", + "arrow-schema 34.0.0", "datafusion-common", "datafusion-expr", "log", @@ -1022,9 +1090,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" +checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" dependencies = [ "futures-channel", "futures-core", @@ -1037,9 +1105,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" dependencies = [ "futures-core", "futures-sink", @@ -1047,15 +1115,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" [[package]] name = "futures-executor" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" +checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" dependencies = [ "futures-core", "futures-task", @@ -1064,15 +1132,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ "proc-macro2", "quote", @@ -1081,21 +1149,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-channel", "futures-core", @@ -1204,6 +1272,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "http" version = "0.2.9" @@ -1342,10 +1416,11 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "76e86b86ae312accbf05ade23ce76b625e0e47a255712b7414037385a1c05380" dependencies = [ + "hermit-abi 0.3.1", "libc", "windows-sys 0.45.0", ] @@ -1784,12 +1859,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ac135ecf63ebb5f53dda0921b0b76d6048b3ef631a5f4760b9e8f863ff00cfa" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", + "arrow-data 34.0.0", "arrow-ipc", - "arrow-schema", + "arrow-schema 34.0.0", "arrow-select", "base64", "brotli", @@ -1810,6 +1885,15 @@ dependencies = [ "zstd 0.12.3+zstd.1.5.2", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +dependencies = [ + "regex", +] + [[package]] name = "paste" version = "1.0.12" @@ -1832,6 +1916,44 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56ac890c5e3ca598bbdeaa99964edb5b0258a583a9eb6ef4e89fc85d9224770" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1888,9 +2010,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" dependencies = [ "unicode-ident", ] @@ -1907,9 +2029,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -2159,9 +2281,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "seq-macro" @@ -2171,18 +2293,18 @@ checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.154" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cdd151213925e7f1ab45a9bbfb129316bd00799784b174b7cc7bcd16961c49e" +checksum = "314b5b092c0ade17c00142951e50ced110ec27cea304b1037c6969246c2469a4" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.154" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc80d722935453bcafdc2c9a73cd6fac4dc1938f0346035d84bf99fa9e33217" +checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d" dependencies = [ "proc-macro2", "quote", @@ -2223,6 +2345,12 @@ dependencies = [ "digest", ] +[[package]] +name = "siphasher" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" + [[package]] name = "slab" version = "0.4.8" @@ -2639,12 +2767,11 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", - "winapi", "winapi-util", ] @@ -2829,9 +2956,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -2844,45 +2971,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "winreg" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 8a0a7042fcba..7d78ed70eb35 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -41,6 +41,7 @@ pyarrow = ["pyo3", "arrow/pyarrow"] [dependencies] apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true } arrow = { workspace = true, default-features = false } +arrow-array = { version = "35.0.0", default-features = false, features = ["chrono-tz"] } chrono = { version = "0.4", default-features = false } cranelift-module = { version = "0.92.0", optional = true } num_cpus = "1.13.0" @@ -48,3 +49,6 @@ object_store = { version = "0.5.4", default-features = false, optional = true } parquet = { workspace = true, default-features = false, optional = true } pyo3 = { version = "0.18.0", optional = true } sqlparser = "0.32" + +[dev-dependencies] +rand = "0.8.4" diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 73352941afa7..92cdab3ebba3 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -43,7 +43,14 @@ use arrow::{ DECIMAL128_MAX_PRECISION, }, }; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; +use arrow_array::timezone::Tz; +use chrono::{DateTime, Datelike, Duration, NaiveDate, NaiveDateTime, TimeZone}; + +// Constants we use throughout this file: +const MILLISECS_IN_ONE_DAY: i64 = 86_400_000; +const NANOSECS_IN_ONE_DAY: i64 = 86_400_000_000_000; +const MILLISECS_IN_ONE_MONTH: i64 = 2_592_000_000; // assuming 30 days. +const NANOSECS_IN_ONE_MONTH: i128 = 2_592_000_000_000_000; // assuming 30 days. /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part to arrow's [`Array`]. @@ -199,10 +206,28 @@ impl PartialEq for ScalarValue { (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), (TimestampNanosecond(_, _), _) => false, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(v1), IntervalDayTime(v2)) => { + ym_to_milli(v1).eq(&dt_to_milli(v2)) + } + (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { + ym_to_nano(v1).eq(&mdn_to_nano(v2)) + } (IntervalYearMonth(_), _) => false, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(v1), IntervalYearMonth(v2)) => { + dt_to_milli(v1).eq(&ym_to_milli(v2)) + } + (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { + dt_to_nano(v1).eq(&mdn_to_nano(v2)) + } (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { + mdn_to_nano(v1).eq(&ym_to_nano(v2)) + } + (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { + mdn_to_nano(v1).eq(&dt_to_nano(v2)) + } (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, @@ -304,10 +329,28 @@ impl PartialOrd for ScalarValue { } (TimestampNanosecond(_, _), _) => None, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(v1), IntervalDayTime(v2)) => { + ym_to_milli(v1).partial_cmp(&dt_to_milli(v2)) + } + (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { + ym_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) + } (IntervalYearMonth(_), _) => None, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(v1), IntervalYearMonth(v2)) => { + dt_to_milli(v1).partial_cmp(&ym_to_milli(v2)) + } + (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { + dt_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) + } (IntervalDayTime(_), _) => None, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { + mdn_to_nano(v1).partial_cmp(&ym_to_nano(v2)) + } + (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { + mdn_to_nano(v1).partial_cmp(&dt_to_nano(v2)) + } (IntervalMonthDayNano(_), _) => None, (Struct(v1, t1), Struct(v2, t2)) => { if t1.eq(t2) { @@ -332,6 +375,52 @@ impl PartialOrd for ScalarValue { } } +/// This function computes the duration (in milliseconds) of the given +/// year-month-interval. +#[inline] +fn ym_to_milli(val: &Option) -> Option { + val.map(|value| (value as i64) * MILLISECS_IN_ONE_MONTH) +} + +/// This function computes the duration (in nanoseconds) of the given +/// year-month-interval. +#[inline] +fn ym_to_nano(val: &Option) -> Option { + val.map(|value| (value as i128) * NANOSECS_IN_ONE_MONTH) +} + +/// This function computes the duration (in milliseconds) of the given +/// daytime-interval. +#[inline] +fn dt_to_milli(val: &Option) -> Option { + val.map(|val| { + let (days, millis) = IntervalDayTimeType::to_parts(val); + (days as i64) * MILLISECS_IN_ONE_DAY + (millis as i64) + }) +} + +/// This function computes the duration (in nanoseconds) of the given +/// daytime-interval. +#[inline] +fn dt_to_nano(val: &Option) -> Option { + val.map(|val| { + let (days, millis) = IntervalDayTimeType::to_parts(val); + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + (millis as i128) * 1_000_000 + }) +} + +/// This function computes the duration (in nanoseconds) of the given +/// month-day-nano-interval. Assumes a month is 30 days long. +#[inline] +fn mdn_to_nano(val: &Option) -> Option { + val.map(|val| { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); + (months as i128) * NANOSECS_IN_ONE_MONTH + + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + + (nanos as i128) + }) +} + impl Eq for ScalarValue {} // TODO implement this in arrow-rs with simd @@ -464,6 +553,71 @@ macro_rules! unsigned_subtraction_error { } macro_rules! impl_op { + ($LHS:expr, $RHS:expr, +) => { + impl_op_arithmetic!($LHS, $RHS, +) + }; + ($LHS:expr, $RHS:expr, -) => { + match ($LHS, $RHS) { + ( + ScalarValue::TimestampSecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampSecond(Some(ts_rhs), tz_rhs), + ) => { + let err = || { + DataFusionError::Execution( + "Overflow while converting seconds to milliseconds".to_string(), + ) + }; + ts_sub_to_interval( + ts_lhs.checked_mul(1_000).ok_or_else(err)?, + ts_rhs.checked_mul(1_000).ok_or_else(err)?, + &tz_lhs, + &tz_rhs, + IntervalMode::Milli, + ) + }, + ( + ScalarValue::TimestampMillisecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampMillisecond(Some(ts_rhs), tz_rhs), + ) => ts_sub_to_interval( + *ts_lhs, + *ts_rhs, + tz_lhs, + tz_rhs, + IntervalMode::Milli, + ), + ( + ScalarValue::TimestampMicrosecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampMicrosecond(Some(ts_rhs), tz_rhs), + ) => { + let err = || { + DataFusionError::Execution( + "Overflow while converting microseconds to nanoseconds".to_string(), + ) + }; + ts_sub_to_interval( + ts_lhs.checked_mul(1_000).ok_or_else(err)?, + ts_rhs.checked_mul(1_000).ok_or_else(err)?, + tz_lhs, + tz_rhs, + IntervalMode::Nano, + ) + }, + ( + ScalarValue::TimestampNanosecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampNanosecond(Some(ts_rhs), tz_rhs), + ) => ts_sub_to_interval( + *ts_lhs, + *ts_rhs, + tz_lhs, + tz_rhs, + IntervalMode::Nano, + ), + _ => impl_op_arithmetic!($LHS, $RHS, -) + } + }; +} + +macro_rules! impl_op_arithmetic { ($LHS:expr, $RHS:expr, $OPERATION:tt) => { match ($LHS, $RHS) { // Binary operations on arguments with the same type: @@ -503,6 +657,40 @@ macro_rules! impl_op { (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { primitive_op!(lhs, rhs, Int8, $OPERATION) } + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => Ok(ScalarValue::new_interval_ym( + 0, + lhs + rhs * get_sign!($OPERATION), + )), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => { + let sign = get_sign!($OPERATION); + let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(*lhs); + let (rhs_days, rhs_millis) = IntervalDayTimeType::to_parts(*rhs); + Ok(ScalarValue::new_interval_dt( + lhs_days + rhs_days * sign, + lhs_millis + rhs_millis * sign, + )) + } + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => { + let sign = get_sign!($OPERATION); + let (lhs_months, lhs_days, lhs_nanos) = + IntervalMonthDayNanoType::to_parts(*lhs); + let (rhs_months, rhs_days, rhs_nanos) = + IntervalMonthDayNanoType::to_parts(*rhs); + Ok(ScalarValue::new_interval_mdn( + lhs_months + rhs_months * sign, + lhs_days + rhs_days * sign, + lhs_nanos + rhs_nanos * (sign as i64), + )) + } // Binary operations on arguments with different types: (ScalarValue::Date32(Some(days)), _) => { let value = date32_add(*days, $RHS, get_sign!($OPERATION))?; @@ -544,6 +732,30 @@ macro_rules! impl_op { let value = nanoseconds_add(*ts_ns, $LHS, get_sign!($OPERATION))?; Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) } + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => op_ym_dt(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => op_ym_mdn(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => op_ym_dt(*rhs, *lhs, get_sign!($OPERATION), true), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => op_dt_mdn(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => op_ym_mdn(*rhs, *lhs, get_sign!($OPERATION), true), + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => op_dt_mdn(*rhs, *lhs, get_sign!($OPERATION), true), _ => Err(DataFusionError::Internal(format!( "Operator {} is not implemented for types {:?} and {:?}", stringify!($OPERATION), @@ -554,6 +766,68 @@ macro_rules! impl_op { }; } +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalYearMonthType`] and [`IntervalDayTimeType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_ym_dt(mut lhs: i32, rhs: i64, sign: i32, commute: bool) -> Result { + let (mut days, millis) = IntervalDayTimeType::to_parts(rhs); + let mut nanos = (millis as i64) * 1_000_000; + if commute { + lhs *= sign; + } else { + days *= sign; + nanos *= sign as i64; + }; + Ok(ScalarValue::new_interval_mdn(lhs, days, nanos)) +} + +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalYearMonthType`] and [`IntervalMonthDayNanoType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_ym_mdn(lhs: i32, rhs: i128, sign: i32, commute: bool) -> Result { + let (mut months, mut days, mut nanos) = IntervalMonthDayNanoType::to_parts(rhs); + if commute { + months += lhs * sign; + } else { + months = lhs + (months * sign); + days *= sign; + nanos *= sign as i64; + } + Ok(ScalarValue::new_interval_mdn(months, days, nanos)) +} + +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalDayTimeType`] and [`IntervalMonthDayNanoType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_dt_mdn(lhs: i64, rhs: i128, sign: i32, commute: bool) -> Result { + let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(lhs); + let (rhs_months, rhs_days, rhs_nanos) = IntervalMonthDayNanoType::to_parts(rhs); + + let result = if commute { + IntervalMonthDayNanoType::make_value( + rhs_months, + lhs_days * sign + rhs_days, + (lhs_millis * sign) as i64 * 1_000_000 + rhs_nanos, + ) + } else { + IntervalMonthDayNanoType::make_value( + rhs_months * sign, + lhs_days + rhs_days * sign, + (lhs_millis as i64) * 1_000_000 + rhs_nanos * (sign as i64), + ) + }; + Ok(ScalarValue::IntervalMonthDayNano(Some(result))) +} + macro_rules! get_sign { (+) => { 1 @@ -563,46 +837,138 @@ macro_rules! get_sign { }; } +#[derive(Clone, Copy)] +enum IntervalMode { + Milli, + Nano, +} + +/// This function computes subtracts `rhs_ts` from `lhs_ts`, taking timezones +/// into account when given. Units of the resulting interval is specified by +/// the argument `mode`. +/// The default behavior of Datafusion is the following: +/// - When subtracting timestamps at seconds/milliseconds precision, the output +/// interval will have the type [`IntervalDayTimeType`]. +/// - When subtracting timestamps at microseconds/nanoseconds precision, the +/// output interval will have the type [`IntervalMonthDayNanoType`]. +fn ts_sub_to_interval( + lhs_ts: i64, + rhs_ts: i64, + lhs_tz: &Option, + rhs_tz: &Option, + mode: IntervalMode, +) -> Result { + let lhs_dt = with_timezone_to_naive_datetime(lhs_ts, lhs_tz, mode)?; + let rhs_dt = with_timezone_to_naive_datetime(rhs_ts, rhs_tz, mode)?; + let delta_secs = lhs_dt.signed_duration_since(rhs_dt); + + match mode { + IntervalMode::Milli => { + let as_millisecs = delta_secs.num_milliseconds(); + Ok(ScalarValue::new_interval_dt( + (as_millisecs / MILLISECS_IN_ONE_DAY) as i32, + (as_millisecs % MILLISECS_IN_ONE_DAY) as i32, + )) + } + IntervalMode::Nano => { + let as_nanosecs = delta_secs.num_nanoseconds().ok_or_else(|| { + DataFusionError::Execution(String::from( + "Can not compute timestamp differences with nanosecond precision", + )) + })?; + Ok(ScalarValue::new_interval_mdn( + 0, + (as_nanosecs / NANOSECS_IN_ONE_DAY) as i32, + as_nanosecs % NANOSECS_IN_ONE_DAY, + )) + } + } +} + +/// This function creates the [`NaiveDateTime`] object corresponding to the +/// given timestamp using the units (tick size) implied by argument `mode`. +#[inline] +fn with_timezone_to_naive_datetime( + ts: i64, + tz: &Option, + mode: IntervalMode, +) -> Result { + let datetime = if let IntervalMode::Milli = mode { + ticks_to_naive_datetime::<1_000_000>(ts) + } else { + ticks_to_naive_datetime::<1>(ts) + }?; + + if let Some(tz) = tz { + let parsed_tz: Tz = FromStr::from_str(tz).map_err(|_| { + DataFusionError::Execution("cannot parse given timezone".to_string()) + })?; + let offset = parsed_tz + .offset_from_local_datetime(&datetime) + .single() + .ok_or_else(|| { + DataFusionError::Execution( + "error conversion result of timezone offset".to_string(), + ) + })?; + return Ok(DateTime::::from_local(datetime, offset).naive_utc()); + } + Ok(datetime) +} + +/// This function creates the [`NaiveDateTime`] object corresponding to the +/// given timestamp, whose tick size is specified by `UNIT_NANOS`. +#[inline] +fn ticks_to_naive_datetime(ticks: i64) -> Result { + NaiveDateTime::from_timestamp_opt( + (ticks * UNIT_NANOS) / 1_000_000_000, + ((ticks * UNIT_NANOS) % 1_000_000_000) as u32, + ) + .ok_or_else(|| { + DataFusionError::Execution( + "Can not convert given timestamp to a NaiveDateTime".to_string(), + ) + }) +} + #[inline] pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let prior = epoch.add(Duration::days(days as i64)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_days() as i32) + do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_days() as i32) } #[inline] pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let prior = epoch.add(Duration::milliseconds(ms)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_milliseconds()) + do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_milliseconds()) } #[inline] pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) + do_date_time_math(ts_s, 0, scalar, sign).map(|dt| dt.timestamp()) } #[inline] pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_ms / 1000; let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_millis()) } #[inline] pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_us / 1_000_000; let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos() / 1000) } #[inline] pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_ns / 1_000_000_000; let nsecs = (ts_ns % 1_000_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos()) } #[inline] @@ -2921,6 +3287,7 @@ mod tests { use arrow::compute::kernels; use arrow::datatypes::ArrowPrimitiveType; + use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; use crate::from_slice::FromSlice; @@ -3707,6 +4074,53 @@ mod tests { ])), None ); + // Different type of intervals can be compared. + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 2))) + < IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 14, 0, 1 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 4))) + >= IntervalDayTime(Some(IntervalDayTimeType::make_value(119, 1))) + ); + assert!( + IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 86_399_999))) + >= IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 0))) + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(2, 12))) + == IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 36, 0, 0 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 0))) + != IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 1))) + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 4))) + == IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 16))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 3))) + > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 2, + 28, + 999_999_999 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 1))) + > IntervalDayTime(Some(IntervalDayTimeType::make_value(29, 9_999))), + ); + assert!( + IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value(1, 12, 34))) + > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 0, 142, 34 + ))) + ); } #[test] @@ -4486,4 +4900,513 @@ mod tests { assert!(distance.is_none()); } } + + #[test] + fn test_scalar_interval_add() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(2, 24), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(2, 1998), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(24, 30, 246_912), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_dt(29, 86_390), + ScalarValue::new_interval_mdn(1, 29, 86_390_000_000), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_mdn(2, 10, 999_999_999), + ScalarValue::new_interval_mdn(3, 10, 999_999_999), + ), + ( + ScalarValue::new_interval_dt(400, 123_456), + ScalarValue::new_interval_ym(1, 1), + ScalarValue::new_interval_mdn(13, 400, 123_456_000_000), + ), + ( + ScalarValue::new_interval_dt(65, 321), + ScalarValue::new_interval_mdn(2, 5, 1_000_000), + ScalarValue::new_interval_mdn(2, 70, 322_000_000), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_ym(2, 0), + ScalarValue::new_interval_mdn(36, 15, 123_456), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 100_000), + ScalarValue::new_interval_dt(370, 1), + ScalarValue::new_interval_mdn(12, 385, 1_100_000), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.add(rhs).unwrap(); + let result_commute = rhs.add(lhs).unwrap(); + assert_eq!(*expected, result, "lhs:{:?} + rhs:{:?}", lhs, rhs); + assert_eq!(*expected, result_commute, "lhs:{:?} + rhs:{:?}", rhs, lhs); + } + } + + #[test] + fn test_scalar_interval_sub() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(0, 0), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(0, 0), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(0, 0, 0), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_dt(29, 999_999), + ScalarValue::new_interval_mdn(1, -29, -999_999_000_000), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_mdn(2, 10, 999_999_999), + ScalarValue::new_interval_mdn(-1, -10, -999_999_999), + ), + ( + ScalarValue::new_interval_dt(400, 123_456), + ScalarValue::new_interval_ym(1, 1), + ScalarValue::new_interval_mdn(-13, 400, 123_456_000_000), + ), + ( + ScalarValue::new_interval_dt(65, 321), + ScalarValue::new_interval_mdn(2, 5, 1_000_000), + ScalarValue::new_interval_mdn(-2, 60, 320_000_000), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_ym(2, 0), + ScalarValue::new_interval_mdn(-12, 15, 123_456), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 100_000), + ScalarValue::new_interval_dt(370, 1), + ScalarValue::new_interval_mdn(12, -355, -900_000), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.sub(rhs).unwrap(); + assert_eq!(*expected, result, "lhs:{:?} - rhs:{:?}", lhs, rhs); + } + } + + #[test] + fn timestamp_op_tests() { + // positive interval, edge cases + let test_data = get_timestamp_test_data(1); + for (lhs, rhs, expected) in test_data.into_iter() { + assert_eq!(expected, lhs.sub(rhs).unwrap()) + } + + // negative interval, edge cases + let test_data = get_timestamp_test_data(-1); + for (rhs, lhs, expected) in test_data.into_iter() { + assert_eq!(expected, lhs.sub(rhs).unwrap()); + } + } + #[test] + fn timestamp_op_random_tests() { + // timestamp1 + (or -) interval = timestamp2 + // timestamp2 - timestamp1 (or timestamp1 - timestamp2) = interval ? + let sample_size = 1000000; + let timestamps1 = get_random_timestamps(sample_size); + let intervals = get_random_intervals(sample_size); + // ts(sec) + interval(ns) = ts(sec); however, + // ts(sec) - ts(sec) cannot be = interval(ns). Therefore, + // timestamps are more precise than intervals in tests. + for (idx, ts1) in timestamps1.iter().enumerate() { + if idx % 2 == 0 { + let timestamp2 = ts1.add(intervals[idx].clone()).unwrap(); + assert_eq!( + intervals[idx], + timestamp2.sub(ts1).unwrap(), + "index:{}, operands: {:?} (-) {:?}", + idx, + timestamp2, + ts1 + ); + } else { + let timestamp2 = ts1.sub(intervals[idx].clone()).unwrap(); + assert_eq!( + intervals[idx], + ts1.sub(timestamp2.clone()).unwrap(), + "index:{}, operands: {:?} (-) {:?}", + idx, + ts1, + timestamp2 + ); + }; + } + } + + fn get_timestamp_test_data( + sign: i32, + ) -> Vec<(ScalarValue, ScalarValue, ScalarValue)> { + vec![ + ( + // 1st test case, having the same time but different with timezones + // Since they are timestamps with nanosecond precision, expected type is + // [`IntervalMonthDayNanoType`] + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_nano_opt(12, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + Some("+12:00".to_string()), + ), + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + Some("+00:00".to_string()), + ), + ScalarValue::new_interval_mdn(0, 0, 0), + ), + // 2nd test case, january with 31 days plus february with 28 days, with timezone + ( + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 1) + .unwrap() + .and_hms_micro_opt(2, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + Some("+01:00".to_string()), + ), + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_micro_opt(0, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + Some("-01:00".to_string()), + ), + ScalarValue::new_interval_mdn(0, sign * 59, 0), + ), + // 3rd test case, 29-days long february minus previous, year with timezone + ( + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2024, 2, 29) + .unwrap() + .and_hms_milli_opt(10, 10, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+10:10".to_string()), + ), + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 12, 31) + .unwrap() + .and_hms_milli_opt(1, 0, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+01:00".to_string()), + ), + ScalarValue::new_interval_dt(sign * 60, 0), + ), + // 4th test case, leap years occur mostly every 4 years, but every 100 years + // we skip a leap year unless the year is divisible by 400, so 31 + 28 = 59 + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2100, 3, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + Some("-11:59".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2100, 1, 1) + .unwrap() + .and_hms_opt(23, 58, 0) + .unwrap() + .timestamp(), + ), + Some("+11:59".to_string()), + ), + ScalarValue::new_interval_dt(sign * 59, 0), + ), + // 5th test case, without timezone positively seemed, but with timezone, + // negative resulting interval + ( + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(6, 00, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+06:00".to_string()), + ), + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(0, 0, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("-12:00".to_string()), + ), + ScalarValue::new_interval_dt(0, sign * -43_200_000), + ), + // 6th test case, no problem before unix epoch beginning + ( + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .and_hms_micro_opt(1, 2, 3, 15) + .unwrap() + .timestamp_micros(), + ), + None, + ), + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(1969, 1, 1) + .unwrap() + .and_hms_micro_opt(0, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + None, + ), + ScalarValue::new_interval_mdn( + 0, + 365 * sign, + sign as i64 * 3_723_000_015_000, + ), + ), + // 7th test case, no problem with big intervals + ( + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2100, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 0) + .unwrap() + .timestamp_nanos(), + ), + None, + ), + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2000, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + None, + ), + ScalarValue::new_interval_mdn(0, sign * 36525, 0), + ), + // 8th test case, no problem detecting 366-days long years + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2041, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2040, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ), + ScalarValue::new_interval_dt(sign * 366, 0), + ), + // 9th test case, no problem with unrealistic timezones + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 3) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + Some("+23:59".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_opt(0, 2, 0) + .unwrap() + .timestamp(), + ), + Some("-23:59".to_string()), + ), + ScalarValue::new_interval_dt(0, 0), + ), + // 10th test case, parsing different types of timezone input + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 17) + .unwrap() + .and_hms_opt(14, 10, 0) + .unwrap() + .timestamp(), + ), + Some("Europe/Istanbul".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 17) + .unwrap() + .and_hms_opt(4, 10, 0) + .unwrap() + .timestamp(), + ), + Some("America/Los_Angeles".to_string()), + ), + ScalarValue::new_interval_dt(0, 0), + ), + ] + } + + fn get_random_timestamps(sample_size: u64) -> Vec { + let vector_size = sample_size; + let mut timestamp = vec![]; + let mut rng = rand::thread_rng(); + for i in 0..vector_size { + let year = rng.gen_range(1995..=2050); + let month = rng.gen_range(1..=12); + let day = rng.gen_range(1..=28); // to exclude invalid dates + let hour = rng.gen_range(0..=23); + let minute = rng.gen_range(0..=59); + let second = rng.gen_range(0..=59); + if i % 4 == 0 { + timestamp.push(ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_opt(hour, minute, second) + .unwrap() + .timestamp(), + ), + None, + )) + } else if i % 4 == 1 { + let millisec = rng.gen_range(0..=999); + timestamp.push(ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_milli_opt(hour, minute, second, millisec) + .unwrap() + .timestamp_millis(), + ), + None, + )) + } else if i % 4 == 2 { + let microsec = rng.gen_range(0..=999_999); + timestamp.push(ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_micro_opt(hour, minute, second, microsec) + .unwrap() + .timestamp_micros(), + ), + None, + )) + } else if i % 4 == 3 { + let nanosec = rng.gen_range(0..=999_999_999); + timestamp.push(ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_nano_opt(hour, minute, second, nanosec) + .unwrap() + .timestamp_nanos(), + ), + None, + )) + } + } + timestamp + } + + fn get_random_intervals(sample_size: u64) -> Vec { + let vector_size = sample_size; + let mut intervals = vec![]; + let mut rng = rand::thread_rng(); + const SECS_IN_ONE_DAY: i32 = 86_400; + const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; + for i in 0..vector_size { + if i % 4 == 0 { + let days = rng.gen_range(0..5000); + // to not break second precision + let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_dt(days, millis)); + } else if i % 4 == 1 { + let days = rng.gen_range(0..5000); + let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + intervals.push(ScalarValue::new_interval_dt(days, millisec)); + } else if i % 4 == 2 { + let days = rng.gen_range(0..5000); + // to not break microsec precision + let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } else { + let days = rng.gen_range(0..5000); + let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } + } + intervals + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index aad66473aec3..d0829702d07f 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -35,7 +35,38 @@ impl<'a> std::fmt::Display for ResolvedTableReference<'a> { } } -/// Represents a path to a table that may require further resolution +/// [`TableReference`]s represent a multi part identifier (path) to a +/// table that may require further resolution. +/// +/// # Creating [`TableReference`] +/// +/// When converting strings to [`TableReference`]s, the string is +/// parsed as though it were a SQL identifier, normalizing (convert to +/// lowercase) any unquoted identifiers. +/// +/// See [`TableReference::bare`] to create references without applying +/// normalization semantics +/// +/// # Examples +/// ``` +/// # use datafusion_common::TableReference; +/// // Get a table reference to 'mytable' +/// let table_reference = TableReference::from("mytable"); +/// assert_eq!(table_reference, TableReference::bare("mytable")); +/// +/// // Get a table reference to 'mytable' (note the capitalization) +/// let table_reference = TableReference::from("MyTable"); +/// assert_eq!(table_reference, TableReference::bare("mytable")); +/// +/// // Get a table reference to 'MyTable' (note the capitalization) using double quotes +/// // (programatically it is better to use `TableReference::bare` for this) +/// let table_reference = TableReference::from(r#""MyTable""#); +/// assert_eq!(table_reference, TableReference::bare("MyTable")); +/// +/// // Get a table reference to 'myschema.mytable' (note the capitalization) +/// let table_reference = TableReference::from("MySchema.MyTable"); +/// assert_eq!(table_reference, TableReference::partial("myschema", "mytable")); +///``` #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum TableReference<'a> { /// An unqualified table reference, e.g. "table" @@ -61,6 +92,16 @@ pub enum TableReference<'a> { }, } +/// This is a [`TableReference`] that has 'static lifetime (aka it +/// owns the underlying string) +/// +/// To convert a [`TableReference`] to an [`OwnedTableReference`], use +/// +/// ``` +/// # use datafusion_common::{OwnedTableReference, TableReference}; +/// let table_reference = TableReference::from("mytable"); +/// let owned_reference = table_reference.to_owned_reference(); +/// ``` pub type OwnedTableReference = TableReference<'static>; impl std::fmt::Display for TableReference<'_> { @@ -85,14 +126,20 @@ impl<'a> TableReference<'a> { None } - /// Convenience method for creating a `Bare` variant of `TableReference` + /// Convenience method for creating a [`TableReference::Bare`] + /// + /// As described on [`TableReference`] this does *NO* parsing at + /// all, so "Foo.Bar" stays as a reference to the table named + /// "Foo.Bar" (rather than "foo"."bar") pub fn bare(table: impl Into>) -> TableReference<'a> { TableReference::Bare { table: table.into(), } } - /// Convenience method for creating a `Partial` variant of `TableReference` + /// Convenience method for creating a [`TableReference::Partial`]. + /// + /// As described on [`TableReference`] this does *NO* parsing at all. pub fn partial( schema: impl Into>, table: impl Into>, @@ -103,7 +150,9 @@ impl<'a> TableReference<'a> { } } - /// Convenience method for creating a `Full` variant of `TableReference` + /// Convenience method for creating a [`TableReference::Full`] + /// + /// As described on [`TableReference`] this does *NO* parsing at all. pub fn full( catalog: impl Into>, schema: impl Into>, @@ -141,12 +190,12 @@ impl<'a> TableReference<'a> { } } - /// Compare with another `TableReference` as if both are resolved. + /// Compare with another [`TableReference`] as if both are resolved. /// This allows comparing across variants, where if a field is not present /// in both variants being compared then it is ignored in the comparison. /// - /// e.g. this allows a `TableReference::Bare` to be considered equal to a - /// fully qualified `TableReference::Full` if the table names match. + /// e.g. this allows a [`TableReference::Bare`] to be considered equal to a + /// fully qualified [`TableReference::Full`] if the table names match. pub fn resolved_eq(&self, other: &Self) -> bool { match self { TableReference::Bare { table } => table == other.table(), @@ -194,7 +243,8 @@ impl<'a> TableReference<'a> { } } - /// Converts directly into an [`OwnedTableReference`] + /// Converts directly into an [`OwnedTableReference`] by cloning + /// the underlying data. pub fn to_owned_reference(&self) -> OwnedTableReference { match self { Self::Full { @@ -217,6 +267,16 @@ impl<'a> TableReference<'a> { } /// Forms a string where the identifiers are quoted + /// + /// # Example + /// ``` + /// # use datafusion_common::TableReference; + /// let table_reference = TableReference::partial("myschema", "mytable"); + /// assert_eq!(table_reference.to_quoted_string(), r#""myschema"."mytable""#); + /// + /// let table_reference = TableReference::partial("MySchema", "MyTable"); + /// assert_eq!(table_reference.to_quoted_string(), r#""MySchema"."MyTable""#); + /// ``` pub fn to_quoted_string(&self) -> String { match self { TableReference::Bare { table } => quote_identifier(table), @@ -236,14 +296,8 @@ impl<'a> TableReference<'a> { } } - /// Forms a [`TableReference`] by attempting to parse `s` as a multipart identifier, - /// failing that then taking the entire unnormalized input as the identifier itself. - /// - /// Will normalize (convert to lowercase) any unquoted identifiers. - /// - /// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as - /// `Foo".bar` (note the preserved case and requiring two double quotes to represent - /// a single double quote in the identifier) + /// Forms a [`TableReference`] by parsing `s` as a multipart SQL + /// identifier. See docs on [`TableReference`] for more details. pub fn parse_str(s: &'a str) -> Self { let mut parts = parse_identifiers_normalized(s); @@ -265,7 +319,7 @@ impl<'a> TableReference<'a> { } } -/// Parse a `String` into a OwnedTableReference +/// Parse a `String` into a OwnedTableReference as a multipart SQL identifier. impl From for OwnedTableReference { fn from(s: String) -> Self { TableReference::parse_str(&s).to_owned_reference() diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 1b8b742c2ae1..e3e2987d3dfe 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -344,91 +344,70 @@ impl DataFrame { //collect recordBatch let describe_record_batch = vec![ // count aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .map(|f| count(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .map(|f| count(col(f.name())).alias(f.name())) + .collect::>(), + ), // null_count aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .map(|f| count(is_null(col(f.name()))).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .map(|f| count(is_null(col(f.name()))).alias(f.name())) + .collect::>(), + ), // mean aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| avg(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| avg(col(f.name())).alias(f.name())) + .collect::>(), + ), // std aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| stddev(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| stddev(col(f.name())).alias(f.name())) + .collect::>(), + ), // min aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) - }) - .map(|f| min(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| { + !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + }) + .map(|f| min(col(f.name())).alias(f.name())) + .collect::>(), + ), // max aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) - }) - .map(|f| max(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| { + !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + }) + .map(|f| max(col(f.name())).alias(f.name())) + .collect::>(), + ), // median aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| median(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| median(col(f.name())).alias(f.name())) + .collect::>(), + ), ]; // first column with function names @@ -437,24 +416,44 @@ impl DataFrame { ))]; for field in original_schema_fields { let mut array_datas = vec![]; - for record_batch in describe_record_batch.iter() { - // safe unwrap since aggregate record batches should have at least 1 record - let column = record_batch.get(0).unwrap().column_by_name(field.name()); - match column { - Some(c) => { - if field.data_type().is_numeric() { - array_datas.push(cast(c, &DataType::Float64)?); - } else { - array_datas.push(cast(c, &DataType::Utf8)?); + for result in describe_record_batch.iter() { + let array_ref = match result { + Ok(df) => { + let batchs = df.clone().collect().await; + match batchs { + Ok(batchs) + if batchs.len() == 1 + && batchs[0] + .column_by_name(field.name()) + .is_some() => + { + let column = + batchs[0].column_by_name(field.name()).unwrap(); + if field.data_type().is_numeric() { + cast(column, &DataType::Float64)? + } else { + cast(column, &DataType::Utf8)? + } + } + _ => Arc::new(StringArray::from_slice(["null"])), } } - //if None mean the column cannot be min/max aggregation - None => { - array_datas.push(Arc::new(StringArray::from_slice(["null"]))); + //Handling error when only boolean/binary column, and in other cases + Err(err) + if err.to_string().contains( + "Error during planning: \ + Aggregate requires at least one grouping \ + or aggregate expression", + ) => + { + Arc::new(StringArray::from_slice(["null"])) } - } + Err(other_err) => { + panic!("{other_err}") + } + }; + array_datas.push(array_ref); } - array_ref_vec.push(concat( array_datas .iter() diff --git a/datafusion/core/src/datasource/file_format/file_type.rs b/datafusion/core/src/datasource/file_format/file_type.rs index 59c95962a992..e07eb8a3d7a6 100644 --- a/datafusion/core/src/datasource/file_format/file_type.rs +++ b/datafusion/core/src/datasource/file_format/file_type.rs @@ -30,10 +30,10 @@ use async_compression::tokio::bufread::{ }; use bytes::Bytes; #[cfg(feature = "compression")] -use bzip2::read::BzDecoder; +use bzip2::read::MultiBzDecoder; use datafusion_common::parsers::CompressionTypeVariant; #[cfg(feature = "compression")] -use flate2::read::GzDecoder; +use flate2::read::MultiGzDecoder; use futures::Stream; #[cfg(feature = "compression")] use futures::TryStreamExt; @@ -168,11 +168,11 @@ impl FileCompressionType { ) -> Result> { Ok(match self.variant { #[cfg(feature = "compression")] - GZIP => Box::new(GzDecoder::new(r)), + GZIP => Box::new(MultiGzDecoder::new(r)), #[cfg(feature = "compression")] - BZIP2 => Box::new(BzDecoder::new(r)), + BZIP2 => Box::new(MultiBzDecoder::new(r)), #[cfg(feature = "compression")] - XZ => Box::new(XzDecoder::new(r)), + XZ => Box::new(XzDecoder::new_multi_decoder(r)), #[cfg(feature = "compression")] ZSTD => match ZstdDecoder::new(r) { Ok(decoder) => Box::new(decoder), diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 0ac836cf28ed..fa276c423879 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -353,12 +353,12 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { // For instance, if `n_out` number of rows are calculated, we can remove // first `n_out` rows from `self.input_buffer_record_batch`. fn prune_state(&mut self, n_out: usize) -> Result<()> { + // Prune `self.window_agg_states`: + self.prune_out_columns(n_out)?; // Prune `self.partition_batches`: self.prune_partition_batches()?; // Prune `self.input_buffer_record_batch`: self.prune_input_batch(n_out)?; - // Prune `self.window_agg_states`: - self.prune_out_columns(n_out)?; Ok(()) } @@ -548,9 +548,9 @@ impl SortedPartitionByBoundedWindowStream { for (partition_row, WindowState { state: value, .. }) in window_agg_state { let n_prune = min(value.window_frame_range.start, value.last_calculated_index); - if let Some(state) = n_prune_each_partition.get_mut(partition_row) { - if n_prune < *state { - *state = n_prune; + if let Some(current) = n_prune_each_partition.get_mut(partition_row) { + if n_prune < *current { + *current = n_prune; } } else { n_prune_each_partition.insert(partition_row.clone(), n_prune); @@ -571,15 +571,7 @@ impl SortedPartitionByBoundedWindowStream { // Update state indices since we have pruned some rows from the beginning: for window_agg_state in self.window_agg_states.iter_mut() { - let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(err)?; - let mut state = &mut window_state.state; - state.window_frame_range = Range { - start: state.window_frame_range.start - n_prune, - end: state.window_frame_range.end - n_prune, - }; - state.last_calculated_index -= n_prune; - state.offset_pruned_rows += n_prune; + window_agg_state[partition_row].state.prune_state(*n_prune); } } Ok(()) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 82bd0d8443f9..a09f458c8cd9 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -39,12 +39,15 @@ use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable}; async fn describe() -> Result<()> { let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; let describe_record_batch = ctx - .read_parquet( - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) + .table("alltypes_tiny_pages") .await? .describe() .await? @@ -67,6 +70,30 @@ async fn describe() -> Result<()> { ]; assert_batches_eq!(expected, &describe_record_batch); + //add test case for only boolean boolean/binary column + let result = ctx + .sql("select 'a' as a,true as b") + .await? + .describe() + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = vec![ + "+------------+------+------+", + "| describe | a | b |", + "+------------+------+------+", + "| count | 1 | 1 |", + "| null_count | 1 | 1 |", + "| mean | null | null |", + "| std | null | null |", + "| min | a | null |", + "| max | a | null |", + "| median | null | null |", + "+------------+------+------+", + ]; + assert_batches_eq!(expected, &result); + Ok(()) } diff --git a/datafusion/core/tests/sql/order.rs b/datafusion/core/tests/sql/order.rs index 2388eebef931..e29904c21466 100644 --- a/datafusion/core/tests/sql/order.rs +++ b/datafusion/core/tests/sql/order.rs @@ -39,3 +39,86 @@ async fn sort_with_lots_of_repetition_values() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn sort_with_duplicate_sort_exprs() -> Result<()> { + let ctx = SessionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ])); + + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 4, 9, 3, 4])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d", "e"])), + ], + )?; + ctx.register_batch("t1", t1_data)?; + + let sql = "select * from t1 order by id desc, id, name, id asc"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + let expected = vec![ + "Sort: t1.id DESC NULLS FIRST, t1.name ASC NULLS LAST [id:Int32;N, name:Utf8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:Utf8;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+----+------+", + "| id | name |", + "+----+------+", + "| 9 | c |", + "| 4 | b |", + "| 4 | e |", + "| 3 | d |", + "| 2 | a |", + "+----+------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &results); + + let sql = "select * from t1 order by id asc, id, name, id desc;"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + let expected = vec![ + "Sort: t1.id ASC NULLS LAST, t1.name ASC NULLS LAST [id:Int32;N, name:Utf8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:Utf8;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+----+------+", + "| id | name |", + "+----+------+", + "| 2 | a |", + "| 3 | d |", + "| 4 | b |", + "| 4 | e |", + "| 9 | c |", + "+----+------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 8f95eba572ba..535e1d89170f 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -61,7 +61,20 @@ async fn window_frame_creation_type_checking() -> Result<()> { ).await } -fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { +fn split_record_batch(batch: RecordBatch, n_split: usize) -> Vec { + let n_chunk = batch.num_rows() / n_split; + let mut res = vec![]; + for i in 0..n_split - 1 { + let chunk = batch.slice(i * n_chunk, n_chunk); + res.push(chunk); + } + let start = (n_split - 1) * n_chunk; + let len = batch.num_rows() - start; + res.push(batch.slice(start, len)); + res +} + +fn get_test_data(n_split: usize) -> Result> { let ts_field = Field::new("ts", DataType::Int32, false); let inc_field = Field::new("inc_col", DataType::Int32, false); let desc_field = Field::new("desc_col", DataType::Int32, false); @@ -100,19 +113,19 @@ fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { ])), ], )?; - let n_chunk = batch.num_rows() / n_file; - for i in 0..n_file { + Ok(split_record_batch(batch, n_split)) +} + +fn write_test_data_to_parquet(tmpdir: &TempDir, n_split: usize) -> Result<()> { + let batches = get_test_data(n_split)?; + for (i, batch) in batches.into_iter().enumerate() { let target_file = tmpdir.path().join(format!("{i}.parquet")); let file = File::create(target_file).unwrap(); // Default writer properties let props = WriterProperties::builder().build(); - let chunks_start = i * n_chunk; - let cur_batch = batch.slice(chunks_start, n_chunk); - // let chunks_end = chunks_start + n_chunk; - let mut writer = - ArrowWriter::try_new(file, cur_batch.schema(), Some(props)).unwrap(); + let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props)).unwrap(); - writer.write(&cur_batch).expect("Writing batch"); + writer.write(&batch).expect("Writing batch"); // writer must be closed to write footer writer.close().unwrap(); @@ -120,12 +133,11 @@ fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { Ok(()) } -async fn get_test_context(tmpdir: &TempDir) -> Result { +async fn get_test_context(tmpdir: &TempDir, n_batch: usize) -> Result { let session_config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::with_config(session_config); let parquet_read_options = ParquetReadOptions::default(); - // The sort order is specified (not actually correct in this case) let file_sort_order = [col("ts")] .into_iter() .map(|e| { @@ -139,7 +151,7 @@ async fn get_test_context(tmpdir: &TempDir) -> Result { .to_listing_options(&ctx.copied_config()) .with_file_sort_order(Some(file_sort_order)); - write_test_data_to_parquet(tmpdir, 1)?; + write_test_data_to_parquet(tmpdir, n_batch)?; let provided_schema = None; let sql_definition = None; ctx.register_listing_table( @@ -160,7 +172,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_aggregate() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as sum1, @@ -235,7 +247,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_builtin() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, @@ -309,7 +321,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_unbounded_preceding() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT SUM(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as sum1, @@ -368,7 +380,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_unbounded_preceding_builtin() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT FIRST_VALUE(inc_col) OVER(ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as first_value1, diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index 148c7a4fd0b9..85a5bc18d4e2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -2038,3 +2038,39 @@ SELECT statement ok set datafusion.execution.target_partitions = 2; + +# test_window_agg_with_bounded_group +query TT +EXPLAIN SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c12) OVER(ORDER BY c1 GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING) as sum2 + FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +logical_plan +Projection: sum1, sum2 + Limit: skip=0, fetch=5 + Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 + Projection: SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING AS sum2, aggregate_test_100.c9 + WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING]] + WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] + TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +physical_plan +ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] + GlobalLimitExec: skip=0, fetch=5 + SortExec: fetch=5, expr=[c9@2 ASC NULLS LAST] + ProjectionExec: expr=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@14 as sum2, c9@8 as c9] + BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)) }] + BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }] + SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] + CsvExec: files={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, has_header=true, limit=None, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] + +query RR +SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c12) OVER(ORDER BY c1 GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING) as sum2 + FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +4.561269874379 18.036183428008 +6.808931568966 10.238448667883 +2.994840293343 NULL +9.674390599321 NULL +7.728066219895 NULL + diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index e03758600938..a71aab280ea1 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -227,9 +227,7 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { } else if rand_num < 2 { WindowFrameUnits::Rows } else { - // For now we do not support GROUPS in BoundedWindowAggExec implementation - // TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also. - WindowFrameUnits::Range + WindowFrameUnits::Groups }; match units { // In range queries window frame boundaries should match column type @@ -256,8 +254,8 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { } window_frame } - // In window queries, window frame boundary should be Uint64 - WindowFrameUnits::Rows => { + // Window frame boundary should be UInt64 for both ROWS and GROUPS frames: + WindowFrameUnits::Rows | WindowFrameUnits::Groups => { let start_bound = if start_bound.is_preceding { WindowFrameBound::Preceding(ScalarValue::UInt64(Some( start_bound.val as u64, @@ -286,10 +284,10 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); } + // We never use UNBOUNDED FOLLOWING in test. Because that case is not prunable and + // should work only with WindowAggExec window_frame } - // Once GROUPS support is added construct window frame for this case also - _ => todo!(), } } @@ -401,7 +399,7 @@ async fn run_window_test( assert_eq!( (i, usual_line), (i, running_line), - "Inconsistent result for window_fn: {window_fn:?}, args:{args:?}" + "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}" ); } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 5a882108e0fd..15f3d8e1d851 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -18,8 +18,9 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; +use datafusion_expr::expr::Sort as ExprSort; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::Sort; +use datafusion_expr::{Expr, Sort}; use hashbrown::HashSet; /// Optimization rule that eliminate duplicated expr. @@ -41,15 +42,28 @@ impl OptimizerRule for EliminateDuplicatedExpr { ) -> Result> { match plan { LogicalPlan::Sort(sort) => { + let normalized_sort_keys = sort + .expr + .iter() + .map(|e| match e { + Expr::Sort(ExprSort { expr, .. }) => { + Expr::Sort(ExprSort::new(expr.clone(), true, false)) + } + _ => e.clone(), + }) + .collect::>(); + // dedup sort.expr and keep order let mut dedup_expr = Vec::new(); let mut dedup_set = HashSet::new(); - for expr in &sort.expr { - if !dedup_set.contains(expr) { - dedup_expr.push(expr); - dedup_set.insert(expr.clone()); - } - } + sort.expr.iter().zip(normalized_sort_keys.iter()).for_each( + |(expr, normalized_expr)| { + if !dedup_set.contains(normalized_expr) { + dedup_expr.push(expr); + dedup_set.insert(normalized_expr); + } + }, + ); if dedup_expr.len() == sort.expr.len() { Ok(None) } else { @@ -100,4 +114,23 @@ mod tests { \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } + + #[test] + fn eliminate_sort_exprs_with_options() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let sort_exprs = vec![ + col("a").sort(true, true), + col("b").sort(true, false), + col("a").sort(false, false), + col("b").sort(false, true), + ]; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(sort_exprs)? + .limit(5, Some(10))? + .build()?; + let expected = "Limit: skip=5, fetch=10\ + \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected) + } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index e9b43cd070cd..95fd86148ac2 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -27,7 +27,7 @@ use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; +use datafusion_expr::{Accumulator, WindowFrame}; use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr}; use crate::window::{ @@ -115,11 +115,8 @@ impl WindowExpr for PlainAggregateWindowExpr { })?; let mut state = &mut window_state.state; if self.window_frame.start_bound.is_unbounded() { - state.window_frame_range.start = if state.window_frame_range.end >= 1 { - state.window_frame_range.end - 1 - } else { - 0 - }; + state.window_frame_range.start = + state.window_frame_range.end.saturating_sub(1); } } Ok(()) @@ -159,10 +156,8 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.end_bound.is_unbounded() - && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 70ddb2c7671a..329eac333460 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -32,11 +32,11 @@ use crate::window::{ }; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameUnits}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::WindowFrame; /// A window expr that takes the form of a built in window function #[derive(Debug)] @@ -104,17 +104,20 @@ impl WindowExpr for BuiltInWindowExpr { let mut row_wise_results = vec![]; let (values, order_bys) = self.get_values_orderbys(batch)?; - let mut window_frame_ctx = WindowFrameContext::new( - &self.window_frame, - sort_options, - Range { start: 0, end: 0 }, - ); + let mut window_frame_ctx = + WindowFrameContext::new(self.window_frame.clone(), sort_options); + let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { - let range = - window_frame_ctx.calculate_range(&order_bys, num_rows, idx)?; + let range = window_frame_ctx.calculate_range( + &order_bys, + &last_range, + num_rows, + idx, + )?; let value = evaluator.evaluate_inside_range(&values, &range)?; row_wise_results.push(value); + last_range = range; } ScalarValue::iter_to_array(row_wise_results.into_iter()) } else if evaluator.include_rank() { @@ -139,26 +142,23 @@ impl WindowExpr for BuiltInWindowExpr { let out_type = field.data_type(); let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); for (partition_row, partition_batch_state) in partition_batches.iter() { - if !window_agg_state.contains_key(partition_row) { - let evaluator = self.expr.create_evaluator()?; - window_agg_state.insert( - partition_row.clone(), - WindowState { - state: WindowAggState::new(out_type)?, - window_fn: WindowFn::Builtin(evaluator), - }, - ); - }; let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(|| { - DataFusionError::Execution("Cannot find state".to_string()) - })?; + if let Some(window_state) = window_agg_state.get_mut(partition_row) { + window_state + } else { + let evaluator = self.expr.create_evaluator()?; + window_agg_state + .entry(partition_row.clone()) + .or_insert(WindowState { + state: WindowAggState::new(out_type)?, + window_fn: WindowFn::Builtin(evaluator), + }) + }; let evaluator = match &mut window_state.window_fn { WindowFn::Builtin(evaluator) => evaluator, _ => unreachable!(), }; let mut state = &mut window_state.state; - state.is_end = partition_batch_state.is_end; let (values, order_bys) = self.get_values_orderbys(&partition_batch_state.record_batch)?; @@ -166,13 +166,6 @@ impl WindowExpr for BuiltInWindowExpr { // We iterate on each row to perform a running calculation. let record_batch = &partition_batch_state.record_batch; let num_rows = record_batch.num_rows(); - let last_range = state.window_frame_range.clone(); - let mut window_frame_ctx = WindowFrameContext::new( - &self.window_frame, - sort_options.clone(), - // Start search from the last range - last_range, - ); let sort_partition_points = if evaluator.include_rank() { let columns = self.sort_columns(record_batch)?; self.evaluate_partition_points(num_rows, &columns)? @@ -180,33 +173,43 @@ impl WindowExpr for BuiltInWindowExpr { vec![] }; let mut row_wise_results: Vec = vec![]; - let mut last_range = state.window_frame_range.clone(); for idx in state.last_calculated_index..num_rows { - state.window_frame_range = if self.expr.uses_window_frame() { - window_frame_ctx.calculate_range(&order_bys, num_rows, idx) + let frame_range = if self.expr.uses_window_frame() { + state + .window_frame_ctx + .get_or_insert_with(|| { + WindowFrameContext::new( + self.window_frame.clone(), + sort_options.clone(), + ) + }) + .calculate_range( + &order_bys, + // Start search from the last range + &state.window_frame_range, + num_rows, + idx, + ) } else { - evaluator.get_range(state, num_rows) + evaluator.get_range(idx, num_rows) }?; - evaluator.update_state(state, &order_bys, &sort_partition_points)?; - let frame_range = &state.window_frame_range; // Exit if the range extends all the way: - if frame_range.end == num_rows && !state.is_end { + if frame_range.end == num_rows && !partition_batch_state.is_end { break; } + // Update last range + state.window_frame_range = frame_range; + evaluator.update_state(state, idx, &order_bys, &sort_partition_points)?; row_wise_results.push(evaluator.evaluate_stateful(&values)?); - last_range.clone_from(frame_range); - state.last_calculated_index += 1; } - state.window_frame_range = last_range; let out_col = if row_wise_results.is_empty() { new_empty_array(out_type) } else { ScalarValue::iter_to_array(row_wise_results.into_iter())? }; - state.out_col = concat(&[&state.out_col, &out_col])?; - state.n_row_result_missing = num_rows - state.last_calculated_index; + state.update(&out_col, partition_batch_state)?; if self.window_frame.start_bound.is_unbounded() { let mut evaluator_state = evaluator.state()?; if let BuiltinWindowState::NthValue(nth_value_state) = @@ -236,11 +239,9 @@ impl WindowExpr for BuiltInWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.expr.supports_bounded_execution() && (!self.expr.uses_window_frame() - || !(self.window_frame.end_bound.is_unbounded() - || matches!(self.window_frame.units, WindowFrameUnits::Groups))) + || !self.window_frame.end_bound.is_unbounded()) } } @@ -271,9 +272,7 @@ fn memoize_nth_value( let result = ScalarValue::try_from_array(out, size - 1)?; nth_value_state.finalized_result = Some(result); } - if state.window_frame_range.end > 0 { - state.window_frame_range.start = state.window_frame_range.end - 1; - } + state.window_frame_range.start = state.window_frame_range.end.saturating_sub(1); } Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 794c7bd39249..e2dfd52daf71 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -189,33 +189,25 @@ impl PartitionEvaluator for WindowShiftEvaluator { fn update_state( &mut self, - state: &WindowAggState, + _state: &WindowAggState, + idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { - self.state.idx = state.last_calculated_index; + self.state.idx = idx; Ok(()) } - fn get_range(&self, state: &WindowAggState, n_rows: usize) -> Result> { + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { if self.shift_offset > 0 { let offset = self.shift_offset as usize; - let start = if state.last_calculated_index > offset { - state.last_calculated_index - offset - } else { - 0 - }; - Ok(Range { - start, - end: state.last_calculated_index + 1, - }) + let start = idx.saturating_sub(offset); + let end = idx + 1; + Ok(Range { start, end }) } else { let offset = (-self.shift_offset) as usize; - let end = min(state.last_calculated_index + offset, n_rows); - Ok(Range { - start: state.last_calculated_index, - end, - }) + let end = min(idx + offset, n_rows); + Ok(Range { start: idx, end }) } } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index ef6e3c6d016d..4da91e75ef20 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -160,6 +160,7 @@ impl PartitionEvaluator for NthValueEvaluator { fn update_state( &mut self, state: &WindowAggState, + _idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 7887d1412b98..758f7c3b1b23 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -38,9 +38,15 @@ pub trait PartitionEvaluator: Debug + Send { Ok(BuiltinWindowState::Default) } + /// Updates the internal state for Built-in window function + // state is useful to update internal state for Built-in window function. + // idx is the index of last row for which result is calculated. + // range_columns is the result of order by column values. It is used to calculate rank boundaries + // sort_partition_points is the boundaries of each rank in the range_column. It is used to update rank. fn update_state( &mut self, _state: &WindowAggState, + _idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { @@ -54,20 +60,23 @@ pub trait PartitionEvaluator: Debug + Send { )) } - fn get_range(&self, _state: &WindowAggState, _n_rows: usize) -> Result> { + /// Gets the range where Built-in window function result is calculated. + // idx is the index of last row for which result is calculated. + // n_rows is the number of rows of the input record batch (Used during bound check) + fn get_range(&self, _idx: usize, _n_rows: usize) -> Result> { Err(DataFusionError::NotImplemented( "get_range is not implemented for this window function".to_string(), )) } - /// evaluate the partition evaluator against the partition + /// Evaluate the partition evaluator against the partition fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate is not implemented by default".into(), )) } - /// evaluate window function result inside given range + /// Evaluate window function result inside given range fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { Err(DataFusionError::NotImplemented( "evaluate_stateful is not implemented by default".into(), diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index ead9d44535ba..5f016739cfa0 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -25,6 +25,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::get_row_at_idx; use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::any::Any; use std::iter; @@ -118,11 +119,10 @@ pub(crate) struct RankEvaluator { } impl PartitionEvaluator for RankEvaluator { - fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { - Ok(Range { - start: state.last_calculated_index, - end: state.last_calculated_index + 1, - }) + fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { + let start = idx; + let end = idx + 1; + Ok(Range { start, end }) } fn state(&self) -> Result { @@ -132,22 +132,21 @@ impl PartitionEvaluator for RankEvaluator { fn update_state( &mut self, state: &WindowAggState, + idx: usize, range_columns: &[ArrayRef], sort_partition_points: &[Range], ) -> Result<()> { - // find range inside `sort_partition_points` containing `state.last_calculated_index` + // find range inside `sort_partition_points` containing `idx` let chunk_idx = sort_partition_points .iter() - .position(|elem| { - elem.start <= state.last_calculated_index - && state.last_calculated_index < elem.end - }) - .ok_or_else(|| DataFusionError::Execution("Expects sort_partition_points to contain state.last_calculated_index".to_string()))?; + .position(|elem| elem.start <= idx && idx < elem.end) + .ok_or_else(|| { + DataFusionError::Execution( + "Expects sort_partition_points to contain idx".to_string(), + ) + })?; let chunk = &sort_partition_points[chunk_idx]; - let last_rank_data = range_columns - .iter() - .map(|c| ScalarValue::try_from_array(c, chunk.end - 1)) - .collect::>>()?; + let last_rank_data = get_row_at_idx(range_columns, chunk.end - 1)?; let empty = self.state.last_rank_data.is_empty(); if empty || self.state.last_rank_data != last_rank_data { self.state.last_rank_data = last_rank_data; diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index c858a5724a20..8961b277e7de 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -19,7 +19,7 @@ use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; -use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; +use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; @@ -81,11 +81,10 @@ impl PartitionEvaluator for NumRowsEvaluator { Ok(BuiltinWindowState::NumRows(self.state.clone())) } - fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { - Ok(Range { - start: state.last_calculated_index, - end: state.last_calculated_index + 1, - }) + fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { + let start = idx; + let end = idx + 1; + Ok(Range { start, end }) } /// evaluate window function result inside given range diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 0723f05c598f..7fa33d71ca44 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -26,7 +26,7 @@ use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; +use datafusion_expr::{Accumulator, WindowFrame}; use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr}; use crate::window::{ @@ -138,10 +138,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.end_bound.is_unbounded() - && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 96e22976b3d8..7568fa3b2b58 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -18,7 +18,7 @@ use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_frame_state::WindowFrameContext; use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::{new_empty_array, ArrayRef}; +use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::{concat, SortOptions}; @@ -164,8 +164,18 @@ pub trait AggregateWindowExpr: WindowExpr { fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; - let mut idx = 0; - self.get_result_column(&mut accumulator, batch, &mut last_range, &mut idx, false) + let sort_options: Vec = + self.order_by().iter().map(|o| o.options).collect(); + let mut window_frame_ctx = + WindowFrameContext::new(self.get_window_frame().clone(), sort_options); + self.get_result_column( + &mut accumulator, + batch, + &mut last_range, + &mut window_frame_ctx, + 0, + false, + ) } /// Statefully evaluates the window function against the batch. Maintains @@ -196,20 +206,25 @@ pub trait AggregateWindowExpr: WindowExpr { WindowFn::Aggregate(accumulator) => accumulator, _ => unreachable!(), }; - let mut state = &mut window_state.state; - + let state = &mut window_state.state; let record_batch = &partition_batch_state.record_batch; + + // If there is no window state context, initialize it. + let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { + let sort_options: Vec = + self.order_by().iter().map(|o| o.options).collect(); + WindowFrameContext::new(self.get_window_frame().clone(), sort_options) + }); let out_col = self.get_result_column( accumulator, record_batch, + // Start search from the last range &mut state.window_frame_range, - &mut state.last_calculated_index, + window_frame_ctx, + state.last_calculated_index, !partition_batch_state.is_end, )?; - state.is_end = partition_batch_state.is_end; - state.out_col = concat(&[&state.out_col, &out_col])?; - state.n_row_result_missing = - record_batch.num_rows() - state.last_calculated_index; + state.update(&out_col, partition_batch_state)?; } Ok(()) } @@ -221,23 +236,18 @@ pub trait AggregateWindowExpr: WindowExpr { accumulator: &mut Box, record_batch: &RecordBatch, last_range: &mut Range, - idx: &mut usize, + window_frame_ctx: &mut WindowFrameContext, + mut idx: usize, not_end: bool, ) -> Result { let (values, order_bys) = self.get_values_orderbys(record_batch)?; // We iterate on each row to perform a running calculation. let length = values[0].len(); - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); - let mut window_frame_ctx = WindowFrameContext::new( - self.get_window_frame(), - sort_options, - // Start search from the last range - last_range.clone(), - ); let mut row_wise_results: Vec = vec![]; - while *idx < length { - let cur_range = window_frame_ctx.calculate_range(&order_bys, length, *idx)?; + while idx < length { + // Start search from the last_range. This squeezes searched range. + let cur_range = + window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?; // Exit if the range extends all the way: if cur_range.end == length && not_end { break; @@ -248,9 +258,10 @@ pub trait AggregateWindowExpr: WindowExpr { &values, accumulator, )?; - last_range.clone_from(&cur_range); + // Update last range + *last_range = cur_range; row_wise_results.push(value); - *idx += 1; + idx += 1; } if row_wise_results.is_empty() { let field = self.field()?; @@ -340,6 +351,7 @@ pub enum BuiltinWindowState { pub struct WindowAggState { /// The range that we calculate the window function pub window_frame_range: Range, + pub window_frame_ctx: Option, /// The index of the last row that its result is calculated inside the partition record batch buffer. pub last_calculated_index: usize, /// The offset of the deleted row number @@ -353,6 +365,54 @@ pub struct WindowAggState { pub is_end: bool, } +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } +} + +impl WindowAggState { + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } +} + /// State for each unique partition determined according to PARTITION BY column(s) #[derive(Debug)] pub struct PartitionBatchState { @@ -383,6 +443,7 @@ impl WindowAggState { let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); Ok(Self { window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, last_calculated_index: 0, offset_pruned_rows: 0, out_col: empty_out_col, diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 64abacde49c1..01a4f9ad71a8 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -31,37 +31,33 @@ use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] -pub enum WindowFrameContext<'a> { +pub enum WindowFrameContext { /// ROWS frames are inherently stateless. - Rows(&'a Arc), + Rows(Arc), /// RANGE frames are stateful, they store indices specifying where the /// previous search left off. This amortizes the overall cost to O(n) /// where n denotes the row count. Range { - window_frame: &'a Arc, + window_frame: Arc, state: WindowFrameStateRange, }, /// GROUPS frames are stateful, they store group boundaries and indices /// specifying where the previous search left off. This amortizes the /// overall cost to O(n) where n denotes the row count. Groups { - window_frame: &'a Arc, + window_frame: Arc, state: WindowFrameStateGroups, }, } -impl<'a> WindowFrameContext<'a> { +impl WindowFrameContext { /// Create a new state object for the given window frame. - pub fn new( - window_frame: &'a Arc, - sort_options: Vec, - last_range: Range, - ) -> Self { + pub fn new(window_frame: Arc, sort_options: Vec) -> Self { match window_frame.units { WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame), WindowFrameUnits::Range => WindowFrameContext::Range { window_frame, - state: WindowFrameStateRange::new(sort_options, last_range), + state: WindowFrameStateRange::new(sort_options), }, WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, @@ -74,10 +70,11 @@ impl<'a> WindowFrameContext<'a> { pub fn calculate_range( &mut self, range_columns: &[ArrayRef], + last_range: &Range, length: usize, idx: usize, ) -> Result> { - match *self { + match self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) } @@ -87,7 +84,13 @@ impl<'a> WindowFrameContext<'a> { WindowFrameContext::Range { window_frame, ref mut state, - } => state.calculate_range(window_frame, range_columns, length, idx), + } => state.calculate_range( + window_frame, + last_range, + range_columns, + length, + idx, + ), // Sort options is not used in GROUPS mode calculations as the // inequality of two rows indicates a group change, and ordering // or position of NULLs do not impact inequality. @@ -159,33 +162,29 @@ impl<'a> WindowFrameContext<'a> { } /// This structure encapsulates all the state information we require as we scan -/// ranges of data while processing RANGE frames. Attribute `last_range` stores -/// the resulting indices from the previous search. Since the indices only -/// advance forward, we start from `last_range` subsequently. Thus, the overall -/// time complexity of linear search amortizes to O(n) where n denotes the total -/// row count. +/// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER /// BY clause. This information is used to calculate the range. #[derive(Debug, Default)] pub struct WindowFrameStateRange { - last_range: Range, sort_options: Vec, } impl WindowFrameStateRange { /// Create a new object to store the search state. - fn new(sort_options: Vec, last_range: Range) -> Self { - Self { - // Stores the search range we calculate for future use. - last_range, - sort_options, - } + fn new(sort_options: Vec) -> Self { + Self { sort_options } } /// This function calculates beginning/ending indices for the frame of the current row. + // Argument `last_range` stores the resulting indices from the previous search. Since the indices only + // advance forward, we start from `last_range` subsequently. Thus, the overall + // time complexity of linear search amortizes to O(n) where n denotes the total + // row count. fn calculate_range( &mut self, window_frame: &Arc, + last_range: &Range, range_columns: &[ArrayRef], length: usize, idx: usize, @@ -198,6 +197,7 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -206,6 +206,7 @@ impl WindowFrameStateRange { } WindowFrameBound::CurrentRow => self.calculate_index_of_row::( range_columns, + last_range, idx, None, length, @@ -213,6 +214,7 @@ impl WindowFrameStateRange { WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -222,12 +224,14 @@ impl WindowFrameStateRange { WindowFrameBound::Preceding(ref n) => self .calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, )?, WindowFrameBound::CurrentRow => self.calculate_index_of_row::( range_columns, + last_range, idx, None, length, @@ -239,6 +243,7 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -246,9 +251,6 @@ impl WindowFrameStateRange { } } }; - // Store the resulting range so we can start from here subsequently: - self.last_range.start = start; - self.last_range.end = end; Ok(Range { start, end }) } @@ -258,6 +260,7 @@ impl WindowFrameStateRange { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], + last_range: &Range, idx: usize, delta: Option<&ScalarValue>, length: usize, @@ -298,9 +301,9 @@ impl WindowFrameStateRange { current_row_values }; let search_start = if SIDE { - self.last_range.start + last_range.start } else { - self.last_range.end + last_range.end }; let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, &self.sort_options)?; @@ -332,16 +335,16 @@ impl WindowFrameStateRange { // last row of the group that comes "offset" groups after the current group. // - UNBOUNDED FOLLOWING: End with the last row of the partition. Possible only in frame_end. -// This structure encapsulates all the state information we require as we -// scan groups of data while processing window frames. +/// This structure encapsulates all the state information we require as we +/// scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { /// A tuple containing group values and the row index where the group ends. /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to /// [([1, 1], 2), ([2, 1], 4), ...]. - group_start_indices: VecDeque<(Vec, usize)>, + pub group_end_indices: VecDeque<(Vec, usize)>, /// The group index to which the row index belongs. - current_group_idx: usize, + pub current_group_idx: usize, } impl WindowFrameStateGroups { @@ -435,14 +438,28 @@ impl WindowFrameStateGroups { 0 }; let mut group_start = 0; - let last_group = self.group_start_indices.back(); - if let Some((_, group_end)) = last_group { + let last_group = self.group_end_indices.back_mut(); + if let Some((group_row, group_end)) = last_group { + if *group_end < length { + let new_group_row = get_row_at_idx(range_columns, *group_end)?; + // If last/current group keys are the same, we extend the last group: + if new_group_row.eq(group_row) { + // Update the end boundary of the group (search right boundary): + *group_end = search_in_slice( + range_columns, + group_row, + check_equality, + *group_end, + length, + )?; + } + } // Start searching from the last group boundary: group_start = *group_end; } // Advance groups until `idx` is inside a group: - while idx > group_start { + while idx >= group_start { let group_row = get_row_at_idx(range_columns, group_start)?; // Find end boundary of the group (search right boundary): let group_end = search_in_slice( @@ -452,13 +469,13 @@ impl WindowFrameStateGroups { group_start, length, )?; - self.group_start_indices.push_back((group_row, group_end)); + self.group_end_indices.push_back((group_row, group_end)); group_start = group_end; } // Update the group index `idx` belongs to: - while self.current_group_idx < self.group_start_indices.len() - && idx >= self.group_start_indices[self.current_group_idx].1 + while self.current_group_idx < self.group_end_indices.len() + && idx >= self.group_end_indices[self.current_group_idx].1 { self.current_group_idx += 1; } @@ -475,7 +492,7 @@ impl WindowFrameStateGroups { }; // Extend `group_start_indices` until it includes at least `group_idx`: - while self.group_start_indices.len() <= group_idx && group_start < length { + while self.group_end_indices.len() <= group_idx && group_start < length { let group_row = get_row_at_idx(range_columns, group_start)?; // Find end boundary of the group (search right boundary): let group_end = search_in_slice( @@ -485,7 +502,7 @@ impl WindowFrameStateGroups { group_start, length, )?; - self.group_start_indices.push_back((group_row, group_end)); + self.group_end_indices.push_back((group_row, group_end)); group_start = group_end; } @@ -493,10 +510,10 @@ impl WindowFrameStateGroups { Ok(match (SIDE, SEARCH_SIDE) { // Window frame start: (true, _) => { - let group_idx = min(group_idx, self.group_start_indices.len()); + let group_idx = min(group_idx, self.group_end_indices.len()); if group_idx > 0 { // Normally, start at the boundary of the previous group. - self.group_start_indices[group_idx - 1].1 + self.group_end_indices[group_idx - 1].1 } else { // If previous group is out of the table, start at zero. 0 @@ -506,7 +523,7 @@ impl WindowFrameStateGroups { (false, true) => { if self.current_group_idx >= delta { let group_idx = self.current_group_idx - delta; - self.group_start_indices[group_idx].1 + self.group_end_indices[group_idx].1 } else { // Group is out of the table, therefore end at zero. 0 @@ -516,9 +533,9 @@ impl WindowFrameStateGroups { (false, false) => { let group_idx = min( self.current_group_idx + delta, - self.group_start_indices.len() - 1, + self.group_end_indices.len() - 1, ); - self.group_start_indices[group_idx].1 + self.group_end_indices[group_idx].1 } }) }