diff --git a/docs/source/developer_guide/guides/2_real_world_phishing.md b/docs/source/developer_guide/guides/2_real_world_phishing.md index 0e5ed2f5ef..71c132d608 100644 --- a/docs/source/developer_guide/guides/2_real_world_phishing.md +++ b/docs/source/developer_guide/guides/2_real_world_phishing.md @@ -406,7 +406,8 @@ Note that the tokenizer parameters and vocabulary hash file should exactly match At this point, we have a pipeline that reads in a set of records and pre-processes them with the metadata required for our classifier to make predictions. Our next step is to define a stage that applies a machine learning model to our `MessageMeta` object. To accomplish this, we will be using Morpheus' `TritonInferenceStage`. This stage will handle communication with the `phishing-bert-onnx` model, which we provided to the Triton Docker container via the `models` directory mount. -Next we will add a monitor stage to measure the inference rate as well as a filter stage to filter out any results below a probability threshold of `0.9`. +Next we will add a monitor stage to measure the inference rate: + ```python # Add an inference stage pipeline.add_stage( @@ -418,14 +419,17 @@ pipeline.add_stage( )) pipeline.add_stage(MonitorStage(config, description="Inference Rate", smoothing=0.001, unit="inf")) +``` -# Filter values lower than 0.9 -pipeline.add_stage(FilterDetectionsStage(config, threshold=0.9)) +Here we add a postprocessing stage that adds the probability score for `is_phishing`: + +```python +pipeline.add_stage(AddScoresStage(config, labels=["is_phishing"])) ``` Lastly, we will save our results to disk. For this purpose, we are using two stages that are often used in conjunction with each other: `SerializeStage` and `WriteToFileStage`. -The `SerializeStage` is used to include and exclude columns as desired in the output. Importantly, it also handles conversion from the `MultiMessage`-derived output type that is used by the `FilterDetectionsStage` to the `MessageMeta` class that is expected as input by the `WriteToFileStage`. +The `SerializeStage` is used to include and exclude columns as desired in the output. Importantly, it also handles conversion from the `MultiMessage`-derived output type to the `MessageMeta` class that is expected as input by the `WriteToFileStage`. The `WriteToFileStage` will append message data to the output file as messages are received. Note however that for performance reasons the `WriteToFileStage` does not flush its contents out to disk every time a message is received. Instead, it relies on the underlying [buffered output stream](https://gcc.gnu.org/onlinedocs/libstdc++/manual/streambufs.html) to flush as needed, and then will close the file handle on shutdown. @@ -456,7 +460,7 @@ from morpheus.stages.general.monitor_stage import MonitorStage from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage from morpheus.stages.input.file_source_stage import FileSourceStage from morpheus.stages.output.write_to_file_stage import WriteToFileStage -from morpheus.stages.postprocess.filter_detections_stage import FilterDetectionsStage +from morpheus.stages.postprocess.add_scores_stage import AddScoresStage from morpheus.stages.postprocess.serialize_stage import SerializeStage from morpheus.stages.preprocess.deserialize_stage import DeserializeStage from morpheus.stages.preprocess.preprocess_nlp_stage import PreprocessNLPStage @@ -522,8 +526,8 @@ def run_pipeline(): # Monitor the inference rate pipeline.add_stage(MonitorStage(config, description="Inference Rate", smoothing=0.001, unit="inf")) - # Filter values lower than 0.9 - pipeline.add_stage(FilterDetectionsStage(config, threshold=0.9)) + # Add probability score for is_phishing + pipeline.add_stage(AddScoresStage(config, labels=["is_phishing"])) # Write the to the output file pipeline.add_stage(SerializeStage(config)) @@ -550,7 +554,7 @@ morpheus --log_level=debug --plugin examples/developer_guide/2_1_real_world_phis preprocess --vocab_hash_file=data/bert-base-uncased-hash.txt --truncation=true --do_lower_case=true --add_special_tokens=false \ inf-triton --model_name=phishing-bert-onnx --server_url=localhost:8001 --force_convert_inputs=true \ monitor --description="Inference Rate" --smoothing=0.001 --unit=inf \ - filter --threshold=0.9 --filter_source=TENSOR \ + add-scores --label=is_phishing \ serialize \ to-file --filename=/tmp/detections.jsonlines --overwrite ``` diff --git a/docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md b/docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md index 81789e2956..38518c17e0 100644 --- a/docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md +++ b/docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md @@ -233,7 +233,7 @@ The `DFPFileToDataFrameStage` (examples/digital_fingerprinting/production/morphe | `parser_kwargs` | `dict` or `None` | Optional: additional keyword arguments to be passed into the `DataFrame` parser, currently this is going to be either [`pandas.read_csv`](https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html), [`pandas.read_json`](https://pandas.pydata.org/docs/reference/api/pandas.read_json.html) or [`pandas.read_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html) | | `cache_dir` | `str` | Optional: path to cache location, defaults to `./.cache/dfp` | -This stage is able to download and load data files concurrently by multiple methods. Currently supported methods are: `single_thread`, `multiprocess`, `dask`, and `dask_thread`. The method used is chosen by setting the {envvar}`MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable, and `dask_thread` is used by default, and `single_thread` effectively disables concurrent loading. +This stage is able to download and load data files concurrently by multiple methods. Currently supported methods are: `single_thread`, `dask`, and `dask_thread`. The method used is chosen by setting the {envvar}`MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable, and `dask_thread` is used by default, and `single_thread` effectively disables concurrent loading. This stage will cache the resulting `DataFrame` in `cache_dir`, since we are caching the `DataFrame`s and not the source files, a cache hit avoids the cost of parsing the incoming data. In the case of remote storage systems, such as S3, this avoids both parsing and a download on a cache hit. One consequence of this is that any change to the `schema` will require purging cached files in the `cache_dir` before those changes are visible. diff --git a/docs/source/loaders/core/file_to_df_loader.md b/docs/source/loaders/core/file_to_df_loader.md index 921154aafd..12c24f631c 100644 --- a/docs/source/loaders/core/file_to_df_loader.md +++ b/docs/source/loaders/core/file_to_df_loader.md @@ -17,7 +17,7 @@ limitations under the License. ## File to DataFrame Loader -[DataLoader](../../modules/core/data_loader.md) module is used to load data files content into a dataframe using custom loader function. This loader function can be configured to use different processing methods, such as single-threaded, multiprocess, dask, or dask_thread, as determined by the `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable. When download_method starts with "dask," a dask client is created to process the files, otherwise, a single thread or multiprocess is used. +[DataLoader](../../modules/core/data_loader.md) module is used to load data files content into a dataframe using custom loader function. This loader function can be configured to use different processing methods, such as single-threaded, dask, or dask_thread, as determined by the `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable. When download_method starts with "dask," a dask client is created to process the files, otherwise, a single thread is used. After processing, the resulting dataframe is cached using a hash of the file paths. This loader also has the ability to load file content from S3 buckets, in addition to loading data from the disk. diff --git a/examples/abp_pcap_detection/README.md b/examples/abp_pcap_detection/README.md index 8221178c46..fe27f5e70c 100644 --- a/examples/abp_pcap_detection/README.md +++ b/examples/abp_pcap_detection/README.md @@ -27,14 +27,13 @@ docker pull nvcr.io/nvidia/tritonserver:23.06-py3 ``` ##### Deploy Triton Inference Server - -Bind the provided `abp-pcap-xgb` directory to the docker container model repo at `/models`. - From the root of the Morpheus repo, navigate to the anomalous behavior profiling example directory: ```bash cd examples/abp_pcap_detection +``` -# Launch the container +The following creates the Triton container, mounts the `abp-pcap-xgb` directory to `/models/abp-pcap-xgb` in the Triton container, and starts the Triton server: +```bash docker run --rm --gpus=all -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/abp-pcap-xgb:/models/abp-pcap-xgb --name tritonserver nvcr.io/nvidia/tritonserver:23.06-py3 tritonserver --model-repository=/models --exit-on-error=false ``` diff --git a/examples/digital_fingerprinting/production/morpheus/benchmarks/README.md b/examples/digital_fingerprinting/production/morpheus/benchmarks/README.md index 7e80c973c8..39ad984193 100644 --- a/examples/digital_fingerprinting/production/morpheus/benchmarks/README.md +++ b/examples/digital_fingerprinting/production/morpheus/benchmarks/README.md @@ -93,7 +93,7 @@ Morpheus pipeline configurations for each workflow are managed using [pipelines_ When using the MRC SegmentModule in a pipeline, it will also require a module configuration which gets generated within the test. Additional information is included in the [Morpheus Pipeline with Modules](../../../../../docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md#morpheus-pipeline-with-modules) -To ensure the [file_to_df_loader.py](../../../../../morpheus/loaders/file_to_df_loader.py) utilizes the same type of downloading mechanism, set `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable with any one of given choices (`multiprocess`, `dask`, `dask thread`, `single thread`). +To ensure the [file_to_df_loader.py](../../../../../morpheus/loaders/file_to_df_loader.py) utilizes the same type of downloading mechanism, set `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable with any one of given choices (`dask`, `dask thread`, `single thread`). ``` export MORPHEUS_FILE_DOWNLOAD_TYPE=dask diff --git a/models/data/labels_phishing.txt b/models/data/labels_phishing.txt index 8eeea320f8..1a91ac21f3 100644 --- a/models/data/labels_phishing.txt +++ b/models/data/labels_phishing.txt @@ -1,2 +1,2 @@ -score -pred +not_phishing +is_phishing diff --git a/morpheus/_lib/include/morpheus/stages/kafka_source.hpp b/morpheus/_lib/include/morpheus/stages/kafka_source.hpp index eb4d17feaf..3770797d84 100644 --- a/morpheus/_lib/include/morpheus/stages/kafka_source.hpp +++ b/morpheus/_lib/include/morpheus/stages/kafka_source.hpp @@ -30,13 +30,16 @@ #include #include #include +#include #include #include // for apply, make_subscriber, observable_member, is_on_error<>::not_void, is_on_next_of<>::not_void, trace_activity #include // for size_t #include // for uuint32_t +#include #include #include +#include #include #include #include @@ -53,6 +56,17 @@ namespace morpheus { */ #pragma GCC visibility push(default) + +class KafkaOAuthCallback : public RdKafka::OAuthBearerTokenRefreshCb +{ + public: + KafkaOAuthCallback(const std::function()>& oauth_callback); + + void oauthbearer_token_refresh_cb(RdKafka::Handle* handle, const std::string& oauthbearer_config) override; + + private: + const std::function()>& m_oauth_callback; +}; /** * This class loads messages from the Kafka cluster by serving as a Kafka consumer. */ @@ -82,10 +96,11 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource config, - bool disable_commit = false, - bool disable_pre_filtering = false, - std::size_t stop_after = 0, - bool async_commits = true); + bool disable_commit = false, + bool disable_pre_filtering = false, + std::size_t stop_after = 0, + bool async_commits = true, + std::unique_ptr oauth_callback = nullptr); /** * @brief Construct a new Kafka Source Stage object @@ -106,10 +121,11 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource topics, uint32_t batch_timeout_ms, std::map config, - bool disable_commit = false, - bool disable_pre_filtering = false, - std::size_t stop_after = 0, - bool async_commits = true); + bool disable_commit = false, + bool disable_pre_filtering = false, + std::size_t stop_after = 0, + bool async_commits = true, + std::unique_ptr oauth_callback = nullptr); ~KafkaSourceStage() override = default; @@ -176,6 +192,8 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource m_oauth_callback; }; /****** KafkaSourceStageInferenceProxy**********************/ @@ -200,6 +218,7 @@ struct KafkaSourceStageInterfaceProxy * @param stop_after : Stops ingesting after emitting `stop_after` records (rows in the table). * Useful for testing. Disabled if `0` * @param async_commits : Asynchronously acknowledge consuming Kafka messages + * @param oauth_callback : Callback used when an OAuth token needs to be generated. */ static std::shared_ptr> init_with_single_topic( mrc::segment::Builder& builder, @@ -210,8 +229,9 @@ struct KafkaSourceStageInterfaceProxy std::map config, bool disable_commit, bool disable_pre_filtering, - std::size_t stop_after = 0, - bool async_commits = true); + std::size_t stop_after = 0, + bool async_commits = true, + std::optional oauth_callback = std::nullopt); /** * @brief Create and initialize a KafkaSourceStage, and return the result @@ -229,6 +249,7 @@ struct KafkaSourceStageInterfaceProxy * @param stop_after : Stops ingesting after emitting `stop_after` records (rows in the table). * Useful for testing. Disabled if `0` * @param async_commits : Asynchronously acknowledge consuming Kafka messages + * @param oauth_callback : Callback used when an OAuth token needs to be generated. */ static std::shared_ptr> init_with_multiple_topics( mrc::segment::Builder& builder, @@ -239,8 +260,19 @@ struct KafkaSourceStageInterfaceProxy std::map config, bool disable_commit, bool disable_pre_filtering, - std::size_t stop_after = 0, - bool async_commits = true); + std::size_t stop_after = 0, + bool async_commits = true, + std::optional oauth_callback = std::nullopt); + + private: + /** + * @brief Create a KafkaOAuthCallback or return nullptr. If oauth_callback is std::nullopt, + * returns nullptr, otherwise wraps the callback in a KafkaOAuthCallback such that the values + * returned from the python callback are converted for use in c++. + * @param oauth_callback : The callback to wrap, if any. + */ + static std::unique_ptr make_kafka_oauth_callback( + std::optional&& oauth_callback); }; #pragma GCC visibility pop /** @} */ // end of group diff --git a/morpheus/_lib/src/stages/kafka_source.cpp b/morpheus/_lib/src/stages/kafka_source.cpp index 9c490040a9..7d97396447 100644 --- a/morpheus/_lib/src/stages/kafka_source.cpp +++ b/morpheus/_lib/src/stages/kafka_source.cpp @@ -21,6 +21,7 @@ #include "mrc/node/rx_source_base.hpp" #include "mrc/node/source_properties.hpp" #include "mrc/segment/object.hpp" +#include "pymrc/utilities/function_wrappers.hpp" // for PyFuncWrapper #include "morpheus/messages/meta.hpp" #include "morpheus/utilities/stage_util.hpp" @@ -35,6 +36,8 @@ #include #include // for SharedFuture #include +#include +#include #include #include // for find, min, transform @@ -45,11 +48,14 @@ #include #include // for initializer_list #include // for back_insert_iterator, back_inserter +#include #include #include +#include #include #include #include +#include #include // IWYU thinks we need atomic for vector.emplace_back of a unique_ptr // and __alloc_traits<>::value_type for vector assignments @@ -74,6 +80,29 @@ #endif // DOXYGEN_SHOULD_SKIP_THIS namespace morpheus { + +KafkaOAuthCallback::KafkaOAuthCallback(const std::function()>& oauth_callback) : + m_oauth_callback(oauth_callback) +{} + +void KafkaOAuthCallback::oauthbearer_token_refresh_cb(RdKafka::Handle* handle, const std::string& oauthbearer_config) +{ + try + { + auto response = m_oauth_callback(); + // Build parameters to pass to librdkafka + std::string token = response["token"]; + int64_t token_lifetime_ms = std::stoll(response["token_expiration_in_epoch"]); + std::list extensions; // currently not supported + std::string errstr; + auto result = handle->oauthbearer_set_token(token, token_lifetime_ms, "kafka", extensions, errstr); + CHECK(result == RdKafka::ErrorCode::ERR_NO_ERROR) << "Error occurred while setting the oauthbearer token"; + } catch (std::exception ex) + { + LOG(FATAL) << "Exception occured oauth refresh: " << ex.what(); + } +} + // Component-private classes. // ************ KafkaSourceStage__UnsubscribedException**************// class KafkaSourceStageUnsubscribedException : public std::exception @@ -264,7 +293,8 @@ KafkaSourceStage::KafkaSourceStage(TensorIndex max_batch_size, bool disable_commit, bool disable_pre_filtering, std::size_t stop_after, - bool async_commits) : + bool async_commits, + std::unique_ptr oauth_callback) : PythonSource(build()), m_max_batch_size(max_batch_size), m_topics(std::vector{std::move(topic)}), @@ -273,7 +303,8 @@ KafkaSourceStage::KafkaSourceStage(TensorIndex max_batch_size, m_disable_commit(disable_commit), m_disable_pre_filtering(disable_pre_filtering), m_stop_after{stop_after}, - m_async_commits(async_commits) + m_async_commits(async_commits), + m_oauth_callback(std::move(oauth_callback)) {} KafkaSourceStage::KafkaSourceStage(TensorIndex max_batch_size, @@ -283,7 +314,8 @@ KafkaSourceStage::KafkaSourceStage(TensorIndex max_batch_size, bool disable_commit, bool disable_pre_filtering, std::size_t stop_after, - bool async_commits) : + bool async_commits, + std::unique_ptr oauth_callback) : PythonSource(build()), m_max_batch_size(max_batch_size), m_topics(std::move(topics)), @@ -292,7 +324,8 @@ KafkaSourceStage::KafkaSourceStage(TensorIndex max_batch_size, m_disable_commit(disable_commit), m_disable_pre_filtering(disable_pre_filtering), m_stop_after{stop_after}, - m_async_commits(async_commits) + m_async_commits(async_commits), + m_oauth_callback(std::move(oauth_callback)) {} KafkaSourceStage::subscriber_fn_t KafkaSourceStage::build() @@ -453,6 +486,15 @@ std::unique_ptr KafkaSourceStage::create_consumer(RdKafk LOG(FATAL) << "Error occurred while setting Kafka rebalance function. Error: " << errstr; } + if (m_oauth_callback != nullptr) + { + if (RdKafka::Conf::ConfResult::CONF_OK != + kafka_conf->set("oauthbearer_token_refresh_cb", m_oauth_callback.get(), errstr)) + { + LOG(FATAL) << "Error occurred while setting Kafka OAuth Callback function. Error: " << errstr; + } + } + auto consumer = std::unique_ptr(RdKafka::KafkaConsumer::create(kafka_conf.get(), errstr)); if (!consumer) @@ -619,8 +661,11 @@ std::shared_ptr> KafkaSourceStageInterfac bool disable_commit, bool disable_pre_filtering, std::size_t stop_after, - bool async_commits) + bool async_commits, + std::optional oauth_callback) { + auto oauth_callback_cpp = KafkaSourceStageInterfaceProxy::make_kafka_oauth_callback(std::move(oauth_callback)); + auto stage = builder.construct_object(name, max_batch_size, topic, @@ -629,7 +674,8 @@ std::shared_ptr> KafkaSourceStageInterfac disable_commit, disable_pre_filtering, stop_after, - async_commits); + async_commits, + std::move(oauth_callback_cpp)); return stage; } @@ -644,8 +690,11 @@ std::shared_ptr> KafkaSourceStageInterfac bool disable_commit, bool disable_pre_filtering, std::size_t stop_after, - bool async_commits) + bool async_commits, + std::optional oauth_callback) { + auto oauth_callback_cpp = KafkaSourceStageInterfaceProxy::make_kafka_oauth_callback(std::move(oauth_callback)); + auto stage = builder.construct_object(name, max_batch_size, topics, @@ -654,8 +703,31 @@ std::shared_ptr> KafkaSourceStageInterfac disable_commit, disable_pre_filtering, stop_after, - async_commits); + async_commits, + std::move(oauth_callback_cpp)); return stage; } + +std::unique_ptr KafkaSourceStageInterfaceProxy::make_kafka_oauth_callback( + std::optional&& oauth_callback) +{ + if (oauth_callback == std::nullopt) + { + return static_cast>(nullptr); + } + + auto oauth_callback_wrapped = mrc::pymrc::PyFuncWrapper(std::move(oauth_callback.value())); + + return std::make_unique([oauth_callback_wrapped = std::move(oauth_callback_wrapped)]() { + auto kvp_cpp = std::map(); + auto kvp = oauth_callback_wrapped.operator()(); + for (auto [key, value] : kvp) + { + kvp_cpp[key.cast()] = value.cast(); + } + return kvp_cpp; + }); +} + } // namespace morpheus diff --git a/morpheus/_lib/stages/__init__.pyi b/morpheus/_lib/stages/__init__.pyi index d24f1b7d68..a9e8ac6445 100644 --- a/morpheus/_lib/stages/__init__.pyi +++ b/morpheus/_lib/stages/__init__.pyi @@ -54,9 +54,9 @@ class InferenceClientStage(mrc.core.segment.SegmentObject): pass class KafkaSourceStage(mrc.core.segment.SegmentObject): @typing.overload - def __init__(self, builder: mrc.core.segment.Builder, name: str, max_batch_size: int, topic: str, batch_timeout_ms: int, config: typing.Dict[str, str], disable_commits: bool = False, disable_pre_filtering: bool = False, stop_after: int = 0, async_commits: bool = True) -> None: ... + def __init__(self, builder: mrc.core.segment.Builder, name: str, max_batch_size: int, topic: str, batch_timeout_ms: int, config: typing.Dict[str, str], disable_commits: bool = False, disable_pre_filtering: bool = False, stop_after: int = 0, async_commits: bool = True, oauth_callback: typing.Optional[function] = None) -> None: ... @typing.overload - def __init__(self, builder: mrc.core.segment.Builder, name: str, max_batch_size: int, topics: typing.List[str], batch_timeout_ms: int, config: typing.Dict[str, str], disable_commits: bool = False, disable_pre_filtering: bool = False, stop_after: int = 0, async_commits: bool = True) -> None: ... + def __init__(self, builder: mrc.core.segment.Builder, name: str, max_batch_size: int, topics: typing.List[str], batch_timeout_ms: int, config: typing.Dict[str, str], disable_commits: bool = False, disable_pre_filtering: bool = False, stop_after: int = 0, async_commits: bool = True, oauth_callback: typing.Optional[function] = None) -> None: ... pass class PreallocateMessageMetaStage(mrc.core.segment.SegmentObject): def __init__(self, builder: mrc.core.segment.Builder, name: str, needed_columns: typing.List[typing.Tuple[str, morpheus._lib.common.TypeId]]) -> None: ... diff --git a/morpheus/_lib/stages/module.cpp b/morpheus/_lib/stages/module.cpp index b4f979a142..dd034f2cbb 100644 --- a/morpheus/_lib/stages/module.cpp +++ b/morpheus/_lib/stages/module.cpp @@ -150,7 +150,8 @@ PYBIND11_MODULE(stages, _module) py::arg("disable_commits") = false, py::arg("disable_pre_filtering") = false, py::arg("stop_after") = 0, - py::arg("async_commits") = true) + py::arg("async_commits") = true, + py::arg("oauth_callback") = py::none()) .def(py::init<>(&KafkaSourceStageInterfaceProxy::init_with_multiple_topics), py::arg("builder"), py::arg("name"), @@ -161,7 +162,8 @@ PYBIND11_MODULE(stages, _module) py::arg("disable_commits") = false, py::arg("disable_pre_filtering") = false, py::arg("stop_after") = 0, - py::arg("async_commits") = true); + py::arg("async_commits") = true, + py::arg("oauth_callback") = py::none()); py::class_>, mrc::segment::ObjectProperties, diff --git a/morpheus/loaders/file_to_df_loader.py b/morpheus/loaders/file_to_df_loader.py index ff69d89366..3915ebbd0f 100644 --- a/morpheus/loaders/file_to_df_loader.py +++ b/morpheus/loaders/file_to_df_loader.py @@ -34,9 +34,9 @@ def file_to_df_loader(control_message: ControlMessage, task: dict): """ This function is used to load files containing data into a dataframe. Dataframe is created by - processing files either using a single thread, multiprocess, dask, or dask_thread. This function determines + processing files either using a single thread, dask, or dask_thread. This function determines the download method to use, and if it starts with "dask," it creates a dask client and uses it to process the files. - Otherwise, it uses a single thread or multiprocess to process the files. This function then caches the resulting + Otherwise, it uses a single thread to process the files. This function then caches the resulting dataframe using a hash of the file paths. The dataframe is wrapped in a MessageMeta and then attached as a payload to a ControlMessage object and passed on to further stages. diff --git a/morpheus/utils/downloader.py b/morpheus/utils/downloader.py index 722c2387b4..d2882afa93 100644 --- a/morpheus/utils/downloader.py +++ b/morpheus/utils/downloader.py @@ -33,8 +33,6 @@ class DownloadMethods(str, Enum): """Valid download methods for the `Downloader` class.""" SINGLE_THREAD = "single_thread" - MULTIPROCESS = "multiprocess" - MULTIPROCESSING = "multiprocessing" DASK = "dask" DASK_THREAD = "dask_thread" @@ -45,14 +43,12 @@ class DownloadMethods(str, Enum): class Downloader: """ Downloads a list of `fsspec.core.OpenFiles` files using one of the following methods: - single_thread, multiprocess, dask or dask_thread + single_thread, dask or dask_thread The download method can be passed in via the `download_method` parameter or via the `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable. If both are set, the environment variable takes precedence, by default `dask_thread` is used. - When using single_thread, or multiprocess is used `dask` and `dask.distributed` is not reuiqrred to be installed. - - For compatibility reasons "multiprocessing" is an alias for "multiprocess". + When using single_thread, `dask` and `dask.distributed` is not reuiqrred to be installed. Parameters ---------- @@ -78,6 +74,10 @@ def __init__(self, download_method = os.environ.get("MORPHEUS_FILE_DOWNLOAD_TYPE", download_method) if isinstance(download_method, str): + if (download_method in ("multiprocess", "multiprocessing")): + raise ValueError( + f"The '{download_method}' download method is no longer supported. Please use 'dask' or " + "'single_thread' instead.") try: download_method = DOWNLOAD_METHODS_MAP[download_method.lower()] except KeyError as exc: @@ -165,10 +165,6 @@ def download(self, dfs = dist.client.map(download_fn, download_buckets) dfs = dist.client.gather(dfs) - elif (self._download_method in ("multiprocess", "multiprocessing")): - # Use multiprocessing here since parallel downloads are a pain - with mp.get_context("spawn").Pool(mp.cpu_count()) as pool: - dfs = pool.map(download_fn, download_buckets) else: # Simply loop for open_file in download_buckets: diff --git a/tests/examples/digital_fingerprinting/test_dfp_file_to_df.py b/tests/examples/digital_fingerprinting/test_dfp_file_to_df.py index e675ce2de7..716707ecfd 100644 --- a/tests/examples/digital_fingerprinting/test_dfp_file_to_df.py +++ b/tests/examples/digital_fingerprinting/test_dfp_file_to_df.py @@ -101,11 +101,9 @@ def test_constructor(config: Config): # pylint: disable=redefined-outer-name @pytest.mark.reload_modules(morpheus.utils.downloader) -@pytest.mark.usefixtures("restore_environ") -@pytest.mark.parametrize('dl_type', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"]) +@pytest.mark.usefixtures("reload_modules", "restore_environ") +@pytest.mark.parametrize('dl_type', ["single_thread", "dask", "dask_thread"]) @pytest.mark.parametrize('use_convert_to_dataframe', [True, False]) -@pytest.mark.usefixtures("reload_modules") -@mock.patch('multiprocessing.get_context') @mock.patch('dask.distributed.Client') @mock.patch('dask_cuda.LocalCUDACluster') @mock.patch('morpheus.controllers.file_to_df_controller.single_object_to_dataframe') @@ -116,7 +114,6 @@ def test_get_or_create_dataframe_from_batch_cache_miss(mock_proc_df: mock.MagicM mock_obf_to_df: mock.MagicMock, mock_dask_cluster: mock.MagicMock, mock_dask_client: mock.MagicMock, - mock_mp_gc: mock.MagicMock, config: Config, dl_type: str, use_convert_to_dataframe: bool, @@ -136,13 +133,6 @@ def test_get_or_create_dataframe_from_batch_cache_miss(mock_proc_df: mock.MagicM mock_distributed.__enter__.return_value = mock_distributed mock_distributed.__exit__.return_value = False - mock_mp_gc.return_value = mock_mp_gc - mock_mp_pool = mock.MagicMock() - mock_mp_gc.Pool.return_value = mock_mp_pool - mock_mp_pool.return_value = mock_mp_pool - mock_mp_pool.__enter__.return_value = mock_mp_pool - mock_mp_pool.__exit__.return_value = False - expected_hash = hashlib.md5(json.dumps([{ 'ukey': single_file_obj.fs.ukey(single_file_obj.path) }]).encode()).hexdigest() @@ -161,8 +151,6 @@ def test_get_or_create_dataframe_from_batch_cache_miss(mock_proc_df: mock.MagicM if dl_type.startswith('dask'): mock_dist_client.map.return_value = [returned_df] mock_dist_client.gather.return_value = [returned_df] - elif dl_type in ("multiprocess", "multiprocessing"): - mock_mp_pool.map.return_value = [returned_df] else: mock_obf_to_df.return_value = returned_df @@ -179,13 +167,6 @@ def test_get_or_create_dataframe_from_batch_cache_miss(mock_proc_df: mock.MagicM (output_df, cache_hit) = stage._controller._get_or_create_dataframe_from_batch((batch, 1)) assert not cache_hit - if dl_type in ("multiprocess", "multiprocessing"): - mock_mp_gc.assert_called_once() - mock_mp_pool.map.assert_called_once() - else: - mock_mp_gc.assert_not_called() - mock_mp_pool.map.assert_not_called() - if dl_type == "single_thread": mock_obf_to_df.assert_called_once() else: @@ -209,9 +190,8 @@ def test_get_or_create_dataframe_from_batch_cache_miss(mock_proc_df: mock.MagicM @pytest.mark.usefixtures("restore_environ") -@pytest.mark.parametrize('dl_type', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"]) +@pytest.mark.parametrize('dl_type', ["single_thread", "dask", "dask_thread"]) @pytest.mark.parametrize('use_convert_to_dataframe', [True, False]) -@mock.patch('multiprocessing.get_context') @mock.patch('dask.config') @mock.patch('dask.distributed.Client') @mock.patch('dask_cuda.LocalCUDACluster') @@ -220,7 +200,6 @@ def test_get_or_create_dataframe_from_batch_cache_hit(mock_obf_to_df: mock.Magic mock_dask_cluster: mock.MagicMock, mock_dask_client: mock.MagicMock, mock_dask_config: mock.MagicMock, - mock_mp_gc: mock.MagicMock, config: Config, dl_type: str, use_convert_to_dataframe: bool, @@ -233,13 +212,6 @@ def test_get_or_create_dataframe_from_batch_cache_hit(mock_obf_to_df: mock.Magic mock_dask_client.__enter__.return_value = mock_dask_client mock_dask_client.__exit__.return_value = False - mock_mp_gc.return_value = mock_mp_gc - mock_mp_pool = mock.MagicMock() - mock_mp_gc.Pool.return_value = mock_mp_pool - mock_mp_pool.return_value = mock_mp_pool - mock_mp_pool.__enter__.return_value = mock_mp_pool - mock_mp_pool.__exit__.return_value = False - file_specs = fsspec.open_files(os.path.abspath(os.path.join(TEST_DIRS.tests_data_dir, 'filter_probs.csv'))) # pylint: disable=no-member @@ -268,8 +240,6 @@ def test_get_or_create_dataframe_from_batch_cache_hit(mock_obf_to_df: mock.Magic assert cache_hit # When we get a cache hit, none of the download methods should be executed - mock_mp_gc.assert_not_called() - mock_mp_pool.map.assert_not_called() mock_obf_to_df.assert_not_called() mock_dask_cluster.assert_not_called() mock_dask_client.assert_not_called() @@ -279,9 +249,8 @@ def test_get_or_create_dataframe_from_batch_cache_hit(mock_obf_to_df: mock.Magic @pytest.mark.usefixtures("restore_environ") -@pytest.mark.parametrize('dl_type', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"]) +@pytest.mark.parametrize('dl_type', ["single_thread", "dask", "dask_thread"]) @pytest.mark.parametrize('use_convert_to_dataframe', [True, False]) -@mock.patch('multiprocessing.get_context') @mock.patch('dask.config') @mock.patch('dask.distributed.Client') @mock.patch('dask_cuda.LocalCUDACluster') @@ -290,7 +259,6 @@ def test_get_or_create_dataframe_from_batch_none_noop(mock_obf_to_df: mock.Magic mock_dask_cluster: mock.MagicMock, mock_dask_client: mock.MagicMock, mock_dask_config: mock.MagicMock, - mock_mp_gc: mock.MagicMock, config: Config, dl_type: str, use_convert_to_dataframe: bool, @@ -299,10 +267,6 @@ def test_get_or_create_dataframe_from_batch_none_noop(mock_obf_to_df: mock.Magic mock_dask_cluster.return_value = mock_dask_cluster mock_dask_client.return_value = mock_dask_client - mock_mp_gc.return_value = mock_mp_gc - mock_mp_pool = mock.MagicMock() - mock_mp_gc.Pool.return_value = mock_mp_pool - os.environ['MORPHEUS_FILE_DOWNLOAD_TYPE'] = dl_type stage = DFPFileToDataFrameStage(config, DataFrameInputSchema(), cache_dir=tmp_path) if use_convert_to_dataframe: @@ -315,7 +279,5 @@ def test_get_or_create_dataframe_from_batch_none_noop(mock_obf_to_df: mock.Magic mock_dask_cluster.assert_not_called() mock_dask_client.assert_not_called() mock_dask_config.assert_not_called() - mock_mp_gc.assert_not_called() - mock_mp_pool.map.assert_not_called() assert os.listdir(tmp_path) == [] diff --git a/tests/test_cli.py b/tests/test_cli.py index 23408165ee..3cb0efea43 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -685,8 +685,8 @@ def test_pipeline_nlp(self, config, callback_values): '--truncation=True', '--do_lower_case=True', '--add_special_tokens=False' - ] + INF_TRITON_ARGS + MONITOR_ARGS + ['add-class', '--label=pred', '--threshold=0.7'] + VALIDATE_ARGS + - ['serialize'] + TO_FILE_ARGS) + ] + INF_TRITON_ARGS + MONITOR_ARGS + ['add-class', '--label=is_phishing', '--threshold=0.7'] + + VALIDATE_ARGS + ['serialize'] + TO_FILE_ARGS) obj = {} runner = CliRunner() @@ -696,7 +696,7 @@ def test_pipeline_nlp(self, config, callback_values): # Ensure our config is populated correctly config = obj["config"] assert config.mode == PipelineModes.NLP - assert config.class_labels == ["score", "pred"] + assert config.class_labels == ["not_phishing", "is_phishing"] assert config.feature_length == 128 assert config.ae is None @@ -731,7 +731,7 @@ def test_pipeline_nlp(self, config, callback_values): assert monitor._mc._unit == 'inf' assert isinstance(add_class, AddClassificationsStage) - assert add_class._labels == ('pred', ) + assert add_class._labels == ('is_phishing', ) assert add_class._threshold == 0.7 assert isinstance(validation, ValidationStage) @@ -781,8 +781,8 @@ def test_pipeline_nlp_all(self, config, callback_values, tmp_path, mlflow_uri): 'mlflow-drift', '--tracking_uri', mlflow_uri - ] + INF_TRITON_ARGS + MONITOR_ARGS + ['add-class', '--label=pred', '--threshold=0.7'] + VALIDATE_ARGS + - ['serialize'] + TO_FILE_ARGS + TO_KAFKA_ARGS) + ] + INF_TRITON_ARGS + MONITOR_ARGS + ['add-class', '--label=is_phishing', '--threshold=0.7'] + + VALIDATE_ARGS + ['serialize'] + TO_FILE_ARGS + TO_KAFKA_ARGS) obj = {} runner = CliRunner() @@ -792,7 +792,7 @@ def test_pipeline_nlp_all(self, config, callback_values, tmp_path, mlflow_uri): # Ensure our config is populated correctly config = obj["config"] assert config.mode == PipelineModes.NLP - assert config.class_labels == ["score", "pred"] + assert config.class_labels == ["not_phishing", "is_phishing"] assert config.feature_length == 128 assert config.ae is None @@ -864,7 +864,7 @@ def test_pipeline_nlp_all(self, config, callback_values, tmp_path, mlflow_uri): assert monitor._mc._unit == 'inf' assert isinstance(add_class, AddClassificationsStage) - assert add_class._labels == ('pred', ) + assert add_class._labels == ('is_phishing', ) assert add_class._threshold == 0.7 assert isinstance(validation, ValidationStage) diff --git a/tests/test_downloader.py b/tests/test_downloader.py index 015af44457..f5c67a7afd 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -48,7 +48,7 @@ def dask_cuda(fail_missing: bool): @pytest.mark.usefixtures("restore_environ") @pytest.mark.parametrize('use_env', [True, False]) -@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"]) +@pytest.mark.parametrize('dl_method', ["single_thread", "dask", "dask_thread"]) def test_constructor_download_type(use_env: bool, dl_method: str): kwargs = {} if use_env: @@ -67,12 +67,11 @@ def test_constructor_enum_vals(dl_method: DownloadMethods): @pytest.mark.usefixtures("restore_environ") -@pytest.mark.parametrize('dl_method', - [DownloadMethods.SINGLE_THREAD, DownloadMethods.DASK, DownloadMethods.DASK_THREAD]) +@pytest.mark.parametrize('dl_method', [DownloadMethods.DASK, DownloadMethods.DASK_THREAD]) def test_constructor_env_wins(dl_method: DownloadMethods): - os.environ['MORPHEUS_FILE_DOWNLOAD_TYPE'] = "multiprocessing" + os.environ['MORPHEUS_FILE_DOWNLOAD_TYPE'] = "single_thread" downloader = Downloader(download_method=dl_method) - assert downloader.download_method == DownloadMethods.MULTIPROCESSING + assert downloader.download_method == DownloadMethods.SINGLE_THREAD @pytest.mark.usefixtures("restore_environ") @@ -119,7 +118,7 @@ def test_close(mock_dask_cluster: mock.MagicMock, dl_method: str): @mock.patch('dask_cuda.LocalCUDACluster') -@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing"]) +@pytest.mark.parametrize('dl_method', ["single_thread"]) def test_close_noop(mock_dask_cluster: mock.MagicMock, dl_method: str): mock_dask_cluster.return_value = mock_dask_cluster downloader = Downloader(download_method=dl_method) @@ -133,15 +132,13 @@ def test_close_noop(mock_dask_cluster: mock.MagicMock, dl_method: str): @pytest.mark.reload_modules(morpheus.utils.downloader) @pytest.mark.usefixtures("reload_modules", "restore_environ") -@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"]) -@mock.patch('multiprocessing.get_context') +@pytest.mark.parametrize('dl_method', ["single_thread", "dask", "dask_thread"]) @mock.patch('dask.config') @mock.patch('dask.distributed.Client') @mock.patch('dask_cuda.LocalCUDACluster') def test_download(mock_dask_cluster: mock.MagicMock, mock_dask_client: mock.MagicMock, mock_dask_config: mock.MagicMock, - mock_mp_gc: mock.MagicMock, dl_method: str): mock_dask_config.get = lambda key: 1.0 if (key == "distributed.comm.timesouts.connect") else None mock_dask_cluster.return_value = mock_dask_cluster @@ -149,13 +146,6 @@ def test_download(mock_dask_cluster: mock.MagicMock, mock_dask_client.__enter__.return_value = mock_dask_client mock_dask_client.__exit__.return_value = False - mock_mp_gc.return_value = mock_mp_gc - mock_mp_pool = mock.MagicMock() - mock_mp_gc.Pool.return_value = mock_mp_pool - mock_mp_pool.return_value = mock_mp_pool - mock_mp_pool.__enter__.return_value = mock_mp_pool - mock_mp_pool.__exit__.return_value = False - input_glob = os.path.join(TEST_DIRS.tests_data_dir, 'appshield/snapshot-1/*.json') download_buckets = fsspec.open_files(input_glob) num_buckets = len(download_buckets) @@ -165,8 +155,6 @@ def test_download(mock_dask_cluster: mock.MagicMock, returnd_df = mock.MagicMock() if dl_method.startswith('dask'): mock_dask_client.gather.return_value = [returnd_df for _ in range(num_buckets)] - elif dl_method in ("multiprocess", "multiprocessing"): - mock_mp_pool.map.return_value = [returnd_df for _ in range(num_buckets)] else: download_fn.return_value = returnd_df @@ -175,13 +163,6 @@ def test_download(mock_dask_cluster: mock.MagicMock, results = downloader.download(download_buckets, download_fn) assert results == [returnd_df for _ in range(num_buckets)] - if dl_method in ("multiprocess", "multiprocessing"): - mock_mp_gc.assert_called_once() - mock_mp_pool.map.assert_called_once() - else: - mock_mp_gc.assert_not_called() - mock_mp_pool.map.assert_not_called() - if dl_method == "single_thread": download_fn.assert_has_calls([mock.call(bucket) for bucket in download_buckets]) else: @@ -195,3 +176,19 @@ def test_download(mock_dask_cluster: mock.MagicMock, mock_dask_cluster.assert_not_called() mock_dask_client.assert_not_called() mock_dask_config.assert_not_called() + + +@pytest.mark.usefixtures("restore_environ") +@pytest.mark.parametrize('use_env', [True, False]) +@pytest.mark.parametrize('dl_method', ["multiprocess", "multiprocessing"]) +def test_constructor_multiproc_dltype_not_supported(use_env: bool, dl_method: str): + kwargs = {} + if use_env: + os.environ['MORPHEUS_FILE_DOWNLOAD_TYPE'] = dl_method + else: + kwargs['download_method'] = dl_method + + with pytest.raises(ValueError) as excinfo: + Downloader(**kwargs) + + assert "no longer supported" in str(excinfo.value)