Skip to content

Commit

Permalink
Merge branch 'branch-23.11' into david-mlflow-27-1199
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv authored Sep 25, 2023
2 parents e4d7720 + 23e3e59 commit eead924
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 128 deletions.
20 changes: 12 additions & 8 deletions docs/source/developer_guide/guides/2_real_world_phishing.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ Note that the tokenizer parameters and vocabulary hash file should exactly match

At this point, we have a pipeline that reads in a set of records and pre-processes them with the metadata required for our classifier to make predictions. Our next step is to define a stage that applies a machine learning model to our `MessageMeta` object. To accomplish this, we will be using Morpheus' `TritonInferenceStage`. This stage will handle communication with the `phishing-bert-onnx` model, which we provided to the Triton Docker container via the `models` directory mount.

Next we will add a monitor stage to measure the inference rate as well as a filter stage to filter out any results below a probability threshold of `0.9`.
Next we will add a monitor stage to measure the inference rate:

```python
# Add an inference stage
pipeline.add_stage(
Expand All @@ -418,14 +419,17 @@ pipeline.add_stage(
))

pipeline.add_stage(MonitorStage(config, description="Inference Rate", smoothing=0.001, unit="inf"))
```

# Filter values lower than 0.9
pipeline.add_stage(FilterDetectionsStage(config, threshold=0.9))
Here we add a postprocessing stage that adds the probability score for `is_phishing`:

```python
pipeline.add_stage(AddScoresStage(config, labels=["is_phishing"]))
```

Lastly, we will save our results to disk. For this purpose, we are using two stages that are often used in conjunction with each other: `SerializeStage` and `WriteToFileStage`.

The `SerializeStage` is used to include and exclude columns as desired in the output. Importantly, it also handles conversion from the `MultiMessage`-derived output type that is used by the `FilterDetectionsStage` to the `MessageMeta` class that is expected as input by the `WriteToFileStage`.
The `SerializeStage` is used to include and exclude columns as desired in the output. Importantly, it also handles conversion from the `MultiMessage`-derived output type to the `MessageMeta` class that is expected as input by the `WriteToFileStage`.

