diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml deleted file mode 100644 index 7b32c9ebf..000000000 --- a/.github/workflows/codeql-analysis.yml +++ /dev/null @@ -1,58 +0,0 @@ -# For most projects, this workflow file will not need changing; you simply need -# to commit it to your repository. -# -# You may wish to alter this file to override the set of languages analyzed, -# or to provide custom queries or build logic. -# -# ******** NOTE ******** -# We have attempted to detect the languages in your repository. Please check -# the `language` matrix defined below to confirm you have the correct set of -# supported CodeQL languages. -# -name: "CodeQL" - -on: - push: - branches: [main] - schedule: - - cron: "0 9 * * 1" # Every Monday at 09:00 (9:00 AM) - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: ["python"] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Learn more about CodeQL language support at https://git.io/codeql-language-support - - steps: - # the following step is required to avoid running out of space - - name: Maximize build space - run: | - df -h - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - echo "Check space..." - df -h - - - name: Checkout repository - uses: actions/checkout@v3 - - name: Get composite run steps repository - uses: actions/checkout@v3 - with: - repository: mosaicml/ci-testing - ref: v0.0.2 - path: ./ci-testing - - uses: ./ci-testing/.github/actions/codeql-analysis - with: - language: ${{ matrix.language }} diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 4e5bd0930..69b3ea6bc 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -32,7 +32,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.2 + ref: v0.0.9 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: diff --git a/docs/source/how_to_guides/configure_cloud_storage_credentials.md b/docs/source/how_to_guides/configure_cloud_storage_credentials.md index 8431e5a9e..6c46679c3 100644 --- a/docs/source/how_to_guides/configure_cloud_storage_credentials.md +++ b/docs/source/how_to_guides/configure_cloud_storage_credentials.md @@ -7,6 +7,7 @@ Streaming dataset supports the following cloud storage providers to stream your - [Oracle Cloud Storage](#oracle-cloud-storage) - [Azure Blob Storage](#azure-blob-storage-and-azure-datalake) - [Databricks](#databricks) +- [Huggingface Datasets](#huggingface-datasets) ## Amazon S3 @@ -251,6 +252,23 @@ export AZURE_ACCOUNT_ACCESS_KEY='NN1KHxKKkj20ZO92EMiDQjx3wp2kZG4UUvfAGlgGWRn6sPR ``` ```` +## Huggingface Datasets + +To authenticate Huggingface Hub access, users must set their HuggingFace token ([HF_TOKEN](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#hftoken)) in the run environment. See the [HF's documentation](https://huggingface.co/docs/huggingface_hub/guides/hf_file_system) on the URL format. + +Set the Huggingface token in the run environment as shown below + +````{tabs} +```{code-tab} py +import os +os.environ['HF_TOKEN'] = 'EXAMPLEFODNN7EXAMPLE' +``` + +```{code-tab} sh +export HF_TOKEN='EXAMPLEFODNN7EXAMPLE' +``` +```` + ## Databricks To authenticate Databricks access for both Unity Catalog and Databricks File System (DBFS), users must set their Databricks host (`DATABRICKS_HOST`) and access token (`DATABRICKS_TOKEN`) in the run environment. diff --git a/pyproject.toml b/pyproject.toml index e648880ac..9878eba0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ reportUnusedCoroutine = "error" # Pytest [tool.pytest.ini_options] # By default, do not run remote tests -addopts = "--cov=streaming --cov-fail-under=50 --codeblocks --strict-markers -m 'not daily and not remote' -ra --tb=native" +addopts = "--cov=streaming --cov-fail-under=50 --codeblocks --strict-markers -m 'not daily and not remote' -ra --tb=native --color=yes" markers = [ # For distributed testing diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..bafac98ce --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1633 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "const-random", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "arrow" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04a8801ebb147ad240b2d978d3ab9f73c9ccd4557ba6a03e7800496770ed10e0" +dependencies = [ + "ahash", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "pyo3", +] + +[[package]] +name = "arrow-arith" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "895263144bd4a69751cbe6a34a53f26626e19770b313a9fa792c415cd0e78f11" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "num", +] + +[[package]] +name = "arrow-array" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "226fdc6c3a4ae154a74c24091d36a90b514f0ed7112f5b8322c1d8f354d8e20d" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "hashbrown", + "num", +] + +[[package]] +name = "arrow-buffer" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4843af4dd679c2f35b69c572874da8fde33be53eb549a5fb128e7a4b763510" +dependencies = [ + "bytes", + "half", + "num", +] + +[[package]] +name = "arrow-cast" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e8b9990733a9b635f656efda3c9b8308c7a19695c9ec2c7046dd154f9b144b" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "chrono", + "half", + "lexical-core", + "num", +] + +[[package]] +name = "arrow-csv" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "646fbb4e11dd0afb8083e883f53117713b8caadb4413b3c9e63e3f535da3683c" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "csv", + "csv-core", + "lazy_static", + "lexical-core", + "regex", +] + +[[package]] +name = "arrow-data" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da900f31ff01a0a84da0572209be72b2b6f980f3ea58803635de47913191c188" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-ipc" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2707a8d7ee2d345d045283ece3ae43416175873483e5d96319c929da542a0b1f" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "flatbuffers", +] + +[[package]] +name = "arrow-json" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d1b91a63c356d14eedc778b76d66a88f35ac8498426bb0799a769a49a74a8b4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap", + "lexical-core", + "num", + "serde", + "serde_json", +] + +[[package]] +name = "arrow-ord" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "584325c91293abbca7aaaabf8da9fe303245d641f5f4a18a6058dc68009c7ebf" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "num", +] + +[[package]] +name = "arrow-row" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e32afc1329f7b372463b21c6ca502b07cf237e1ed420d87706c1770bb0ebd38" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", + "hashbrown", +] + +[[package]] +name = "arrow-schema" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b104f5daa730f00fde22adc03a12aa5a2ae9ccbbf99cbd53d284119ddc90e03d" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "arrow-select" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b3ca55356d1eae07cf48808d8c462cea674393ae6ad1e0b120f40b422eb2b4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", +] + +[[package]] +name = "arrow-string" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1433ce02590cae68da0a18ed3a3ed868ffac2c6f24c533ddd2067f7ee04b4a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "num", + "regex", + "regex-syntax 0.7.5", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "backtrace" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" + +[[package]] +name = "cc" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.52.5", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "flatbuffers" +version = "23.5.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +dependencies = [ + "bitflags 1.3.2", + "rustc_version", +] + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "lz4" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +dependencies = [ + "libc", + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + +[[package]] +name = "parquet" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad2cba786ae07da4d73371a88b9e0f9d3ffac1a9badc83922e0e15814f5c5fa" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema", + "arrow-select", + "base64", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown", + "lz4", + "num", + "num-bigint", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "proc-macro2" +version = "1.0.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "rust" +version = "0.1.0" +dependencies = [ + "arrow", + "bytes", + "futures", + "parquet", + "pyo3", + "tokio", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "serde_json" +version = "1.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tokio" +version = "1.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/setup.py b/setup.py index 2d7419f09..517a4af82 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,8 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'databricks-connect>=14.3.0', + 'pyarrow>=17,<18', ] extra_deps = {} @@ -68,16 +70,16 @@ 'docformatter>=1.4', 'jupyter==1.0.0', 'pre-commit>=2.18.1,<4', - 'pytest==8.2.1', + 'pytest==8.3.2', 'pytest_codeblocks==0.17.0', 'pytest-cov>=4,<6', 'toml==0.10.2', 'yamllint==1.35.1', 'moto>=4.0,<6', - 'fastapi==0.111.0', - 'pydantic==2.7.1', - 'uvicorn==0.29.0', - 'pytest-split==0.8.2', + 'fastapi==0.111.1', + 'pydantic==2.8.2', + 'uvicorn==0.30.3', + 'pytest-split==0.9.0', ] extra_deps['docs'] = [ @@ -116,13 +118,17 @@ ] extra_deps['databricks'] = [ - 'databricks-sdk==0.27.1', + 'databricks-sdk==0.29.0', ] extra_deps['alipan'] = [ 'AliPCS-Py>=0.8,<1', ] +extra_deps['hf'] = [ + 'huggingface_hub>=0.23.4,<0.24', +] + extra_deps['testing'] = [ 'mosaicml-cli>=0.5.25,<0.7', ] diff --git a/streaming/_version.py b/streaming/_version.py index 6999045b3..43c88d2a3 100644 --- a/streaming/_version.py +++ b/streaming/_version.py @@ -3,4 +3,4 @@ """The Streaming Version.""" -__version__ = '0.7.6' +__version__ = '0.8.0' diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e43e212e8..cef35ae6f 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,7 @@ from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream +from streaming.base.stream import Stream, DeltaDBSQLStream, DeltaSCStream from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -331,7 +331,8 @@ def __init__(self, shuffle_block_size: Optional[int] = None, batching_method: str = 'random', allow_unsafe_types: bool = False, - replication: Optional[int] = None) -> None: + replication: Optional[int] = None, + **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -348,6 +349,8 @@ def __init__(self, self.allow_unsafe_types = allow_unsafe_types self.replication = replication + logger.warning('Using StreamingX:heterogeneous branch') + # Initialize the World context. # * This information is for the per-rank or per-worker process. # * DataLoader worker processes may get a different worker ID and worker count than rank. @@ -443,6 +446,27 @@ def __init__(self, } for stream in streams: stream.apply_default(default) + elif remote is not None and remote.startswith('SELECT'): + cluster_id = kwargs.get('cluster_id', None) + if not cluster_id: + default = DeltaDBSQLStream(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + **kwargs) + else: + default = DeltaSCStream(cluster_id, + remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + streams = [default] else: default = Stream(remote=remote, local=local, @@ -922,14 +946,18 @@ def resample_streams( sample_ids = np.concatenate(sample_ids).astype(np.int64) return shuffle_units, sample_ids - def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]: + def _share_work( + self, + sample_ids: NDArray[np.int64], + ) -> Tuple[SharedMemory, Optional[SharedMemory]]: """Put an epoch's sample ordering into shared memory. Args: sample_ids (NDArray[np.int64]): Sample IDs. Returns: - Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data. + Tuple[SharedMemory, Optional[SharedMemory]]: Shared memory arrays containing shape and + data, if present. """ ndim = 5 @@ -945,19 +973,26 @@ def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Shar shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes() - # Save the generated epoch data to shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = sample_ids.size * np.int64().nbytes - data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) - data_shm.buf[:size] = sample_ids.tobytes() + if sample_ids.size > 0: + # Save the generated epoch data to shared memory, but only if the sample partition is + # non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = sample_ids.size * np.int64().nbytes + data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) + data_shm.buf[:size] = sample_ids.tobytes() + + return shape_shm, data_shm + + else: - return shape_shm, data_shm + return shape_shm, None - def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: + def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: """Get an epoch's sample ordering from shared memory. Returns: - NDArray[np.int64]: Sample IDs. + Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: Sample IDs, shared + memory array for shape, and shared memory array for data, if present. """ ndim = 5 @@ -967,13 +1002,22 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64)) - # Attach to the generated epoch data in shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = int(np.prod(shape)) * np.int64().nbytes - data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) - sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + num_elements = int(np.prod(shape)) + + if num_elements > 0: + # Attach to the generated epoch data in shared memory, but only if the sample partition + # is non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = num_elements * np.int64().nbytes + data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) + sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + + return sample_ids, shape_shm, data_shm + + else: - return sample_ids, shape_shm, data_shm + sample_ids = np.empty(shape=shape, dtype=np.int64) + return sample_ids, shape_shm, None def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. @@ -1025,7 +1069,9 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: # Now clean up after ourselves. shape_shm.cleanup() - data_shm.cleanup() + # Can be None if the sample partition was empty. + if data_shm is not None: + data_shm.cleanup() return worker_sample_ids diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index 152d61ae7..ed3407b86 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -124,7 +124,7 @@ def __init__(self, self.shards = [] # Remove local directory if requested prior to creating writer - local = out if isinstance(out, str) else out[0] + local = os.path.expanduser(out) if isinstance(out, str) else os.path.expanduser(out[0]) if os.path.exists(local) and kwargs.get('exist_ok', False): logger.warning( f'Directory {local} exists and is not empty; exist_ok is set to True so will remove contents.' diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 71c58be46..3a96b9c76 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -16,6 +16,8 @@ from PIL.JpegImagePlugin import JpegImageFile from typing_extensions import Self +import struct + __all__ = [ 'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode', 'is_mds_encoding_safe' @@ -201,6 +203,11 @@ def _rightsize_shape_dtype(cls, shape: npt.NDArray[np.int64]) -> str: Returns: str: The smallest acceptable uint* dtype. """ + if len(shape) == 0: + raise ValueError( + 'Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding.' + ) + if shape.min() <= 0: raise ValueError('All dimensions must be greater than zero.') x = shape.max() @@ -235,6 +242,9 @@ def encode(self, obj: npt.NDArray) -> bytes: if obj.dtype != self.dtype: raise ValueError(f'Wrong dtype: expected {self.dtype}, got {obj.dtype.name}.') + if obj.size == 0: + raise ValueError('Attempting to encode a numpy array with 0 elements.') + # Encode shape, if not given in header. if self.shape is None: ndim = len(obj.shape) @@ -516,6 +526,77 @@ def _is_valid(self, original: Any, converted: Any) -> None: e.msg = f'Invalid JSON data: {original}' raise +class StrArray(Encoding): + """Store a list of strings.""" + + def encode(self, strings: Any) -> bytes: + encoded_parts = [] + + # Encode the length of the list of strings + list_length = len(strings) + encoded_parts.append(list_length.to_bytes(4, byteorder='little')) + + for s in strings: + # Encode each string + encoded_str = s.encode('utf-8') # Encode string to UTF-8 bytes + length_prefix = len(encoded_str).to_bytes(4, byteorder='little') # Prefix with 4-byte length + encoded_parts.append(length_prefix + encoded_str) + + # Return the concatenated byte sequence + return b''.join(encoded_parts) + + + def decode(self, encoded_bytes: bytes) -> Any: + index = 0 + decoded_strings = [] + + # Decode the length of the list of strings + list_length = int.from_bytes(encoded_bytes[index:index+4], byteorder='little') + index += 4 + + for _ in range(list_length): + # Decode the length of the next string + length = int.from_bytes(encoded_bytes[index:index+4], byteorder='little') + index += 4 + + # Extract and decode the string + encoded_str = encoded_bytes[index:index+length] + decoded_str = encoded_str.decode('utf-8') + decoded_strings.append(decoded_str) + + index += length + + return decoded_strings + + +class IntArray(Encoding): + """Store a list of int32 integers efficiently.""" + + def encode(self, integers: Any) -> bytes: + # Pack the length of the list as an unsigned 4-byte integer + list_length = len(integers) + encoded = struct.pack(' npt.NDArray: + index = 0 + + # Unpack the length of the list + list_length = struct.unpack_from(' 0: + int_bytes_length = 4 * list_length + integers = list(struct.unpack_from(f'<{list_length}i', encoded_bytes, index)) + + return np.array(integers, dtype=np.int32) + # Encodings (name -> class). _encodings = { @@ -537,6 +618,8 @@ def _is_valid(self, original: Any, converted: Any) -> None: 'str_int': StrInt, 'str_float': StrFloat, 'str_decimal': StrDecimal, + 'str_array': StrArray, + 'int_array': IntArray, 'pil': PIL, 'jpeg': JPEG, 'png': PNG, diff --git a/streaming/base/partition/__init__.py b/streaming/base/partition/__init__.py index 28e908cb1..65271d8e2 100644 --- a/streaming/base/partition/__init__.py +++ b/streaming/base/partition/__init__.py @@ -3,6 +3,7 @@ """Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order.""" +import logging from typing import Optional import numpy as np @@ -11,6 +12,8 @@ from streaming.base.partition.orig import get_partitions_orig from streaming.base.partition.relaxed import get_partitions_relaxed +logger = logging.getLogger(__name__) + algos = { 'orig': get_partitions_orig, 'relaxed': get_partitions_relaxed, @@ -51,6 +54,17 @@ def get_partitions(algo: str, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ + world_size = ranks_per_node * num_physical_nodes + num_repeated_samples = world_size - (num_samples % world_size) + if num_samples + num_repeated_samples < drop_first: + raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + + f'({num_samples})') + + if num_repeated_samples > 0: + logger.debug(f'Using {num_repeated_samples} repeated samples to ensure that the epoch ' + + f'size is divisible by the number of total devices. This ensures that each ' + + f'device contributes the same number of samples per global batch. ') + get = algos[algo] return get(num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node, workers_per_rank, batch_size, drop_first, initial_physical_nodes) diff --git a/streaming/base/partition/orig.py b/streaming/base/partition/orig.py index dff6d7878..cda16ac1d 100644 --- a/streaming/base/partition/orig.py +++ b/streaming/base/partition/orig.py @@ -46,10 +46,6 @@ def get_partitions_orig(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: - raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + - f'({num_samples})') - if num_canonical_nodes < num_physical_nodes: if num_physical_nodes % num_canonical_nodes: raise ValueError('Either canonical or physical nodes must be evenly divisible by ' + @@ -81,7 +77,7 @@ def get_partitions_orig(num_samples: int, # For samples to be properly split across canonical nodes, there must be more samples than nodes. # The edge case is when the number of samples is equal to the number of canonical nodes, but this only works when - # there is an equal or greater number of canonical nodes than physical nodes. + # there is an equal or greater number of canonical nodes than physical nodes. # If these conditions are not met, an alternative sampling approach is used that leads to many repeats. if num_samples > num_canonical_nodes or (num_samples == num_canonical_nodes and num_canonical_nodes >= num_physical_nodes): @@ -141,8 +137,7 @@ def get_partitions_orig(num_samples: int, ids = ids.reshape(-1, num_physical_nodes) ids = ids.transpose() - # Interleave the node sample ranges over each node's ranks, padding by repeating the last - # sample. + # Interleave the node sample ranges over each node's ranks, padding with -1 for reshaping. # # ids: (physical nodes, samples per rank, ranks per node). overflow = ids.shape[1] % ranks_per_node diff --git a/streaming/base/partition/relaxed.py b/streaming/base/partition/relaxed.py index e84bb7efc..6baa0a48c 100644 --- a/streaming/base/partition/relaxed.py +++ b/streaming/base/partition/relaxed.py @@ -49,10 +49,6 @@ def get_partitions_relaxed(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: - raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + - f'({num_samples})') - if initial_physical_nodes is None or (num_physical_nodes <= num_canonical_nodes and num_canonical_nodes % num_physical_nodes == 0) or \ (num_physical_nodes > num_canonical_nodes and diff --git a/streaming/base/shared/prefix.py b/streaming/base/shared/prefix.py index 7e8936086..64585381c 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/base/shared/prefix.py @@ -118,7 +118,7 @@ def _check_and_find(streams_local: List[str], streams_remote: List[Union[str, No if any(streams_remote): # Get the indices of the local directories which matches with the current # shared memory. - matching_index = np.where(np.in1d(streams_local, their_locals))[0] + matching_index = np.where(np.isin(streams_local, their_locals))[0] if matching_index.size > 0: for idx in matching_index: # If there is a conflicting local directory for a non-None remote directory, diff --git a/streaming/base/spanner.py b/streaming/base/spanner.py index 10cd72639..18426af71 100644 --- a/streaming/base/spanner.py +++ b/streaming/base/spanner.py @@ -49,13 +49,13 @@ def __getitem__(self, index: int) -> Tuple[int, int]: Tuple[int, int]: Shard and relative sample index. """ if not (0 <= index < self.num_samples): - raise ValueError(f'Invalid sample index `{index}`: 0 <= {index} < {self.num_samples}') + raise IndexError(f'Invalid sample index `{index}`: 0 <= {index} < {self.num_samples}') span = index // self.span_size for shard in self.spans[span]: shard_start = self.shard_bounds[shard] shard_stop = self.shard_bounds[shard + 1] if shard_start <= index < shard_stop: - return shard, int(index - shard_start) + return shard, int(index - shard_start.item()) raise RuntimeError('Internal error: shards were indexed incorrectly') diff --git a/streaming/base/storage/__init__.py b/streaming/base/storage/__init__.py index e9653db9d..bfe9ce6f5 100644 --- a/streaming/base/storage/__init__.py +++ b/streaming/base/storage/__init__.py @@ -2,15 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 """Base module for downloading/uploading files from/to cloud storage.""" - -from streaming.base.storage.download import (download_file, download_from_alipan, - download_from_azure, download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_oci, - download_from_s3, download_from_sftp) +# isort: off +from streaming.base.storage.download import ( + download_file, download_from_alipan, download_from_azure, download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, + download_from_hf, download_from_local, download_from_oci, download_from_s3, download_from_sftp) from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - GCSUploader, LocalUploader, OCIUploader, S3Uploader) + GCSUploader, HFUploader, LocalUploader, OCIUploader, + S3Uploader) __all__ = [ 'download_file', @@ -21,6 +20,7 @@ 'LocalUploader', 'AzureUploader', 'AzureDataLakeUploader', + 'HFUploader', 'download_from_s3', 'download_from_sftp', 'download_from_gcs', @@ -31,4 +31,5 @@ 'download_from_dbfs', 'download_from_alipan', 'download_from_local', + 'download_from_hf', ] diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index cdcf3d489..bf4f0e33b 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -19,6 +19,7 @@ 'download_from_oci', 'download_from_azure', 'download_from_azure_datalake', + 'download_from_hf', 'download_from_databricks_unity_catalog', 'download_from_dbfs', 'download_from_alipan', @@ -53,12 +54,15 @@ def _download_file(unsigned: bool = False, extra_args (Dict[str, Any], optional): Extra arguments supported by boto3. Defaults to ``None``. """ + retries = { + 'mode': 'adaptive', + } if unsigned: # Client will be using unsigned mode in which public # resources can be accessed without credentials - config = Config(read_timeout=timeout, signature_version=UNSIGNED) + config = Config(read_timeout=timeout, signature_version=UNSIGNED, retries=retries) else: - config = Config(read_timeout=timeout) + config = Config(read_timeout=timeout, retries=retries) if extra_args is None: extra_args = {} @@ -272,6 +276,30 @@ def download_from_oci(remote: str, local: str) -> None: os.rename(local_tmp, local) +def download_from_hf(remote: str, local: str) -> None: + """Download a file from remote Hugging Face to local. + + Args: + remote (str): Remote path (Hugging Face). + local (str): Local path (local filesystem). + """ + from huggingface_hub import hf_hub_download + + obj = urllib.parse.urlparse(remote) + if obj.scheme != 'hf': + raise ValueError(f'Expected remote path to start with `hf://`, got {remote}.') + + _, _, _, repo_org, repo_name, path = remote.split('/', 5) + local_dirname = os.path.dirname(local) + hf_hub_download(repo_id=f'{repo_org}/{repo_name}', + filename=path, + repo_type='dataset', + local_dir=local_dirname) + + downloaded_name = os.path.join(local_dirname, path) + os.rename(downloaded_name, local) + + def download_from_azure(remote: str, local: str) -> None: """Download a file from remote Microsoft Azure to local. @@ -292,7 +320,10 @@ def download_from_azure(remote: str, local: str) -> None: account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY']) try: - blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip('/')) + file_path = obj.path.lstrip('/').split('/') + container_name = file_path[0] + blob_name = os.path.join(*file_path[1:]) + blob_client = service.get_blob_client(container=container_name, blob=blob_name) local_tmp = local + '.tmp' with open(local_tmp, 'wb') as my_blob: blob_data = blob_client.download_blob() @@ -378,7 +409,7 @@ def download_from_databricks_unity_catalog(remote: str, local: str) -> None: f'operations. Increase the `download_retry` value to retry downloading ' + f'a file.',) if e.error_code == 'NOT_FOUND': - raise FileNotFoundError(f'Object dbfs:{remote} not found.') from e + raise FileNotFoundError(f'Object {remote} not found.') from e raise e os.rename(local_tmp, local) @@ -511,6 +542,8 @@ def download_file(remote: Optional[str], local: str, timeout: float): download_from_gcs(remote, local) elif remote.startswith('oci://'): download_from_oci(remote, local) + elif remote.startswith('hf://'): + download_from_hf(remote, local) elif remote.startswith('azure://'): download_from_azure(remote, local) elif remote.startswith('azure-dl://'): diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 6a8c67e3c..1c296bb89 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -24,6 +24,7 @@ 'S3Uploader', 'GCSUploader', 'OCIUploader', + 'HFUploader', 'AzureUploader', 'DatabricksUnityCatalogUploader', 'DBFSUploader', @@ -37,6 +38,7 @@ 's3': 'S3Uploader', 'gs': 'GCSUploader', 'oci': 'OCIUploader', + 'hf': 'HFUploader', 'azure': 'AzureUploader', 'azure-dl': 'AzureDataLakeUploader', 'dbfs:/Volumes': 'DatabricksUnityCatalogUploader', @@ -616,6 +618,80 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: return [] +class HFUploader(CloudUploader): + """Upload file from local machine to a Huggingface Dataset. + + Args: + out (str): Output dataset directory to save shard files. + + 1. If ``out`` is a local directory, shard files are saved locally. + 2. If ``out`` is a remote directory then the shard files are uploaded to the + remote location. + keep_local (bool): If the dataset is uploaded, whether to keep the local dataset + shard file or remove it after uploading. Defaults to ``False``. + progress_bar (bool): Display TQDM progress bars for uploading output dataset files to + a remote location. Default to ``False``. + retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. + """ + + def __init__(self, + out: str, + keep_local: bool = False, + progress_bar: bool = False, + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) + + import huggingface_hub + self.api = huggingface_hub.HfApi() + self.fs = huggingface_hub.HfFileSystem(token=os.environ.get('HF_TOKEN', None)) + + obj = urllib.parse.urlparse(out) + if obj.scheme != 'hf': + raise ValueError(f'Expected remote path to start with `hf://`, got {out}.') + + _, _, _, self.repo_org, self.repo_name, self.path = out.split('/', 5) + self.dataset_id = os.path.join(self.repo_org, self.repo_name) + self.check_dataset_exists() # pyright: ignore + + def upload_file(self, filename: str): + """Upload file from local instance to HF. + + Args: + filename (str): File to upload. + """ + + @retry(num_attempts=self.retry) + def _upload_file(): + local_filename = filename + local_filename = local_filename.replace('\\', '/') + remote_filename = os.path.join('datasets', self.dataset_id, filename) + remote_filename = remote_filename.replace('\\', '/') + logger.debug(f'Uploading to {remote_filename}') + + with self.fs.open(remote_filename, 'wb') as f: + with open(local_filename, 'rb') as data: + f.write(data.read()) + + _upload_file() + + def check_dataset_exists(self): + """Raise an exception if the dataset does not exist. + + Raises: + error: Dataset does not exist. + """ + import huggingface_hub + try: + _ = list(huggingface_hub.list_repo_tree(self.dataset_id, repo_type='dataset')) + except Exception: + raise FileNotFoundError( + f'The HF dataset {self.dataset_id} could not be found. Please make sure ' + + f'that the dataset exists and you have the correct access permissions.') + + class AzureUploader(CloudUploader): """Upload file from local machine to Microsoft Azure bucket. diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 7287a5c71..a56955ad6 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -5,9 +5,10 @@ import hashlib import json +import time import os import tempfile -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Any import numpy as np from numpy.typing import NDArray @@ -15,13 +16,20 @@ from streaming.base.compression import decompress from streaming.base.constant import TICK +import torch.distributed as dist from streaming.base.distributed import barrier, get_local_rank from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash from streaming.base.storage import download_file -from streaming.base.util import retry, wait_for_file_to_exist +from streaming.base.util import retry, wait_for_file_to_exist, wait_for_json_to_exist from streaming.base.world import World +import re +import random +import pyarrow as pa +import requests +from tempfile import TemporaryDirectory + class Stream: """A dataset, or sub-dataset if mixing, from which we stream/cache samples. @@ -288,7 +296,7 @@ def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.i stream.repeat = repeat stream.choose = choose - return choose_per_epoch + return int(choose_per_epoch) def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: """Safely download a file from remote to local cache. @@ -454,8 +462,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: wait_for_file_to_exist( filename, TICK, self.download_timeout, f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + - f'-> {filename} took too long to download. Either increase the ' + - f'`download_timeout` value or check the other traceback.') + f'-> {filename} took too long to download or failed to download. Either increase the ' + + f'`download_timeout` value or check the local rank 0 traceback.') # Load the index. try: @@ -505,3 +513,767 @@ def get_index_size(self) -> int: """ filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size + +def save_dict_to_file(directory, filename, dictionary): + """Save a dictionary to a file in the specified directory.""" + if not os.path.exists(directory): + os.makedirs(directory) + + file_path = os.path.join(directory, filename) + with open(file_path, 'w') as file: + json.dump(dictionary, file, indent=4) + print(f"Dictionary saved to {file_path}") + +def load_dict_from_file(directory, filename): + """Load a dictionary from a file in the specified directory.""" + file_path = os.path.join(directory, filename) + if not os.path.exists(file_path): + raise FileNotFoundError(f"No such file: '{file_path}'") + + with open(file_path, 'r') as file: + dictionary = json.load(file) + print(f"Dictionary loaded from {file_path}") + return dictionary + + +class DeltaSCStream(Stream): + + def __init__(self, + cluster_id: str, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + self.url_to_basename= {} + self.basename_to_url={} + self.cluster_id = cluster_id + + def generate_unique_basename(self, url: str, index: int) -> str: + """Generate a unique basename for the file path from the URL.""" + hash_object = hashlib.md5(url.encode()) + hex_dig = hash_object.hexdigest() + basename = '.'.join(['shard', f'{index:05}', 'mds']) + self.url_to_basename[url] = basename + self.basename_to_url[basename] = url + + return basename + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + # Prepare cloudfetch + from databricks.connect import DatabricksSession + from databricks.sdk import WorkspaceClient + from streaming.base.converters import infer_dataframe_schema + + w = WorkspaceClient() + + sparkSession = DatabricksSession.builder.remote( + host=w.config.host, + token=w.config.token, + cluster_id=self.cluster_id).getOrCreate() + + df = sparkSession.sql(self.remote) + query = df._plan.to_proto(df._session.client) # pyright: ignore + schema, cloudfetch_results = df._session.client.experimental_to_cloudfetch(query, "arrow", compression=False) # pyright: ignore + + # Local leader prepares the index file based on cloudfetch results + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + self.columns = infer_dataframe_schema(df, None) + + column_names = [] + column_encodings = [] + column_sizes = [] + for k, v in self.columns.items(): + column_names.append(k) + column_encodings.append(v) + column_sizes.append(None) + + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for index, result in enumerate(cloudfetch_results): + shard = { + "column_encodings": column_encodings, + "column_names": column_names, + "column_sizes": column_sizes, + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": self.generate_unique_basename(result.url, index), + "bytes": result.uncompressed_size, + "hashes": {} + }, + "samples": result.row_count, + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + + else: + wait_for_file_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + save_dict_to_file('./', 'basename_to_url.json', self.basename_to_url) + return shards + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + from streaming import MDSWriter + + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + cloud_fetch_url = self.basename_to_url[from_basename] + local = os.path.join(self.local, self.split, from_basename) + + # Attempt to download, possibly repeating on failure. + retry(num_attempts=self.download_retry)( + lambda: fetch_and_convert(cloud_fetch_url, local))() + + print('download to local is done = ', local) + return local + + +class DeltaDBSQLStream(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + host = w.config.host.lstrip('https://') + token = w.config.token + #host = kwargs.get('host', os.environ['DATABRICKS_HOST']).lstrip('https://') + #token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) + + warehouse_id = kwargs.get('warehouse_id', None) + catalog = kwargs.get('catalog', None) + schema = kwargs.get('schema', None) + + if any([not warehouse_id, not host, not token, not catalog, not schema]): + raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization, but got {warehouse_id}, {host}, {token}, {catalog}, {schema}") + + self.base_url = f"https://{host}/api/2.0/sql/statements/" + self.headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + self.data = { + "warehouse_id": warehouse_id, + "format": "ARROW_STREAM", + "disposition": "EXTERNAL_LINKS", + "statement": remote, + "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error + "parameters": [], + # "byte_limit": 10000000000000, + } + + # From dbsql dtyps (lower case) to MDS encoded types + # https://docs.databricks.com/en/dev-tools/python-sql-connector.html + self.dtypes_mapping = { + 'string' : 'str', + 'bigint' : 'int64', + 'array': 'ndarray', + 'array': 'str_array', + 'array': 'int_array', + 'binary': 'bytes', + 'boolean': 'uint32', + 'date': 'str', + 'datetime.date': 'str', + 'decimal': 'str_decimal', + 'double' : 'float64', + 'int': 'int', + 'map': 'json', + 'smallint': 'int16', + 'struct': 'json', + 'tinyint': 'int8', + 'long': 'int8', + 'array>': 'json', # special for messages + } + + def generate_statement_id_and_sync(self, world: World): + if dist.is_available() and dist.is_initialized(): + barrier() + + if world.is_leader: # is_local_leader: + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + data = self.statement_id + else: + data = None + + + obj_list = [data] + dist.broadcast_object_list(obj_list, src=0) + self.statement_id = obj_list[0] + return + + world_size = world.num_ranks + if world_size > 1: + raise RuntimeError(''.join([ + f'The world_size({world_size}) > 1, but the distributed package is not available ', + 'or has not been initialized. Please check you have initialized the distributed ', + 'runtime and that PyTorch has been built with distributed support.' + ])) + + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + + def wait_for_query_result(self, timeout=3600): + if not self.statement_id: + raise ValueError(f"statement id is not set yet") + + total_time = 0 + while total_time <= timeout: + response = requests.get(f"{self.base_url}/{self.statement_id}", headers=self.headers) + response.raise_for_status() + response_data = response.json() + query_status = response_data['status']['state'] + + if query_status == "SUCCEEDED": + #self.statement_id = response_data['statement_id'] + save_dict_to_file(self.local, f'response_{int(time.time())}', response_data) + return response_data + + print(f"Query status: {query_status}") + time.sleep(3) + total_time += 3 + raise TimeoutError(f"Query execution failed with status: {query_status}") + + def get_encode_format(self, sql_fmt: str): + mds_fmt = self.dtypes_mapping.get(sql_fmt.lower(), None) + if not mds_fmt: + raise TypeError(f"{sql_fmt} is not supported by MDSWrite.") + return mds_fmt + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, + is_mds_encoding, mds_encode) + self.generate_statement_id_and_sync(world) + + sql_response = self.wait_for_query_result() + + # Local leader prepares the index file based on cloudfetch results + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + self.columns = { c['name']: self.get_encode_format(c['type_text']) for c in sql_response['manifest']['schema']['columns'] } + + column_names = [] + column_encodings = [] + column_sizes = [] + for name in sorted(self.columns): + encoding = self.columns[name] + if not is_mds_encoding(encoding): + raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' + + f'is unsupported. Supported encodings are {get_mds_encodings()}') + size = get_mds_encoded_size(encoding) + column_names.append(name) + column_encodings.append(encoding) + column_sizes.append(size) + + print(f'self.columns = {self.columns}') + + total_shard_count = sql_response['manifest']['total_chunk_count'] + + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for shard_id, shard_meta in enumerate(sql_response['manifest']['chunks']): + shard = { + "column_encodings": column_encodings, + "column_names": column_names, + "column_sizes": column_sizes, + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": f'shard.{shard_id:05}.mds', + "bytes": shard_meta['byte_count'], + "hashes": {} + }, + "samples": shard_meta['row_count'], + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + else: + wait_for_json_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + return shards + + def _make_request(self, url: str) -> requests.Response: + if random.random() < 0.0: # make rhs > 0.0 for testing, so x% of the time return HTTPError + response = requests.Response() + response.status_code = 404 + response.url = url + raise requests.exceptions.HTTPError(f"Manually raised HTTPError for testing purposes: {int(time.time())}", response=response) + else: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + return response + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + from streaming import MDSWriter + def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + chunk_index = int(re.search(r'\d+', from_basename).group()) + print('from_basename = ', from_basename) + print('chunk_index = ', chunk_index) + + + try: + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + response = self._make_request(url) + except Exception as e: # requests.exceptions.HTTPError as e: + print('Failed to download, I cannot refresh statement id and try again') + print('url = ', url) + print(e) + raise TimeoutError('Check if the query results retention period of your workspace and make sure it is longer than the expected training period. For multi-node, we do not want to refresh and communicate statement id from worker processes.') from e + # self.refresh_statement_id() + #url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + #response = self._make_request(url) + + cloud_fetch_url = response.json()['external_links'][0]['external_link'] + local = os.path.join(self.local, self.split, from_basename) + retry(num_attempts=self.download_retry)(lambda: _fetch_and_convert(cloud_fetch_url, local))() + + print('Download to local is done = ', local) + return local + + +class DeltaDBSQLStreamSession(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + warehouse_id = kwargs.get('warehouse_id', None) + host = kwargs.get('host', os.environ['DATABRICKS_HOST']) + token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) + catalog = kwargs.get('catalog', None) + schema = kwargs.get('schema', None) + self.use_cached_result = kwargs.get('use_cached_result', False) + + if any([not warehouse_id, not host, not token, not catalog, not schema]): + raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") + + self.base_url = f"https://{host}/api/2.0/sql/statements/" + self.session_url = f"https://{host}/api/2.0/sql/sessions/" + + self.headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + self.session_payload = { + "warehouse_id": warehouse_id, + "catalog": catalog, + "schema": schema, + "session_confs": {"use_cached_result": "false"} + } + + self.payload = { + "warehouse_id": warehouse_id, + "format": "ARROW_STREAM", + "disposition": "EXTERNAL_LINKS", + "statement": remote, + "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error + "parameters": [], + } + + # From dbsql dtyps (lower case) to MDS encoded types + # https://docs.databricks.com/en/dev-tools/python-sql-connector.html + self.dtypes_mapping = { + 'string' : 'str', + 'bigint' : 'int64', + 'array': 'ndarray', + 'array': 'str_array', + 'binary': 'bytes', + 'boolean': 'uint32', + 'date': 'str', + 'datetime.date': 'str', + 'decimal': 'str_decimal', + 'double' : 'float64', + 'int': 'int', + 'map': 'json', + 'smallint': 'int16', + 'struct': 'json', + 'tinyint': 'int8', + 'long': 'int8', + } + + self.refresh_statement_id(self.use_cached_result) + + def polling(self, timeout: int = 3600): + total_time = 0 + while total_time <= timeout: + response = requests.get(f"{self.base_url}/{self.statement_id}", headers=self.headers) + response.raise_for_status() + response_data = response.json() + query_status = response_data['status']['state'] + + if query_status == "SUCCEEDED": + save_dict_to_file(self.local, f'response_{int(time.time())}', response_data) + return response_data + + print(f"Query status: {query_status}") + time.sleep(3) + total_time += 3 + raise TimeoutError(f"Query execution failed with status: {query_status}") + + + def refresh_statement_id(self, use_cached_result:bool=False): + + boolean_string = "true" if use_cached_result else "false" + self.session_payload['session_confs']['use_cached_result'] = boolean_string + + print(f"Set the session data to be {self.session_payload}") + + # Create a session id + # Use session id in payload + # Fetch result via get status api + response = requests.post(self.session_url, headers=self.headers, json=self.session_payload) + self.payload['session_id'] = response.json()['session_id'] + + print(f"Set the payload to be {self.payload}") + + response = requests.post(self.base_url, headers=self.headers, json=self.payload) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + + return self.polling() + + def get_encode_format(self, sql_fmt: str): + mds_fmt = self.dtypes_mapping.get(sql_fmt.lower(), None) + if not mds_fmt: + raise TypeError(f"{sql_fmt} is not supported by MDSWrite.") + return mds_fmt + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, + is_mds_encoding, mds_encode) + + sql_response = self.refresh_statement_id(True) + + # Local leader prepares the index file based on cloudfetch results + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + self.columns = { c['name']: self.get_encode_format(c['type_text']) for c in sql_response['manifest']['schema']['columns'] } + + column_names = [] + column_encodings = [] + column_sizes = [] + for name in sorted(self.columns): + encoding = self.columns[name] + if not is_mds_encoding(encoding): + raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' + + f'is unsupported. Supported encodings are {get_mds_encodings()}') + size = get_mds_encoded_size(encoding) + column_names.append(name) + column_encodings.append(encoding) + column_sizes.append(size) + + print(f'self.columns = {self.columns}') + + total_shard_count = sql_response['manifest']['total_chunk_count'] + + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for shard_id, shard_meta in enumerate(sql_response['manifest']['chunks']): + shard = { + "column_encodings": column_encodings, + "column_names": column_names, + "column_sizes": column_sizes, + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": f'shard.{shard_id:05}.mds', + "bytes": shard_meta['byte_count'], + "hashes": {} + }, + "samples": shard_meta['row_count'], + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + else: + wait_for_json_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + return shards + + def _make_request(self, url: str) -> requests.Response: + if random.random() < 0.0: # make rhs > 0.0 for testing, so x% of the time return HTTPError + response = requests.Response() + response.status_code = 404 + response.url = url + raise requests.exceptions.HTTPError(f"Manually raised HTTPError for testing purposes: {int(time.time())}", response=response) + else: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + return response + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + from streaming import MDSWriter + def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + chunk_index = int(re.search(r'\d+', from_basename).group()) + print('from_basename = ', from_basename) + print('chunk_index = ', chunk_index) + + try: + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + response = self._make_request(url) + except Exception as e: # requests.exceptions.HTTPError as e: + print('Failed to download, refresh statement id and try again') + print('url = ', url) + print(e) + self.refresh_statement_id(True) + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + response = self._make_request(url) + + cloud_fetch_url = response.json()['external_links'][0]['external_link'] + local = os.path.join(self.local, self.split, from_basename) + retry(num_attempts=self.download_retry)(lambda: _fetch_and_convert(cloud_fetch_url, local))() + + print('Download to local is done = ', local) + return local + diff --git a/streaming/base/util.py b/streaming/base/util.py index 3be5b729a..e8b721a36 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -27,6 +27,9 @@ logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format='%(asctime)s [Process %(process)d, Thread %(thread)d] %(message)s') + + TCallable = TypeVar('TCallable', bound=Callable) __all__ = [ @@ -46,6 +49,38 @@ def get_list_arg(text: str) -> List[str]: """ return text.split(',') if text else [] +def wait_for_json_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for a json to exist till timeout seconds. Raise an Exception after that. + + Difference from wait_for_file_to_exist is that we load json and validate. + + Args: + filename (str): A file name of a json + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + def is_valid_json(filename): + try: + obj = json.load(open(filename)) + return True + except json.decoder.JSONDecodeError as error: + return False + + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename) and is_valid_json(filename): + logging.warning('json has read in') + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, err_msg: str) -> None: diff --git a/streaming/vision/convert/imagenet.py b/streaming/vision/convert/imagenet.py index f98883527..cdb84e7af 100644 --- a/streaming/vision/convert/imagenet.py +++ b/streaming/vision/convert/imagenet.py @@ -159,7 +159,7 @@ def main(args: Namespace) -> None: x = open(filenames[i], 'rb').read() y = classes[i] out.write({ - 'i': i, + 'i': int(i), 'x': x, 'y': y, }) diff --git a/tests/conftest.py b/tests/conftest.py index ac7844539..3b8a416c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,6 +51,12 @@ def aws_credentials(): os.environ['AWS_SESSION_TOKEN'] = 'testing' +@pytest.fixture(scope='class', autouse=True) +def hf_credentials(): + """Mocked HF Credentials.""" + os.environ['HF_TOKEN'] = 'testing' + + @pytest.fixture() def s3_client(aws_credentials: Any): with mock_aws(): diff --git a/tests/test_download.py b/tests/test_download.py index 8d9a0c1b2..50bee57d1 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -14,7 +14,8 @@ download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, - download_from_local, download_from_s3) + download_from_hf, download_from_local, + download_from_s3) from tests.conftest import GCS_URL, MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -47,6 +48,15 @@ def test_invalid_cloud_prefix(self, remote_local_file: Any): download_from_azure_datalake(mock_remote_filepath, mock_local_filepath) +class TestHFClient: + + @pytest.mark.usefixtures('remote_local_file') + def test_invalid_cloud_prefix(self, remote_local_file: Any): + with pytest.raises(ValueError): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='hf://') + download_from_hf(mock_remote_filepath, mock_local_filepath) + + class TestS3Client: @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') @@ -183,6 +193,14 @@ def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) + @patch('streaming.base.storage.download.download_from_hf') + @pytest.mark.usefixtures('remote_local_file') + def test_download_from_hf_gets_called(self, mocked_requests: Mock, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='hf://') + download_file(mock_remote_filepath, mock_local_filepath, 60) + mocked_requests.assert_called_once() + mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) + @patch('streaming.base.storage.download.download_from_azure') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_local_file: Any): diff --git a/tests/test_encodings.py b/tests/test_encodings.py index bc3aac670..36374545a 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from numpy.typing import NDArray from PIL import Image import streaming.base.format.json.encodings as jsonEnc @@ -132,6 +133,19 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): assert b3_len < b2_len < b1_len assert b3_len == np.prod(shape) * dtype().nbytes + def test_error_no_elements_ndarray(self): + encoding = 'ndarray' + with pytest.raises(ValueError, + match='Attempting to encode a numpy array with 0 elements.*'): + _ = mdsEnc.mds_encode(encoding, np.array([])) + + @pytest.mark.parametrize('array', [np.array(0.5), np.empty(()), np.array(1)]) + def test_error_scalar_ndarray(self, array: NDArray): + encoding = 'ndarray' + with pytest.raises(ValueError, + match='Attempting to encode a scalar with NDArray encoding.*'): + _ = mdsEnc.mds_encode(encoding, array) + @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_pil_encode_decode(self, mode: str): pil_enc = mdsEnc.PIL() @@ -187,7 +201,7 @@ def test_jpegfile_encode_decode(self, mode: str): # Creating the (32 x 32) NumPy Array with random values size = {'RGB': (224, 224, 3), 'L': (28, 28)}[mode] - np_data = np.random.randint(255, size=size, dtype=np.uint8) + np_data = np.array(np.random.randint(255, size=size, dtype=np.uint8)) # Default image mode of PIL Image is 'I' img = Image.fromarray(np_data).convert(mode) diff --git a/tests/test_partition.py b/tests/test_partition.py index 68d4ba8e1..42cfaa1f6 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -38,6 +38,68 @@ def test_partition_walk(partition_algo: str): assert x.shape == (22, 8, 8, 1, 10) +@pytest.mark.parametrize('num_samples', [405, 812, 1111]) +@pytest.mark.parametrize('num_canonical_nodes', [1, 2]) +@pytest.mark.parametrize('num_physical_nodes', [2, 8]) +@pytest.mark.parametrize('ranks_per_node', [1, 8]) +@pytest.mark.parametrize('workers_per_rank', [1, 8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_drop_all( + num_samples: int, + num_canonical_nodes: int, + num_physical_nodes: int, + ranks_per_node: int, + workers_per_rank: int, + batch_size: int, + partition_algo: str, +): + initial_physical_nodes = None + if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8: + num_canonical_nodes = 3 + initial_physical_nodes = 3 + batch_size = batch_size * 3 + num_samples = 3 * num_samples + + # Partitioning should repeat samples so that the epoch size is divisible by the world size. + # To drop all samples, we need to drop all repeated samples as well. + world_size = num_physical_nodes * ranks_per_node + num_repeated_samples = world_size - (num_samples % world_size) + drop_first = num_samples + num_repeated_samples + + x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first, + initial_physical_nodes) + # Partition should still have the appropriate shape, but without any samples in it. + assert x.shape == (num_physical_nodes, ranks_per_node, workers_per_rank, 0, batch_size) + assert x.size == 0 + + +@pytest.mark.parametrize('num_samples', [400, 1000]) +@pytest.mark.parametrize('drop_additional', [1, 400]) +@pytest.mark.parametrize('num_canonical_nodes', [4]) +@pytest.mark.parametrize('num_physical_nodes', [4]) +@pytest.mark.parametrize('ranks_per_node', [8]) +@pytest.mark.parametrize('workers_per_rank', [8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_invalid_drop_first(num_samples: int, drop_additional: int, + num_canonical_nodes: int, num_physical_nodes: int, + ranks_per_node: int, workers_per_rank: int, batch_size: int, + partition_algo: str): + + # Partitioning should repeat samples so that the epoch size is divisible by the world size. + # For `drop_first` to be invalid, we need to exceed the number of unique samples plus the + # number of repeated samples. + world_size = num_physical_nodes * ranks_per_node + num_repeated_samples = world_size - (num_samples % world_size) + drop_first = num_samples + num_repeated_samples + drop_additional + + with pytest.raises(ValueError, match=f'Resuming further into the dataset*'): + _ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first) + + @pytest.mark.parametrize('num_samples', [1, 4]) @pytest.mark.parametrize('num_canonical_nodes', [1, 4]) @pytest.mark.parametrize('num_physical_nodes', [1, 4]) diff --git a/tests/test_spanner.py b/tests/test_spanner.py index f971813f3..ad8e01a74 100644 --- a/tests/test_spanner.py +++ b/tests/test_spanner.py @@ -24,6 +24,6 @@ def test_spanner_success(): def test_spanner_invalid_index(index: int): shard_sizes = np.arange(5, 100, 5) span_size = 7 - with pytest.raises(ValueError, match='Invalid sample index.*'): + with pytest.raises(IndexError, match='Invalid sample index.*'): spanner = Spanner(shard_sizes, span_size) spanner[index] diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1c1f7e10c..cbad5a9e9 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -3,87 +3,224 @@ import pathlib import time +import itertools from typing import Any, Dict, Optional, Tuple -import pytest +#import pytest -from streaming.base import StreamingDataset +from streaming.base import StreamingDataset, StreamingDataLoader from streaming.text import StreamingC4 from streaming.vision import StreamingADE20K, StreamingCIFAR10, StreamingCOCO, StreamingImageNet +from composer.utils import dist as dist +from composer.utils import get_device +from composer.utils.dist import get_world_size +import torch def get_dataset(name: str, - local: str, split: str, shuffle: bool, batch_size: Optional[int], other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'ade20k': { - 'remote': 's3://mosaicml-internal-dataset-ade20k/mds/2/', - 'num_samples': { - 'train': 20206, - 'val': 2000, - }, - 'class': StreamingADE20K, + 'refinedweb': { + 'local': f'/tmp/test_refinedweb_05May1029', + 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', + 'num_samples': 20206, + 'class': StreamingDataset, 'kwargs': {}, }, - 'imagenet1k': { - 'remote': 's3://mosaicml-internal-dataset-imagenet1k/mds/2/', - 'num_samples': { - 'train': 1281167, - 'val': 50000, + 'dummy_table_dbsql': { + 'local': f'/tmp/test_dummy_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', + 'num_samples': 5, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', }, - 'class': StreamingImageNet, - 'kwargs': {}, }, - 'coco': { - 'remote': 's3://mosaicml-internal-dataset-coco/mds/2/', - 'num_samples': { - 'train': 117266, - 'val': 4952, + 'random_cpt_table_dbsql': { + 'local': f'/tmp/test_random_cpt_table_05May1029', + 'remote': 'SELECT text FROM main.streaming.random_cpt_table', + 'num_samples': 100000, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', + 'use_cached_result': False, }, - 'class': StreamingCOCO, - 'kwargs': {}, }, - 'c4': { - 'remote': 's3://mosaicml-internal-dataset-c4/mds/2/', - 'num_samples': { - 'train': 364868892, - 'val': 364608, + 'prompt_response_table_dbsql': { + 'local': f'/tmp/test_prompt_response_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.prompt_response_table_normal_1000000_20000', + 'num_samples': 1000000, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', }, - 'class': StreamingC4, + }, + 'random_large_table': { + 'local': f'/tmp/test_random_large_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.random_large_table', + 'num_samples': 100000, + 'class': StreamingDataset, 'kwargs': { - 'tokenizer_name': 'bert-base-uncased', - 'max_seq_len': 512, - 'group_method': 'truncate' + 'cluster_id': "0201-234512-tcp9nfat" }, }, - 'cifar10': { - 'remote': 's3://mosaicml-internal-dataset-cifar10/mds/2/', - 'num_samples': { - 'train': 50000, - 'val': 10000, + 'large_liquid_test_table_08_07_dbsql': { + 'local': f'/tmp/test_liquid_test_table_05May1029', + 'remote': 'SELECT * FROM auto_maintenance_bugbash.stella.large_liquid_test_table_08_07', + 'num_samples': 89279077339, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'auto_maintenance_bugbash', + 'schema': 'stella', }, - 'class': StreamingCIFAR10, - 'kwargs': {}, }, - 'test_streaming_upload': { - 'remote': 's3://streaming-upload-test-bucket/', - 'num_samples': { - 'all': 0, + 'reddit_table_sparkconnect': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': { + 'cluster_id': "0523-224100-tid6mais" }, + }, + 'reddit_table_dbsql': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {}, - } + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'reddit', + }, + }, + 'reddit_table_dbsql_cachelimit': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'reddit', + 'cache_limit': '10gb', + }, + }, + 'wiki_table_dbsql_cachelimit': { + 'local': f'/tmp/test_wiki_table_05May1029', + 'remote': 'SELECT id, text FROM main.streaming.wiki_table', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + 'cache_limit': '100mb', + }, + 'shuffle': True, + }, + 'main_streaming_wiki_table_mds': { + 'local': f'/tmp/test_wiki_table_volume_05May1029', + 'remote': 'dbfs:/Volumes/main/streaming/xiaohan_zhang/delta-streaming-benchmarks-mds/wiki_table', + 'num_samples': 5823210, + 'class': StreamingDataset, + 'kwargs': { + 'cache_limit': '100gb', + }, + 'shuffle': True, + }, + 'main_streaming_wiki_table_dbsql': { + 'local': f'/tmp/test_wiki_table_volume_05May1029', + 'remote': 'SELECT text FROM main.streaming.wiki_table', + 'num_samples': 5823210, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + 'cache_limit': '100gb', + }, + 'shuffle': True, + }, + 'coco_table_dbsql': { + 'local': f'/tmp/test_coco_table_05May1029', + 'remote': 'SELECT data, captions FROM main.streaming.coco_with_meta_and_captions', + 'num_samples': 26688, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, + 'evesize_level1_filter_dbsql': { + 'local': f'/tmp/test_evesize_05May1029', + 'remote': "SELECT prompt, response, class FROM datasets.cody.evesize_level1_evolve_respond WHERE class = \'CODE\'", + 'num_samples': 68784, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'datasets', + 'schema': 'cody', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, + 'evesize_level1_version_dbsql': { + 'local': f'/tmp/test_evesize_05May1029', + 'remote': "SELECT * FROM main.streaming.evesize_level1_evolve_response_sub VERSION AS OF 0", + 'num_samples': 273044, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, + 'finance_more_like_dbsql': { + 'local': f'/tmp/test_finance_more_like_05May1029', + 'remote': "SELECT llama_3_1_tokens AS tokens FROM main.seanowen.finance_more_like", + 'num_samples': 210463508, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'seanowen', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, + 'debug_local': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': None, + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': {} + }, } - if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: - raise ValueError('Could not load dataset with name={name} and split={split}') + #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: + # raise ValueError('Could not load dataset with name={name} and split={split}') d = dataset_map[name] - expected_samples = d['num_samples'][split] + expected_samples = d['num_samples'] + local = d['local'] remote = d['remote'] + shuffle = d.get('shuffle', False) or shuffle kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, remote=remote, @@ -94,23 +231,14 @@ def get_dataset(name: str, return (expected_samples, dataset) -@pytest.mark.remote -@pytest.mark.parametrize('name', [ - 'ade20k', - 'imagenet1k', - 'coco', - 'cifar10', - 'c4', -]) -@pytest.mark.parametrize('split', ['val']) -def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) -> None: +def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() + batch_size = 1024 expected_samples, dataset = get_dataset(name=name, - local=str(tmp_path), split=split, shuffle=False, - batch_size=None) + batch_size=batch_size) build_end = time.time() build_dur = build_end - build_start print('Built dataset') @@ -121,7 +249,7 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) for _ in dataset: rcvd_samples += 1 - if (rcvd_samples % 1000 == 0): + if (rcvd_samples % 10000 == 0): print(f'samples read: {rcvd_samples}') iter_end = time.time() @@ -129,8 +257,103 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) samples_per_sec = rcvd_samples / iter_dur # Print debug info + print(f'received {rcvd_samples} samples') + print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + + f'samples_per_sec={samples_per_sec:.2f}') + + # Test all samples arrived + if dist.is_available() and dist.is_initialized() and get_world_size()>1: + rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64).cuda() + dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') + assert rcvd_samples.cpu() >= expected_samples + return + + assert rcvd_samples >= expected_samples + +def test_streaming_remote_dataloader(name: str, split: str) -> None: + # Build StreamingDataset + build_start = time.time() + batch_size = 1024 + expected_samples, dataset = get_dataset(name=name, + split=split, + shuffle=False, + batch_size=batch_size) + + + data_loader = StreamingDataLoader(dataset, + batch_size=batch_size, + num_workers=8, + prefetch_factor=None, + persistent_workers=False, + pin_memory=True, + drop_last=True) + build_end = time.time() + build_dur = build_end - build_start + print('Built dataset') + + # Test basic iteration + rcvd_samples = 0 + iter_start = time.time() + + for epcoh in range(1): + skip_batches = 5 + for batch_idx, data_dict in enumerate(itertools.islice(data_loader, skip_batches, None)): + #for batch_idx, data_dict in enumerate(data_loader): + rcvd_samples += batch_size + + if (rcvd_samples % (10*batch_size) == 0): + print(f'samples read: {rcvd_samples}') + + iter_end = time.time() + iter_dur = iter_end - iter_start + samples_per_sec = rcvd_samples / iter_dur + + # Print debug info + print(f'received {rcvd_samples} samples') print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + f'samples_per_sec={samples_per_sec:.2f}') # Test all samples arrived - assert rcvd_samples == expected_samples + if dist.is_available() and dist.is_initialized() and get_world_size()>1: + rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64).cuda() + dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') + assert rcvd_samples.cpu() >= expected_samples + return + + # Test all samples arrived + assert rcvd_samples >= expected_samples + + +if __name__ == "__main__": + dist.initialize_dist(get_device(None)) + + from streaming.base.util import clean_stale_shared_memory + clean_stale_shared_memory() + + #test_streaming_remote_dataset(name = 'refinedweb', split=None) + #test_streaming_remote_dataset(name = 'dummy_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'random_large_table', split=None) + #test_streaming_remote_dataset(name = 'reddit_table', split=None) + #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) + #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) + #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'debug_local', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'finance_more_like_dbsql', split=None) + + #test_streaming_remote_dataloader(name = 'refinedweb', split=None) + #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) + #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_mds', split=None) + #test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_dbsql', split=None) + diff --git a/tests/test_upload.py b/tests/test_upload.py index 455b6b8c4..b280ac968 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -14,7 +14,7 @@ from streaming.base.storage.upload import (AlipanUploader, AzureDataLakeUploader, AzureUploader, CloudUploader, DatabricksUnityCatalogUploader, DBFSUploader, GCSAuthentication, GCSUploader, - LocalUploader, S3Uploader) + HFUploader, LocalUploader, S3Uploader) from tests.conftest import MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -425,6 +425,33 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): _ = AzureDataLakeUploader(out=local) +class TestHFUploader: + + @patch('streaming.base.storage.upload.HFUploader.check_dataset_exists') + @pytest.mark.usefixtures('hf_credentials') + @pytest.mark.parametrize('out', ['hf://datasets/org_name/repo_name/path']) + def test_instantiation(self, mocked_requests: Mock, out: Any): + mocked_requests.side_effect = None + _ = HFUploader(out=out) + if not isinstance(out, str): + shutil.rmtree(out[0], ignore_errors=True) + + @pytest.mark.parametrize('out', ['ss4://container/dir']) + def test_invalid_remote_str(self, out: str): + with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): + _ = HFUploader(out=out) + + def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): + with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): + local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) + local_file_path = os.path.join(local, 'file.txt') + # Creating an empty file at specified location + with open(local_file_path, 'w') as _: + pass + _ = HFUploader(out=local) + + class TestDatabricksUnityCatalogUploader: @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client')