From dc536af29ccc60938b82fd8d3bd780873fcc8997 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 8 Oct 2024 12:59:33 +0200 Subject: [PATCH] Remote IO: http support (#464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support read directly from a http server like: ```python import kvikio import cupy with kvikio.RemoteFile.from_http_url("http://127.0.0.1:9000/myfile") as f: ary = cupy.empty(f.nbytes, dtype="uint8") f.read(ary) ``` This PR is the first step to support S3 using libcurl instead of [aws-s3-sdk](https://github.com/rapidsai/kvikio/pull/426), which has some pros and cons: * Pros * The [global conda pinning issue](https://github.com/rapidsai/kvikio/pull/426#discussion_r1752688778) is less of a problem. * We can support other protocols such as http, ftp, and Azure’s storage, without much work. * We avoid the [free-after-main issue in aws-s3-sdk](https://github.com/rapidsai/kvikio/blob/000126516db430988ab9af5ee1576ca3fe6afe27/cpp/include/kvikio/remote_handle.hpp#L87-L94). This is huge since we would otherwise have to pass around a `S3Context` in libcudf and cudf to handle shutdown correctly. This is not a problem in libcurl, see https://curl.se/libcurl/c/libcurl.html under `Global constants`. * Cons * Hard to support the AWS configuration file. We will require the user to either specify the options programmatically or through environment variables like `AWS_ACCESS_KEY_ID ` and `AWS_SECRET_ACCESS_KEY `. Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/kvikio/pull/464 --- .../all_cuda-118_arch-aarch64.yaml | 2 + .../all_cuda-118_arch-x86_64.yaml | 2 + .../all_cuda-125_arch-aarch64.yaml | 2 + .../all_cuda-125_arch-x86_64.yaml | 2 + conda/recipes/kvikio/meta.yaml | 1 + conda/recipes/libkvikio/meta.yaml | 2 + cpp/CMakeLists.txt | 15 + cpp/cmake/thirdparty/get_libcurl.cmake | 32 ++ cpp/include/kvikio/remote_handle.hpp | 294 ++++++++++++++++++ cpp/include/kvikio/shim/libcurl.hpp | 260 ++++++++++++++++ dependencies.yaml | 2 + docs/source/api.rst | 7 + docs/source/index.rst | 1 + docs/source/remote_file.rst | 11 + python/kvikio/examples/http_io.py | 37 +++ python/kvikio/kvikio/__init__.py | 9 +- python/kvikio/kvikio/_lib/CMakeLists.txt | 10 + python/kvikio/kvikio/_lib/remote_handle.pyx | 89 ++++++ python/kvikio/kvikio/benchmarks/http_io.py | 174 +++++++++++ python/kvikio/kvikio/remote_file.py | 127 ++++++++ python/kvikio/kvikio/utils.py | 81 +++++ python/kvikio/pyproject.toml | 1 + python/kvikio/tests/test_benchmarks.py | 31 ++ python/kvikio/tests/test_examples.py | 17 +- python/kvikio/tests/test_http_io.py | 98 ++++++ 25 files changed, 1305 insertions(+), 2 deletions(-) create mode 100644 cpp/cmake/thirdparty/get_libcurl.cmake create mode 100644 cpp/include/kvikio/remote_handle.hpp create mode 100644 cpp/include/kvikio/shim/libcurl.hpp create mode 100644 docs/source/remote_file.rst create mode 100644 python/kvikio/examples/http_io.py create mode 100644 python/kvikio/kvikio/_lib/remote_handle.pyx create mode 100644 python/kvikio/kvikio/benchmarks/http_io.py create mode 100644 python/kvikio/kvikio/remote_file.py create mode 100644 python/kvikio/kvikio/utils.py create mode 100644 python/kvikio/tests/test_http_io.py diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 65ca39fa34..0e7f4b3e21 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -17,6 +17,7 @@ dependencies: - dask>=2022.05.2 - doxygen=1.9.1 - gcc_linux-aarch64=11.* +- libcurl>=7.87.0 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 @@ -28,6 +29,7 @@ dependencies: - pytest - pytest-cov - python>=3.10,<3.13 +- rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 - scikit-build-core>=0.10.0 - sphinx diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index a020690e64..293085e8f7 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -19,6 +19,7 @@ dependencies: - gcc_linux-64=11.* - libcufile-dev=1.4.0.31 - libcufile=1.4.0.31 +- libcurl>=7.87.0 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 @@ -30,6 +31,7 @@ dependencies: - pytest - pytest-cov - python>=3.10,<3.13 +- rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 - scikit-build-core>=0.10.0 - sphinx diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index 31145241d7..1e4a370ff6 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -18,6 +18,7 @@ dependencies: - doxygen=1.9.1 - gcc_linux-aarch64=11.* - libcufile-dev +- libcurl>=7.87.0 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 @@ -28,6 +29,7 @@ dependencies: - pytest - pytest-cov - python>=3.10,<3.13 +- rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 - scikit-build-core>=0.10.0 - sphinx diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 4d7d0be7c6..44d8772a71 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -18,6 +18,7 @@ dependencies: - doxygen=1.9.1 - gcc_linux-64=11.* - libcufile-dev +- libcurl>=7.87.0 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 @@ -28,6 +29,7 @@ dependencies: - pytest - pytest-cov - python>=3.10,<3.13 +- rangehttpserver - rapids-build-backend>=0.3.0,<0.4.0.dev0 - scikit-build-core>=0.10.0 - sphinx diff --git a/conda/recipes/kvikio/meta.yaml b/conda/recipes/kvikio/meta.yaml index 4a352012e3..3c41af3310 100644 --- a/conda/recipes/kvikio/meta.yaml +++ b/conda/recipes/kvikio/meta.yaml @@ -64,6 +64,7 @@ requirements: - rapids-build-backend >=0.3.0,<0.4.0.dev0 - scikit-build-core >=0.10.0 - libkvikio ={{ version }} + - libcurl==7.87.0 run: - python - numpy >=1.23,<3.0a0 diff --git a/conda/recipes/libkvikio/meta.yaml b/conda/recipes/libkvikio/meta.yaml index 186c373f56..999b9fc2c1 100644 --- a/conda/recipes/libkvikio/meta.yaml +++ b/conda/recipes/libkvikio/meta.yaml @@ -52,6 +52,7 @@ requirements: {% else %} - libcufile-dev # [linux] {% endif %} + - libcurl==7.87.0 outputs: - name: libkvikio @@ -74,6 +75,7 @@ outputs: - cmake {{ cmake_version }} host: - cuda-version ={{ cuda_version }} + - libcurl==7.87.0 run: - {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }} {% if cuda_major == "11" %} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5990405b1c..786ccb9266 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -36,6 +36,7 @@ rapids_cmake_write_version_file(include/kvikio/version_config.hpp) rapids_cmake_build_type(Release) # build options +option(KvikIO_REMOTE_SUPPORT "Configure CMake to build with remote IO support" ON) option(KvikIO_BUILD_EXAMPLES "Configure CMake to build examples" ON) option(KvikIO_BUILD_TESTS "Configure CMake to build tests" ON) @@ -50,6 +51,10 @@ rapids_find_package( INSTALL_EXPORT_SET kvikio-exports ) +if(KvikIO_REMOTE_SUPPORT) + include(cmake/thirdparty/get_libcurl.cmake) +endif() + rapids_find_package( CUDAToolkit BUILD_EXPORT_SET kvikio-exports @@ -138,6 +143,10 @@ target_link_libraries( kvikio INTERFACE Threads::Threads BS::thread_pool ${CMAKE_DL_LIBS} $ ) +if(TARGET CURL::libcurl) + target_link_libraries(kvikio INTERFACE $) + target_compile_definitions(kvikio INTERFACE $) +endif() target_compile_features(kvikio INTERFACE cxx_std_17) # optionally build examples @@ -231,6 +240,12 @@ if(NOT already_set_kvikio) target_compile_definitions(kvikio::kvikio INTERFACE KVIKIO_CUFILE_STREAM_API_FOUND) endif() endif() + + if(TARGET CURL::libcurl) + target_link_libraries(kvikio::kvikio INTERFACE CURL::libcurl) + target_compile_definitions(kvikio::kvikio INTERFACE KVIKIO_LIBCURL_FOUND) + endif() + endif() ]=] ) diff --git a/cpp/cmake/thirdparty/get_libcurl.cmake b/cpp/cmake/thirdparty/get_libcurl.cmake new file mode 100644 index 0000000000..7695592737 --- /dev/null +++ b/cpp/cmake/thirdparty/get_libcurl.cmake @@ -0,0 +1,32 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# This function finds libcurl and sets any additional necessary environment variables. +function(find_and_configure_libcurl) + include(${rapids-cmake-dir}/cpm/find.cmake) + + rapids_cpm_find( + CURL 7.87.0 + GLOBAL_TARGETS libcurl + BUILD_EXPORT_SET kvikio-exports + INSTALL_EXPORT_SET kvikio-exports + CPM_ARGS + GIT_REPOSITORY https://github.com/curl/curl + GIT_TAG curl-7_87_0 + OPTIONS "BUILD_CURL_EXE OFF" "BUILD_SHARED_LIBS OFF" "BUILD_TESTING OFF" "CURL_USE_LIBPSL OFF" + "CURL_DISABLE_LDAP ON" "CMAKE_POSITION_INDEPENDENT_CODE ON" + ) +endfunction() + +find_and_configure_libcurl() diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp new file mode 100644 index 0000000000..e036ebcb37 --- /dev/null +++ b/cpp/include/kvikio/remote_handle.hpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace kvikio { +namespace detail { + +/** + * @brief Context used by the "CURLOPT_WRITEFUNCTION" callbacks. + */ +struct CallbackContext { + char* buf; // Output buffer to read into. + std::size_t size; // Total number of bytes to read. + std::ptrdiff_t offset; // Offset into `buf` to start reading. + bool overflow_error; // Flag to indicate overflow. + CallbackContext(void* buf, std::size_t size) + : buf{static_cast(buf)}, size{size}, offset{0}, overflow_error{0} + { + } +}; + +/** + * @brief A "CURLOPT_WRITEFUNCTION" to copy downloaded data to the output host buffer. + * + * See . + * + * @param data Data downloaded by libcurl that is ready for consumption. + * @param size Size of each element in `nmemb`; size is always 1. + * @param nmemb Size of the data in `nmemb`. + * @param context A pointer to an instance of `CallbackContext`. + */ +inline std::size_t callback_host_memory(char* data, + std::size_t size, + std::size_t nmemb, + void* context) +{ + auto ctx = reinterpret_cast(context); + std::size_t const nbytes = size * nmemb; + if (ctx->size < ctx->offset + nbytes) { + ctx->overflow_error = true; + return CURL_WRITEFUNC_ERROR; + } + KVIKIO_NVTX_FUNC_RANGE("RemoteHandle - callback_host_memory()", nbytes); + std::memcpy(ctx->buf + ctx->offset, data, nbytes); + ctx->offset += nbytes; + return nbytes; +} + +/** + * @brief A "CURLOPT_WRITEFUNCTION" to copy downloaded data to the output device buffer. + * + * See . + * + * @param data Data downloaded by libcurl that is ready for consumption. + * @param size Size of each element in `nmemb`; size is always 1. + * @param nmemb Size of the data in `nmemb`. + * @param context A pointer to an instance of `CallbackContext`. + */ +inline std::size_t callback_device_memory(char* data, + std::size_t size, + std::size_t nmemb, + void* context) +{ + auto ctx = reinterpret_cast(context); + const std::size_t nbytes = size * nmemb; + if (ctx->size < ctx->offset + nbytes) { + ctx->overflow_error = true; + return CURL_WRITEFUNC_ERROR; + } + KVIKIO_NVTX_FUNC_RANGE("RemoteHandle - callback_device_memory()", nbytes); + + CUstream stream = detail::StreamsByThread::get(); + CUDA_DRIVER_TRY(cudaAPI::instance().MemcpyHtoDAsync( + convert_void2deviceptr(ctx->buf + ctx->offset), data, nbytes, stream)); + // We have to sync since curl might overwrite or free `data`. + CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(stream)); + + ctx->offset += nbytes; + return nbytes; +} + +} // namespace detail + +/** + * @brief Abstract base class for remote endpoints. + * + * In this context, an endpoint refers to a remote file using a specific communication protocol. + * + * Each communication protocol, such as HTTP or S3, needs to implement this ABC and implement + * its own ctor that takes communication protocol specific arguments. + */ +class RemoteEndpoint { + public: + /** + * @brief Set needed connection options on a curl handle. + * + * Subsequently, a call to `curl.perform()` should connect to the endpoint. + * + * @param curl The curl handle. + */ + virtual void setopt(CurlHandle& curl) = 0; + + /** + * @brief Get a description of this remote point instance. + * + * @returns A string description. + */ + virtual std::string str() = 0; + + virtual ~RemoteEndpoint() = default; +}; + +/** + * @brief A remote endpoint using http. + */ +class HttpEndpoint : public RemoteEndpoint { + private: + std::string _url; + + public: + HttpEndpoint(std::string url) : _url{std::move(url)} {} + void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); } + std::string str() override { return _url; } + ~HttpEndpoint() override = default; +}; + +/** + * @brief Handle of remote file. + */ +class RemoteHandle { + private: + std::unique_ptr _endpoint; + std::size_t _nbytes; + + public: + /** + * @brief Create a new remote handle from an endpoint and a file size. + * + * @param endpoint Remote endpoint used for subsequent IO. + * @param nbytes The size of the remote file (in bytes). + */ + RemoteHandle(std::unique_ptr endpoint, std::size_t nbytes) + : _endpoint{std::move(endpoint)}, _nbytes{nbytes} + { + } + + /** + * @brief Create a new remote handle from an endpoint (infers the file size). + * + * The file size is received from the remote server using `endpoint`. + * + * @param endpoint Remote endpoint used for subsequently IO. + */ + RemoteHandle(std::unique_ptr endpoint) + { + auto curl = create_curl_handle(); + + endpoint->setopt(curl); + curl.setopt(CURLOPT_NOBODY, 1L); + curl.setopt(CURLOPT_FOLLOWLOCATION, 1L); + curl.perform(); + curl_off_t cl; + curl.getinfo(CURLINFO_CONTENT_LENGTH_DOWNLOAD_T, &cl); + if (cl < 0) { + throw std::runtime_error("cannot get size of " + endpoint->str() + + ", content-length not provided by the server"); + } + _nbytes = cl; + _endpoint = std::move(endpoint); + } + + // A remote handle is moveable but not copyable. + RemoteHandle(RemoteHandle&& o) = default; + RemoteHandle& operator=(RemoteHandle&& o) = default; + RemoteHandle(RemoteHandle const&) = delete; + RemoteHandle& operator=(RemoteHandle const&) = delete; + + /** + * @brief Get the file size. + * + * Note, this is very fast, no communication needed. + * + * @return The number of bytes. + */ + [[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; } + + /** + * @brief Read from remote source into buffer (host or device memory). + * + * @param buf Pointer to host or device memory. + * @param size Number of bytes to read. + * @param file_offset File offset in bytes. + * @return Number of bytes read, which is always `size`. + */ + std::size_t read(void* buf, std::size_t size, std::size_t file_offset = 0) + { + KVIKIO_NVTX_FUNC_RANGE("RemoteHandle::read()", size); + + if (file_offset + size > _nbytes) { + std::stringstream ss; + ss << "cannot read " << file_offset << "+" << size << " bytes into a " << _nbytes + << " bytes file (" << _endpoint->str() << ")"; + throw std::invalid_argument(ss.str()); + } + const bool is_host_mem = is_host_memory(buf); + auto curl = create_curl_handle(); + _endpoint->setopt(curl); + + std::string const byte_range = + std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1); + curl.setopt(CURLOPT_RANGE, byte_range.c_str()); + + if (is_host_mem) { + curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_host_memory); + } else { + curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_device_memory); + } + detail::CallbackContext ctx{buf, size}; + curl.setopt(CURLOPT_WRITEDATA, &ctx); + + try { + if (is_host_mem) { + curl.perform(); + } else { + PushAndPopContext c(get_context_from_pointer(buf)); + curl.perform(); + } + } catch (std::runtime_error const& e) { + if (ctx.overflow_error) { + std::stringstream ss; + ss << "maybe the server doesn't support file ranges? [" << e.what() << "]"; + throw std::overflow_error(ss.str()); + } + throw; + } + return size; + } + + /** + * @brief Read from remote source into buffer (host or device memory) in parallel. + * + * This API is a parallel async version of `.read()` that partitions the operation + * into tasks of size `task_size` for execution in the default thread pool. + * + * @param buf Pointer to host or device memory. + * @param size Number of bytes to read. + * @param file_offset File offset in bytes. + * @param task_size Size of each task in bytes. + * @return Future that on completion returns the size of bytes read, which is always `size`. + */ + std::future pread(void* buf, + std::size_t size, + std::size_t file_offset = 0, + std::size_t task_size = defaults::task_size()) + { + KVIKIO_NVTX_FUNC_RANGE("RemoteHandle::pread()", size); + auto task = [this](void* devPtr_base, + std::size_t size, + std::size_t file_offset, + std::size_t devPtr_offset) -> std::size_t { + return read(static_cast(devPtr_base) + devPtr_offset, size, file_offset); + }; + return parallel_io(task, buf, size, file_offset, task_size, 0); + } +}; + +} // namespace kvikio diff --git a/cpp/include/kvikio/shim/libcurl.hpp b/cpp/include/kvikio/shim/libcurl.hpp new file mode 100644 index 0000000000..cee50c5947 --- /dev/null +++ b/cpp/include/kvikio/shim/libcurl.hpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef KVIKIO_LIBCURL_FOUND +#error \ + "cannot include the remote IO API, please build KvikIO with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace kvikio { + +/** + * @brief Singleton class to initialize and cleanup the global state of libcurl + * + * Notice, libcurl allows the use of a singleton class: + * + * In a C++ module, it is common to deal with the global constant situation by defining a special + * class that represents the global constant environment of the module. A program always has exactly + * one object of the class, in static storage. That way, the program automatically calls the + * constructor of the object as the program starts up and the destructor as it terminates. As the + * author of this libcurl-using module, you can make the constructor call curl_global_init and the + * destructor call curl_global_cleanup and satisfy libcurl's requirements without your user having + * to think about it. (Caveat: If you are initializing libcurl from a Windows DLL you should not + * initialize it from DllMain or a static initializer because Windows holds the loader lock during + * that time and it could cause a deadlock.) + * + * Source . + */ +class LibCurl { + public: + // We hold a unique pointer to the raw curl handle and set `curl_easy_cleanup` as its Deleter. + using UniqueHandlePtr = std::unique_ptr>; + + private: + std::mutex _mutex{}; + // Curl handles free to be used. + std::vector _free_curl_handles{}; + + LibCurl() + { + CURLcode err = curl_global_init(CURL_GLOBAL_DEFAULT); + if (err != CURLE_OK) { + throw std::runtime_error("cannot initialize libcurl - errorcode: " + std::to_string(err)); + } + curl_version_info_data* ver = curl_version_info(::CURLVERSION_NOW); + if ((ver->features & CURL_VERSION_THREADSAFE) == 0) { + throw std::runtime_error("cannot initialize libcurl - built with thread safety disabled"); + } + } + ~LibCurl() noexcept + { + _free_curl_handles.clear(); + curl_global_cleanup(); + } + + public: + static LibCurl& instance() + { + static LibCurl _instance; + return _instance; + } + + /** + * @brief Returns a free curl handle if available. + */ + UniqueHandlePtr get_free_handle() + { + UniqueHandlePtr ret; + std::lock_guard const lock(_mutex); + if (!_free_curl_handles.empty()) { + ret = std::move(_free_curl_handles.back()); + _free_curl_handles.pop_back(); + } + return ret; + } + + /** + * @brief Returns a curl handle, create a new handle if none is available. + */ + UniqueHandlePtr get_handle() + { + // Check if we have a free handle available. + UniqueHandlePtr ret = get_free_handle(); + if (ret) { + curl_easy_reset(ret.get()); + } else { + // If not, we create a new handle. + CURL* raw_handle = curl_easy_init(); + if (raw_handle == nullptr) { + throw std::runtime_error("libcurl: call to curl_easy_init() failed"); + } + ret = UniqueHandlePtr(raw_handle, curl_easy_cleanup); + } + return ret; + } + + /** + * @brief Retain a curl handle for later use. + */ + void retain_handle(UniqueHandlePtr handle) + { + std::lock_guard const lock(_mutex); + _free_curl_handles.push_back(std::move(handle)); + } +}; + +/** + * @brief Representation of a curl easy handle pointer and its operations. + * + * An instance is given a `LibCurl::UniqueHandlePtr` on creation, which is + * later retained on destruction. + */ +class CurlHandle { + private: + char _errbuf[CURL_ERROR_SIZE]; + LibCurl::UniqueHandlePtr _handle; + std::string _source_file; + std::string _source_line; + + public: + /** + * @brief Construct a new curl handle. + * + * Typically, do not use this directly instead use the `create_curl_handle()` macro. + * + * @param handle An unused curl easy handle pointer, which is retained on destruction. + * @param source_file Path of source file of the caller (for error messages). + * @param source_line Line of source file of the caller (for error messages). + */ + CurlHandle(LibCurl::UniqueHandlePtr handle, std::string source_file, std::string source_line) + : _handle{std::move(handle)}, + _source_file(std::move(source_file)), + _source_line(std::move(source_line)) + { + // Need CURLOPT_NOSIGNAL to support threading, see + // + setopt(CURLOPT_NOSIGNAL, 1L); + + // We always set CURLOPT_ERRORBUFFER to get better error messages. + setopt(CURLOPT_ERRORBUFFER, _errbuf); + + // Make curl_easy_perform() fail when receiving HTTP code errors. + setopt(CURLOPT_FAILONERROR, 1L); + } + ~CurlHandle() noexcept { LibCurl::instance().retain_handle(std::move(_handle)); } + + /** + * @brief CurlHandle support is not movable or copyable. + */ + CurlHandle(CurlHandle const&) = delete; + CurlHandle& operator=(CurlHandle const&) = delete; + CurlHandle(CurlHandle&& o) = delete; + CurlHandle& operator=(CurlHandle&& o) = delete; + + /** + * @brief Get the underlying curl easy handle pointer. + */ + CURL* handle() noexcept { return _handle.get(); } + + /** + * @brief Set option for the curl handle. + * + * See for available options. + * + * @tparam VAL The type of the value. + * @param option The curl option to set. + */ + template + void setopt(CURLoption option, VAL value) + { + CURLcode err = curl_easy_setopt(handle(), option, value); + if (err != CURLE_OK) { + std::stringstream ss; + ss << "curl_easy_setopt() error near " << _source_file << ":" << _source_line; + ss << "(" << curl_easy_strerror(err) << ")"; + throw std::runtime_error(ss.str()); + } + } + + /** + * @brief Perform a blocking network transfer using previously set options. + * + * See . + */ + void perform() + { + // Perform the curl operation and check for errors. + CURLcode err = curl_easy_perform(handle()); + if (err != CURLE_OK) { + std::string msg(_errbuf); + std::stringstream ss; + ss << "curl_easy_perform() error near " << _source_file << ":" << _source_line; + if (msg.empty()) { + ss << "(" << curl_easy_strerror(err) << ")"; + } else { + ss << "(" << msg << ")"; + } + throw std::runtime_error(ss.str()); + } + } + + /** + * @brief Extract information from a curl handle. + * + * See for available options. + * + * @tparam OUTPUT The type of the output. + * @param output The output, which is used as-is: `curl_easy_getinfo(..., output)`. + */ + template + void getinfo(CURLINFO info, OUTPUT* output) + { + CURLcode err = curl_easy_getinfo(handle(), info, output); + if (err != CURLE_OK) { + std::stringstream ss; + ss << "curl_easy_getinfo() error near " << _source_file << ":" << _source_line; + ss << "(" << curl_easy_strerror(err) << ")"; + throw std::runtime_error(ss.str()); + } + } +}; + +/** + * @brief Create a new curl handle. + * + * @returns A `kvikio::CurlHandle` instance ready to be used. + */ +#define create_curl_handle() \ + kvikio::CurlHandle(kvikio::LibCurl::instance().get_handle(), __FILE__, KVIKIO_STRINGIFY(__LINE__)) + +} // namespace kvikio diff --git a/dependencies.yaml b/dependencies.yaml index 7a8a3a9bcc..39ba3aaa17 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -110,6 +110,7 @@ dependencies: packages: - c-compiler - cxx-compiler + - libcurl>=7.87.0 # Need CURL_WRITEFUNC_ERROR specific: - output_types: conda matrices: @@ -343,6 +344,7 @@ dependencies: - &dask dask>=2022.05.2 - pytest - pytest-cov + - rangehttpserver specific: - output_types: [conda, requirements, pyproject] matrices: diff --git a/docs/source/api.rst b/docs/source/api.rst index 4d19c09bbb..fd34367a00 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -18,6 +18,13 @@ Zarr .. autoclass:: GDSStore :members: +RemoteFile +---------- +.. currentmodule:: kvikio.remote_file + +.. autoclass:: RemoteFile + :members: + Defaults -------- .. currentmodule:: kvikio.defaults diff --git a/docs/source/index.rst b/docs/source/index.rst index 4dd491fd96..9e302b5f44 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ Contents install quickstart zarr + remote_file runtime_settings api genindex diff --git a/docs/source/remote_file.rst b/docs/source/remote_file.rst new file mode 100644 index 0000000000..ed6fe45b7b --- /dev/null +++ b/docs/source/remote_file.rst @@ -0,0 +1,11 @@ +Remote File +=========== + +KvikIO provides direct access to remote files. + + +Example +------- + +.. literalinclude:: ../../python/kvikio/examples/http_io.py + :language: python diff --git a/python/kvikio/examples/http_io.py b/python/kvikio/examples/http_io.py new file mode 100644 index 0000000000..26c9af1d44 --- /dev/null +++ b/python/kvikio/examples/http_io.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import pathlib +import tempfile + +import cupy +import numpy + +import kvikio +from kvikio.utils import LocalHttpServer + + +def main(tmpdir: pathlib.Path): + a = cupy.arange(100) + a.tofile(tmpdir / "myfile") + b = cupy.empty_like(a) + + # Start a local server that serves files in `tmpdir` + with LocalHttpServer(root_path=tmpdir) as server: + # Open remote file from a http url + with kvikio.RemoteFile.open_http(f"{server.url}/myfile") as f: + # KvikIO fetch the file size + assert f.nbytes() == a.nbytes + # Read the remote file into `b` as if it was a local file. + f.read(b) + assert all(a == b) + # We can also read into host memory seamlessly + a = cupy.asnumpy(a) + c = numpy.empty_like(a) + f.read(c) + assert all(a == c) + + +if __name__ == "__main__": + with tempfile.TemporaryDirectory() as tmpdir: + main(pathlib.Path(tmpdir)) diff --git a/python/kvikio/kvikio/__init__.py b/python/kvikio/kvikio/__init__.py index 883ac9e784..749d87ec1f 100644 --- a/python/kvikio/kvikio/__init__.py +++ b/python/kvikio/kvikio/__init__.py @@ -4,9 +4,16 @@ from kvikio._lib import driver_properties # type: ignore from kvikio._version import __git_commit__, __version__ from kvikio.cufile import CuFile +from kvikio.remote_file import RemoteFile, is_remote_file_available # TODO: Wrap nicely, maybe as a dataclass? DriverProperties = driver_properties.DriverProperties -__all__ = ["__git_commit__", "__version__", "CuFile"] +__all__ = [ + "__git_commit__", + "__version__", + "CuFile", + "RemoteFile", + "is_remote_file_available", +] diff --git a/python/kvikio/kvikio/_lib/CMakeLists.txt b/python/kvikio/kvikio/_lib/CMakeLists.txt index 74a6f6562f..18bb46c0fb 100644 --- a/python/kvikio/kvikio/_lib/CMakeLists.txt +++ b/python/kvikio/kvikio/_lib/CMakeLists.txt @@ -17,6 +17,16 @@ set(cython_modules arr.pyx buffer.pyx defaults.pyx driver_properties.pyx file_ha libnvcomp.pyx libnvcomp_ll.pyx ) +if(TARGET CURL::libcurl) + message(STATUS "Building remote_handle.pyx (libcurl found)") + list(APPEND cython_modules remote_handle.pyx) +else() + message( + STATUS + "Skipping remote_handle.pyx (please set KvikIO_REMOTE_SUPPORT=ON for remote file support)" + ) +endif() + rapids_cython_create_modules( CXX SOURCE_FILES "${cython_modules}" diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx new file mode 100644 index 0000000000..5e58da32f0 --- /dev/null +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +# distutils: language = c++ +# cython: language_level=3 + +from typing import Optional + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp.memory cimport make_unique, unique_ptr +from libcpp.string cimport string +from libcpp.utility cimport move, pair + +from kvikio._lib.arr cimport parse_buffer_argument +from kvikio._lib.future cimport IOFuture, _wrap_io_future, future + + +cdef extern from "" nogil: + cdef cppclass cpp_RemoteEndpoint "kvikio::RemoteEndpoint": + pass + + cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint": + cpp_HttpEndpoint(string url) except + + + cdef cppclass cpp_RemoteHandle "kvikio::RemoteHandle": + cpp_RemoteHandle( + unique_ptr[cpp_RemoteEndpoint] endpoint, size_t nbytes + ) except + + cpp_RemoteHandle(unique_ptr[cpp_RemoteEndpoint] endpoint) except + + int nbytes() except + + size_t read( + void* buf, + size_t size, + size_t file_offset + ) except + + future[size_t] pread( + void* buf, + size_t size, + size_t file_offset + ) except + + +cdef string _to_string(str_or_none): + """Convert Python object to a C++ string (if None, return the empty string)""" + if str_or_none is None: + return string() + return str.encode(str(str_or_none)) + + +cdef class RemoteFile: + cdef unique_ptr[cpp_RemoteHandle] _handle + + @classmethod + def open_http( + cls, + url: str, + nbytes: Optional[int], + ): + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_HttpEndpoint] ep = make_unique[cpp_HttpEndpoint]( + _to_string(url) + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + + def nbytes(self) -> int: + return deref(self._handle).nbytes() + + def read(self, buf, size: Optional[int], file_offset: int) -> int: + cdef pair[uintptr_t, size_t] info = parse_buffer_argument(buf, size, True) + return deref(self._handle).read( + info.first, + info.second, + file_offset, + ) + + def pread(self, buf, size: Optional[int], file_offset: int) -> IOFuture: + cdef pair[uintptr_t, size_t] info = parse_buffer_argument(buf, size, True) + return _wrap_io_future( + deref(self._handle).pread( + info.first, + info.second, + file_offset, + ) + ) diff --git a/python/kvikio/kvikio/benchmarks/http_io.py b/python/kvikio/kvikio/benchmarks/http_io.py new file mode 100644 index 0000000000..68d4643004 --- /dev/null +++ b/python/kvikio/kvikio/benchmarks/http_io.py @@ -0,0 +1,174 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import argparse +import contextlib +import pathlib +import statistics +import tempfile +import time +from functools import partial + +import cupy +import numpy +from dask.utils import format_bytes + +import kvikio +import kvikio.defaults +from kvikio.utils import LocalHttpServer + + +def run_numpy_like(args, xp): + src = numpy.arange(args.nelem, dtype=args.dtype) + src.tofile(args.server_root_path / "data") + dst = xp.empty_like(src) + url = f"{args.server_url}/data" + + def run() -> float: + t0 = time.perf_counter() + with kvikio.RemoteFile.open_http(url, nbytes=src.nbytes) as f: + res = f.read(dst) + t1 = time.perf_counter() + assert res == args.nbytes, f"IO mismatch, expected {args.nbytes} got {res}" + xp.testing.assert_array_equal(src, dst) + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +API = { + "cupy": partial(run_numpy_like, xp=cupy), + "numpy": partial(run_numpy_like, xp=numpy), +} + + +def main(args): + cupy.cuda.set_allocator(None) # Disable CuPy's default memory pool + cupy.arange(10) # Make sure CUDA is initialized + + kvikio.defaults.num_threads_reset(args.nthreads) + print("Roundtrip benchmark") + print("--------------------------------------") + print(f"nelem | {args.nelem} ({format_bytes(args.nbytes)})") + print(f"dtype | {args.dtype}") + print(f"nthreads | {args.nthreads}") + print(f"nruns | {args.nruns}") + print(f"server | {args.server}") + if args.server is None: + print("--------------------------------------") + print("WARNING: the bundled server is slow, ") + print("consider using --server.") + print("======================================") + + # Run each benchmark using the requested APIs + for api in args.api: + res = [] + for elapsed in API[api](args): + res.append(elapsed) + + def pprint_api_res(name, samples): + samples = [args.nbytes / s for s in samples] # Convert to throughput + mean = statistics.harmonic_mean(samples) if len(samples) > 1 else samples[0] + ret = f"{api}-{name}".ljust(18) + ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) + if len(samples) > 1: + stdev = statistics.stdev(samples) / mean * 100 + ret += " ± %5.2f %%" % stdev + ret += " (" + for sample in samples: + ret += f"{format_bytes(sample)}/s, " + ret = ret[:-2] + ")" # Replace trailing comma + return ret + + print(pprint_api_res("read", res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="HTTP benchmark") + parser.add_argument( + "-n", + "--nelem", + metavar="NELEM", + default="1024", + type=int, + help="Number of elements (default: %(default)s).", + ) + parser.add_argument( + "--dtype", + metavar="DATATYPE", + default="float32", + type=numpy.dtype, + help="The data type of each element (default: %(default)s).", + ) + parser.add_argument( + "--nruns", + metavar="RUNS", + default=1, + type=int, + help="Number of runs per API (default: %(default)s).", + ) + parser.add_argument( + "-t", + "--nthreads", + metavar="THREADS", + default=1, + type=int, + help="Number of threads to use (default: %(default)s).", + ) + parser.add_argument( + "--server", + default=None, + help=( + "Connect to an external http server as opposed " + "to the bundled (very slow) HTTP server. " + "Remember to also set --server-root-path." + ), + ) + parser.add_argument( + "--server-root-path", + default=None, + help="Path to the root directory that `--server` serves (local path).", + ) + parser.add_argument( + "--bundled-server-lifetime", + metavar="SECONDS", + default=3600, + type=int, + help="Maximum lifetime of the bundled server (default: %(default)s).", + ) + parser.add_argument( + "--api", + metavar="API", + default=list(API.keys())[0], # defaults to the first API + nargs="+", + choices=tuple(API.keys()) + ("all",), + help="List of APIs to use {%(choices)s} (default: %(default)s).", + ) + args = parser.parse_args() + args.nbytes = args.nelem * args.dtype.itemsize + if "all" in args.api: + args.api = tuple(API.keys()) + + with contextlib.ExitStack() as context_stack: + if args.server is None: + # Create a tmp dir for the bundled server to serve + temp_dir = tempfile.TemporaryDirectory() + args.bundled_server_root_dir = pathlib.Path(temp_dir.name) + context_stack.enter_context(temp_dir) + + # Create the bundled server + bundled_server = LocalHttpServer( + root_path=args.bundled_server_root_dir, + range_support=True, + max_lifetime=args.bundled_server_lifetime, + ) + context_stack.enter_context(bundled_server) + args.server_url = bundled_server.url + args.server_root_path = args.bundled_server_root_dir + else: + args.server_url = args.server + if args.server_root_path is None: + raise ValueError("please set --server-root-path") + args.server_root_path = pathlib.Path(args.server_root_path) + main(args) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py new file mode 100644 index 0000000000..52bbe8010f --- /dev/null +++ b/python/kvikio/kvikio/remote_file.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +from __future__ import annotations + +import functools +from typing import Optional + +from kvikio.cufile import IOFuture + + +@functools.cache +def is_remote_file_available() -> bool: + """Check if the remote module is available""" + try: + import kvikio._lib.remote_handle # noqa: F401 + except ImportError: + return False + else: + return True + + +@functools.cache +def _get_remote_module(): + """Get the remote module or raise an error""" + if not is_remote_file_available(): + raise RuntimeError( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ) + import kvikio._lib.remote_handle + + return kvikio._lib.remote_handle + + +class RemoteFile: + """File handle of a remote file.""" + + def __init__(self, handle): + """Create a remote file from a Cython handle. + + This constructor should not be called directly instead use a + factory method like `RemoteFile.open_http()` + + Parameters + ---------- + handle : kvikio._lib.remote_handle.RemoteFile + The Cython handle + """ + assert isinstance(handle, _get_remote_module().RemoteFile) + self._handle = handle + + @classmethod + def open_http( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + """Open a http file. + + Parameters + ---------- + url + URL to the remote file. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. + """ + return RemoteFile(_get_remote_module().RemoteFile.open_http(url, nbytes)) + + def close(self) -> None: + """Close the file""" + pass + + def __enter__(self) -> RemoteFile: + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def nbytes(self) -> int: + """Get the file size. + + Note, this is very fast, no communication needed. + + Returns + ------- + The number of bytes. + """ + return self._handle.nbytes() + + def read(self, buf, size: Optional[int] = None, file_offset: int = 0) -> int: + """Read from remote source into buffer (host or device memory) in parallel. + + Parameters + ---------- + buf : buffer-like or array-like + Device or host buffer to read into. + size + Size in bytes to read. + file_offset + Offset in the file to read from. + + Returns + ------- + The size of bytes that were successfully read. + """ + return self.pread(buf, size, file_offset).get() + + def pread(self, buf, size: Optional[int] = None, file_offset: int = 0) -> IOFuture: + """Read from remote source into buffer (host or device memory) in parallel. + + Parameters + ---------- + buf : buffer-like or array-like + Device or host buffer to read into. + size + Size in bytes to read. + file_offset + Offset in the file to read from. + + Returns + ------- + Future that on completion returns the size of bytes that were successfully + read. + """ + return IOFuture(self._handle.pread(buf, size, file_offset)) diff --git a/python/kvikio/kvikio/utils.py b/python/kvikio/kvikio/utils.py new file mode 100644 index 0000000000..09a9f2062a --- /dev/null +++ b/python/kvikio/kvikio/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import functools +import multiprocessing +import pathlib +import threading +import time +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer + + +class LocalHttpServer: + """Local http server - slow but convenient""" + + @staticmethod + def _server( + queue: multiprocessing.Queue, + root_path: str, + range_support: bool, + max_lifetime: int, + ): + if range_support: + from RangeHTTPServer import RangeRequestHandler + + handler = RangeRequestHandler + else: + handler = SimpleHTTPRequestHandler + httpd = ThreadingHTTPServer( + ("127.0.0.1", 0), functools.partial(handler, directory=root_path) + ) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + queue.put(httpd.server_address) + time.sleep(max_lifetime) + print( + f"ThreadingHTTPServer shutting down because of timeout ({max_lifetime}sec)" + ) + + def __init__( + self, + root_path: str | pathlib.Path, + range_support: bool = True, + max_lifetime: int = 120, + ) -> None: + """Create a context that starts a local http server. + + Example + ------- + >>> with LocalHttpServer(root_path="/my/server/") as server: + ... with kvikio.RemoteFile.open_http(f"{server.url}/myfile") as f: + ... f.read(...) + + Parameters + ---------- + root_path + Path to the directory the server will serve. + range_support + Whether to support the ranges, required by `RemoteFile.open_http()`. + Depend on the `RangeHTTPServer` module (`pip install rangehttpserver`). + max_lifetime + Maximum lifetime of the server (in seconds). + """ + self.root_path = root_path + self.range_support = range_support + self.max_lifetime = max_lifetime + + def __enter__(self): + queue = multiprocessing.Queue() + self.process = multiprocessing.Process( + target=LocalHttpServer._server, + args=(queue, str(self.root_path), self.range_support, self.max_lifetime), + ) + self.process.start() + ip, port = queue.get() + self.ip = ip + self.port = port + self.url = f"http://{ip}:{port}" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.process.kill() diff --git a/python/kvikio/pyproject.toml b/python/kvikio/pyproject.toml index e958b9fb36..04f04cfa6f 100644 --- a/python/kvikio/pyproject.toml +++ b/python/kvikio/pyproject.toml @@ -43,6 +43,7 @@ test = [ "dask>=2022.05.2", "pytest", "pytest-cov", + "rangehttpserver", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] diff --git a/python/kvikio/tests/test_benchmarks.py b/python/kvikio/tests/test_benchmarks.py index 3bdaf6613e..5b5602e53a 100644 --- a/python/kvikio/tests/test_benchmarks.py +++ b/python/kvikio/tests/test_benchmarks.py @@ -8,6 +8,8 @@ import pytest +import kvikio + benchmarks_path = ( Path(os.path.realpath(__file__)).parent.parent / "kvikio" / "benchmarks" ) @@ -78,3 +80,32 @@ def test_zarr_io(run_cmd, tmp_path, api): cwd=benchmarks_path, ) assert retcode == 0 + + +@pytest.mark.parametrize( + "api", + [ + "cupy", + "numpy", + ], +) +def test_http_io(run_cmd, api): + """Test benchmarks/http_io.py""" + + if not kvikio.is_remote_file_available(): + pytest.skip( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ) + retcode = run_cmd( + cmd=[ + sys.executable, + "http_io.py", + "-n", + "1000", + "--api", + api, + ], + cwd=benchmarks_path, + ) + assert retcode == 0 diff --git a/python/kvikio/tests/test_examples.py b/python/kvikio/tests/test_examples.py index e9e1f83d08..07be1fc156 100644 --- a/python/kvikio/tests/test_examples.py +++ b/python/kvikio/tests/test_examples.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. import os @@ -7,6 +7,8 @@ import pytest +import kvikio + examples_path = Path(os.path.realpath(__file__)).parent / ".." / "examples" @@ -26,3 +28,16 @@ def test_zarr_cupy_nvcomp(tmp_path, monkeypatch): monkeypatch.syspath_prepend(str(examples_path)) import_module("zarr_cupy_nvcomp").main(tmp_path / "test-file") + + +def test_http_io(tmp_path, monkeypatch): + """Test examples/http_io.py""" + + if not kvikio.is_remote_file_available(): + pytest.skip( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ) + + monkeypatch.syspath_prepend(str(examples_path)) + import_module("http_io").main(tmp_path) diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py new file mode 100644 index 0000000000..70abec71b6 --- /dev/null +++ b/python/kvikio/tests/test_http_io.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + + +import numpy as np +import pytest + +import kvikio +import kvikio.defaults +from kvikio.utils import LocalHttpServer + +pytestmark = pytest.mark.skipif( + not kvikio.is_remote_file_available(), + reason=( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ), +) + + +@pytest.fixture +def http_server(request, tmpdir): + """Fixture to set up http server in separate process""" + range_support = True + if hasattr(request, "param"): + range_support = request.param.get("range_support", True) + + with LocalHttpServer(tmpdir, range_support, max_lifetime=60) as server: + yield server.url + + +def test_file_size(http_server, tmpdir): + a = np.arange(100) + a.tofile(tmpdir / "a") + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + + +@pytest.mark.parametrize("size", [10, 100, 1000]) +@pytest.mark.parametrize("nthreads", [1, 3]) +@pytest.mark.parametrize("tasksize", [99, 999]) +def test_read(http_server, tmpdir, xp, size, nthreads, tasksize): + a = xp.arange(size) + a.tofile(tmpdir / "a") + + with kvikio.defaults.set_num_threads(nthreads): + with kvikio.defaults.set_task_size(tasksize): + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + +@pytest.mark.parametrize("nthreads", [1, 10]) +def test_large_read(http_server, tmpdir, xp, nthreads): + a = xp.arange(16_000_000) + a.tofile(tmpdir / "a") + + with kvikio.defaults.set_num_threads(nthreads): + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + +def test_error_too_small_file(http_server, tmpdir, xp): + a = xp.arange(10, dtype="uint8") + b = xp.empty(100, dtype="uint8") + a.tofile(tmpdir / "a") + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + with pytest.raises( + ValueError, match=r"cannot read 0\+100 bytes into a 10 bytes file" + ): + f.read(b) + with pytest.raises( + ValueError, match=r"cannot read 100\+5 bytes into a 10 bytes file" + ): + f.read(b, size=5, file_offset=100) + + +@pytest.mark.parametrize("http_server", [{"range_support": False}], indirect=True) +def test_no_range_support(http_server, tmpdir, xp): + a = xp.arange(100, dtype="uint8") + a.tofile(tmpdir / "a") + b = xp.empty_like(a) + with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: + assert f.nbytes() == a.nbytes + with pytest.raises( + OverflowError, match="maybe the server doesn't support file ranges?" + ): + f.read(b, size=10, file_offset=0) + with pytest.raises( + OverflowError, match="maybe the server doesn't support file ranges?" + ): + f.read(b, size=10, file_offset=10)