The `WriteToFileStage` will append message data to the output file as messages are received. Note however that for performance reasons the `WriteToFileStage` does not flush its contents out to disk every time a message is received. Instead, it relies on the underlying [buffered output stream](https://gcc.gnu.org/onlinedocs/libstdc++/manual/streambufs.html) to flush as needed, and then will close the file handle on shutdown.

Expand Down Expand Up @@ -456,7 +460,7 @@ from morpheus.stages.general.monitor_stage import MonitorStage
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
from morpheus.stages.input.file_source_stage import FileSourceStage
from morpheus.stages.output.write_to_file_stage import WriteToFileStage
from morpheus.stages.postprocess.filter_detections_stage import FilterDetectionsStage
from morpheus.stages.postprocess.add_scores_stage import AddScoresStage
from morpheus.stages.postprocess.serialize_stage import SerializeStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.stages.preprocess.preprocess_nlp_stage import PreprocessNLPStage
Expand Down Expand Up @@ -522,8 +526,8 @@ def run_pipeline():
# Monitor the inference rate
pipeline.add_stage(MonitorStage(config, description="Inference Rate", smoothing=0.001, unit="inf"))

# Filter values lower than 0.9
pipeline.add_stage(FilterDetectionsStage(config, threshold=0.9))
# Add probability score for is_phishing
pipeline.add_stage(AddScoresStage(config, labels=["is_phishing"]))

# Write the to the output file
pipeline.add_stage(SerializeStage(config))
Expand All @@ -550,7 +554,7 @@ morpheus --log_level=debug --plugin examples/developer_guide/2_1_real_world_phis
preprocess --vocab_hash_file=data/bert-base-uncased-hash.txt --truncation=true --do_lower_case=true --add_special_tokens=false \
inf-triton --model_name=phishing-bert-onnx --server_url=localhost:8001 --force_convert_inputs=true \
monitor --description="Inference Rate" --smoothing=0.001 --unit=inf \
filter --threshold=0.9 --filter_source=TENSOR \
add-scores --label=is_phishing \
serialize \
to-file --filename=/tmp/detections.jsonlines --overwrite
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ The `DFPFileToDataFrameStage` (examples/digital_fingerprinting/production/morphe
| `parser_kwargs` | `dict` or `None` | Optional: additional keyword arguments to be passed into the `DataFrame` parser, currently this is going to be either [`pandas.read_csv`](https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html), [`pandas.read_json`](https://pandas.pydata.org/docs/reference/api/pandas.read_json.html) or [`pandas.read_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html) |
| `cache_dir` | `str` | Optional: path to cache location, defaults to `./.cache/dfp` |

This stage is able to download and load data files concurrently by multiple methods. Currently supported methods are: `single_thread`, `multiprocess`, `dask`, and `dask_thread`. The method used is chosen by setting the {envvar}`MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable, and `dask_thread` is used by default, and `single_thread` effectively disables concurrent loading.
This stage is able to download and load data files concurrently by multiple methods. Currently supported methods are: `single_thread`, `dask`, and `dask_thread`. The method used is chosen by setting the {envvar}`MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable, and `dask_thread` is used by default, and `single_thread` effectively disables concurrent loading.

This stage will cache the resulting `DataFrame` in `cache_dir`, since we are caching the `DataFrame`s and not the source files, a cache hit avoids the cost of parsing the incoming data. In the case of remote storage systems, such as S3, this avoids both parsing and a download on a cache hit. One consequence of this is that any change to the `schema` will require purging cached files in the `cache_dir` before those changes are visible.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/loaders/core/file_to_df_loader.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.

## File to DataFrame Loader

[DataLoader](../../modules/core/data_loader.md) module is used to load data files content into a dataframe using custom loader function. This loader function can be configured to use different processing methods, such as single-threaded, multiprocess, dask, or dask_thread, as determined by the `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable. When download_method starts with "dask," a dask client is created to process the files, otherwise, a single thread or multiprocess is used.
[DataLoader](../../modules/core/data_loader.md) module is used to load data files content into a dataframe using custom loader function. This loader function can be configured to use different processing methods, such as single-threaded, dask, or dask_thread, as determined by the `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable. When download_method starts with "dask," a dask client is created to process the files, otherwise, a single thread is used.

After processing, the resulting dataframe is cached using a hash of the file paths. This loader also has the ability to load file content from S3 buckets, in addition to loading data from the disk.

Expand Down
7 changes: 3 additions & 4 deletions examples/abp_pcap_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ docker pull nvcr.io/nvidia/tritonserver:23.06-py3
```

##### Deploy Triton Inference Server

Bind the provided `abp-pcap-xgb` directory to the docker container model repo at `/models`.

From the root of the Morpheus repo, navigate to the anomalous behavior profiling example directory:
```bash
cd examples/abp_pcap_detection
```

# Launch the container
The following creates the Triton container, mounts the `abp-pcap-xgb` directory to `/models/abp-pcap-xgb` in the Triton container, and starts the Triton server:
```bash
docker run --rm --gpus=all -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/abp-pcap-xgb:/models/abp-pcap-xgb --name tritonserver nvcr.io/nvidia/tritonserver:23.06-py3 tritonserver --model-repository=/models --exit-on-error=false
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Morpheus pipeline configurations for each workflow are managed using [pipelines_

When using the MRC SegmentModule in a pipeline, it will also require a module configuration which gets generated within the test. Additional information is included in the [Morpheus Pipeline with Modules](../../../../../docs/source/developer_guide/guides/6_digital_fingerprinting_reference.md#morpheus-pipeline-with-modules)

To ensure the [file_to_df_loader.py](../../../../../morpheus/loaders/file_to_df_loader.py) utilizes the same type of downloading mechanism, set `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable with any one of given choices (`multiprocess`, `dask`, `dask thread`, `single thread`).
To ensure the [file_to_df_loader.py](../../../../../morpheus/loaders/file_to_df_loader.py) utilizes the same type of downloading mechanism, set `MORPHEUS_FILE_DOWNLOAD_TYPE` environment variable with any one of given choices (`dask`, `dask thread`, `single thread`).

```
export MORPHEUS_FILE_DOWNLOAD_TYPE=dask
Expand Down
4 changes: 2 additions & 2 deletions models/data/labels_phishing.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
score
pred
not_phishing
is_phishing
56 changes: 44 additions & 12 deletions morpheus/_lib/include/morpheus/stages/kafka_source.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
#include <mrc/segment/builder.hpp>
#include <mrc/segment/object.hpp>
#include <mrc/types.hpp>
#include <pybind11/pytypes.h>
#include <pymrc/node.hpp>
#include <rxcpp/rx.hpp> // for apply, make_subscriber, observable_member, is_on_error<>::not_void, is_on_next_of<>::not_void, trace_activity

#include <cstddef> // for size_t
#include <cstdint> // for uuint32_t
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <thread>
#include <vector>
Expand All @@ -53,6 +56,17 @@ namespace morpheus {
*/

#pragma GCC visibility push(default)

class KafkaOAuthCallback : public RdKafka::OAuthBearerTokenRefreshCb
{
public:
KafkaOAuthCallback(const std::function<std::map<std::string, std::string>()>& oauth_callback);

void oauthbearer_token_refresh_cb(RdKafka::Handle* handle, const std::string& oauthbearer_config) override;

private:
const std::function<std::map<std::string, std::string>()>& m_oauth_callback;
};
/**
* This class loads messages from the Kafka cluster by serving as a Kafka consumer.
*/
Expand Down Expand Up @@ -82,10 +96,11 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource<std::shared_ptr<Message
std::string topic,
uint32_t batch_timeout_ms,
std::map<std::string, std::string> config,
bool disable_commit = false,
bool disable_pre_filtering = false,
std::size_t stop_after = 0,
bool async_commits = true);
bool disable_commit = false,
bool disable_pre_filtering = false,
std::size_t stop_after = 0,
bool async_commits = true,
std::unique_ptr<KafkaOAuthCallback> oauth_callback = nullptr);

/**
* @brief Construct a new Kafka Source Stage object
Expand All @@ -106,10 +121,11 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource<std::shared_ptr<Message
std::vector<std::string> topics,
uint32_t batch_timeout_ms,
std::map<std::string, std::string> config,
bool disable_commit = false,
bool disable_pre_filtering = false,
std::size_t stop_after = 0,
bool async_commits = true);
bool disable_commit = false,
bool disable_pre_filtering = false,
std::size_t stop_after = 0,
bool async_commits = true,
std::unique_ptr<KafkaOAuthCallback> oauth_callback = nullptr);

~KafkaSourceStage() override = default;

Expand Down Expand Up @@ -176,6 +192,8 @@ class KafkaSourceStage : public mrc::pymrc::PythonSource<std::shared_ptr<Message
std::size_t m_stop_after{0};

void* m_rebalancer;

std::unique_ptr<KafkaOAuthCallback> m_oauth_callback;
};

/****** KafkaSourceStageInferenceProxy**********************/
Expand All @@ -200,6 +218,7 @@ struct KafkaSourceStageInterfaceProxy
* @param stop_after : Stops ingesting after emitting `stop_after` records (rows in the table).
* Useful for testing. Disabled if `0`
* @param async_commits : Asynchronously acknowledge consuming Kafka messages
* @param oauth_callback : Callback used when an OAuth token needs to be generated.
*/
static std::shared_ptr<mrc::segment::Object<KafkaSourceStage>> init_with_single_topic(
mrc::segment::Builder& builder,
Expand All @@ -210,8 +229,9 @@ struct KafkaSourceStageInterfaceProxy
std::map<std::string, std::string> config,
bool disable_commit,
bool disable_pre_filtering,
std::size_t stop_after = 0,
bool async_commits = true);
std::size_t stop_after = 0,
bool async_commits = true,
std::optional<pybind11::function> oauth_callback = std::nullopt);

/**
* @brief Create and initialize a KafkaSourceStage, and return the result
Expand All @@ -229,6 +249,7 @@ struct KafkaSourceStageInterfaceProxy
* @param stop_after : Stops ingesting after emitting `stop_after` records (rows in the table).
* Useful for testing. Disabled if `0`
* @param async_commits : Asynchronously acknowledge consuming Kafka messages
* @param oauth_callback : Callback used when an OAuth token needs to be generated.
*/
static std::shared_ptr<mrc::segment::Object<KafkaSourceStage>> init_with_multiple_topics(
mrc::segment::Builder& builder,
Expand All @@ -239,8 +260,19 @@ struct KafkaSourceStageInterfaceProxy
std::map<std::string, std::string> config,
bool disable_commit,
bool disable_pre_filtering,
std::size_t stop_after = 0,
bool async_commits = true);
std::size_t stop_after = 0,
bool async_commits = true,
std::optional<pybind11::function> oauth_callback = std::nullopt);

private:
/**
* @brief Create a KafkaOAuthCallback or return nullptr. If oauth_callback is std::nullopt,
* returns nullptr, otherwise wraps the callback in a KafkaOAuthCallback such that the values
* returned from the python callback are converted for use in c++.
* @param oauth_callback : The callback to wrap, if any.
*/
static std::unique_ptr<KafkaOAuthCallback> make_kafka_oauth_callback(
std::optional<pybind11::function>&& oauth_callback);
};
#pragma GCC visibility pop
/** @} */ // end of group
Expand Down
Loading

0 comments on commit eead924

Please sign in to comment.