Skip to content

Commit

Permalink
Add a stage to create a unique output field by concatenating the time…
Browse files Browse the repository at this point in the history
…stamp and pid_process columns.

Add a `--pipeline_batch_size` flag defaulting to the `model_max_batch_size` avoiding a config warning
Add type hints for `run_pipeline` function arguments
  • Loading branch information
dagardner-nv committed Nov 7, 2024
1 parent 8df8108 commit db07ec1
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions examples/ransomware_detection/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
from stages.create_features import CreateFeaturesRWStage
from stages.preprocessing import PreprocessingRWStage

from morpheus.common import TypeId
from morpheus.config import Config
from morpheus.config import PipelineModes
from morpheus.messages import MessageMeta
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import stage
from morpheus.stages.general.monitor_stage import MonitorStage
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
from morpheus.stages.input.appshield_source_stage import AppShieldSourceStage
Expand Down Expand Up @@ -61,6 +64,12 @@
type=click.IntRange(min=1),
help="Max batch size to use for the model.",
)
@click.option(
"--pipeline_batch_size",
default=1024,
type=click.IntRange(min=1),
help=("Internal batch size for the pipeline. Can be much larger than the model batch size."),
)
@click.option(
"--conf_file",
type=click.STRING,
Expand Down Expand Up @@ -98,18 +107,19 @@
default="./ransomware_detection_output.jsonlines",
help="The path to the file where the inference output will be saved.",
)
def run_pipeline(debug,
num_threads,
n_dask_workers,
threads_per_dask_worker,
model_max_batch_size,
conf_file,
model_name,
server_url,
sliding_window,
input_glob,
watch_directory,
output_file):
def run_pipeline(debug: bool,
num_threads: int,
n_dask_workers: int,
threads_per_dask_worker: int,
model_max_batch_size: int,
pipeline_batch_size: int,
conf_file: str,
model_name: str,
server_url: str,
sliding_window: int,
input_glob: str,
watch_directory: bool,
output_file: str):

if debug:
configure_logging(log_level=logging.DEBUG)
Expand All @@ -125,6 +135,7 @@ def run_pipeline(debug,
# Below properties are specified by the command line.
config.num_threads = num_threads
config.model_max_batch_size = model_max_batch_size
config.pipeline_batch_size = pipeline_batch_size
config.feature_length = snapshot_fea_length * sliding_window
config.class_labels = ["pred", "score"]

Expand Down Expand Up @@ -222,6 +233,18 @@ def run_pipeline(debug,
# This stage logs the metrics (msg/sec) from the above stage.
pipeline.add_stage(MonitorStage(config, description="Serialize rate"))

@stage(needed_columns={'timestamp_process': TypeId.STRING})
def concat_columns(msg: MessageMeta) -> MessageMeta:
"""
This stage concatinates the timestamp and pid_process columns to create a unique field.
"""
with msg.mutable_dataframe() as df:
df['timestamp_process'] = df['timestamp'] + df['pid_process']

return msg

pipeline.add_stage(concat_columns(config))

# Add a write file stage.
# This stage writes all messages to a file.
pipeline.add_stage(WriteToFileStage(config, filename=output_file, overwrite=True))
Expand Down

0 comments on commit db07ec1

Please sign in to comment.