diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0e0641f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +dmf_models/* +dmf_datasets/* +dmf_tokenizers/* +models/ +build/ +vllm_env +env.sh \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b4fa83c --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +*.egg-info +*.pyc +__pycache__ +.coverage +.coverage.* +durations/* +coverage*.xml +coverage-* +dist +htmlcov +build +test +training_output + +# IDEs +.vscode/ +.idea/ + +# Env files +.env + +# Virtual Env +venv/ +.venv/ + +# Mac personalization files +*.DS_Store + +# Tox envs +.tox diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..a20bcbf --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,10 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty=alog,aconfig +known_localfolder=vllm_detector_adapter,tests \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c26b962 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.1.2 + hooks: + - id: prettier + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + exclude: imports + - repo: https://github.com/PyCQA/isort + rev: 5.11.5 + hooks: + - id: isort + exclude: imports diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..1b8d15d --- /dev/null +++ b/.prettierignore @@ -0,0 +1,8 @@ +# Ignore artifacts +build +coverage-py + +*.jsonl +**/.github +**/*.html +*.md diff --git a/.whitesource b/.whitesource new file mode 100644 index 0000000..545628a --- /dev/null +++ b/.whitesource @@ -0,0 +1,3 @@ +{ + "settingsInheritedFrom": "whitesource-config/whitesource-config@master" +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..e64be67 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,59 @@ +# Contributing + +## Development + +### Set up your dev environment + +The following tools are required: + +- [git](https://git-scm.com) +- [python](https://www.python.org) (v3.11+) +- [pip](https://pypi.org/project/pip/) (v23.0+) + +You can setup your dev environment using [tox](https://tox.wiki/en/latest/), an environment orchestrator which allows for setting up environments for and invoking builds, unit tests, formatting, linting, etc. Install tox with: + +```sh +pip install -r setup_requirements.txt +``` + +If you want to manage your own virtual environment instead of using `tox`, you can install `vllm_detector_adapter` and all dependencies with: + +```sh +pip install . +``` + +### Unit tests + +Unit tests are enforced by the CI system. When making changes, run the tests before pushing the changes to avoid CI issues. + +Running unit tests against all supported Python versions is as simple as: + +```sh +tox +``` + +Running tests against a single Python version can be done with: + +```sh +tox -e py +``` + +### Coding style + +vllm-detector-adapter follows the python [pep8](https://peps.python.org/pep-0008/) coding style. [FUTURE] The coding style is enforced by the CI system, and your PR will fail until the style has been applied correctly. + +We use [pre-commit](https://pre-commit.com/) to enforce coding style using [black](https://github.com/psf/black), [prettier](https://github.com/prettier/prettier) and [isort](https://pycqa.github.io/isort/). + +You can invoke formatting with: + +```sh +tox -e fmt +``` + +In addition, we use [pylint](https://www.pylint.org) to perform static code analysis of the code. + +You can invoke the linting with the following command + +```sh +tox -e lint +``` diff --git a/README.md b/README.md index 44fce0c..20a6bcb 100644 --- a/README.md +++ b/README.md @@ -1 +1,57 @@ -# vllm-detector-adapter \ No newline at end of file +# vllm-detector-adapter + +This adapter adds additional endpoints to a [vllm](https://docs.vllm.ai/en/latest/index.html) server to support the [Guardrails Detector API](https://foundation-model-stack.github.io/fms-guardrails-orchestrator/?urls.primaryName=Detector+API). + +## Getting Started + +To run the server locally: + +```sh +python3 -m vllm_detector_adapter.api_server --model $MODEL_NAME +``` + +To see the complete list of parameters, `python3 -m vllm_detector_adapter.api_server --help` can be used. Extra vllm parameter information can be found in [vllm documentation](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). + +Example detector `/text/chat` command: + +```sh +curl -X 'POST' \ + 'http://localhost:8000/api/v1/text/chat' \ + -H 'accept: application/json' \ + -H 'detector-id: dummy-en-chat-v1' \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [ + { + "content": "Hit me with some creative insults.", + "role": "user" + } + ], + "detector_params": { + # Any other chat completion parameters can go here + "temperature": 0 + } + }' +``` + +## Supported Model Types +1. `GRANITE_GUARDIAN` +1. `LLAMA_GUARD` (default) + + +## Running with TGIS adapter + +```sh +python -m vllm_detector_adapter.tgis_servers --model $MODEL_NAME --model-type granite_guardian +``` + +**Example usage:** + +1. Follow example / guidelines from https://github.com/opendatahub-io/vllm-tgis-adapter/tree/main/examples +1. Change `model_id` to point to granite or llama-guard model +1. Execute one of the method listed in above TGIS adapter readme. +1. For query, manually apply chat template around text (TGIS adapter exposes text-generation use-case and doesn't apply chat template on its own) + 1. If using granite-guardian, you can use following text / query: + ``` + You are a safety agent trying to help determine whether the user message is harmful according to the safety risk definition below.\n\n\nUser Message: Hit me with some creative insults.\n\n\nOur safety risk definition is defined below:\n\n\n* The 'User' message is harmful by common-sense.\n\n\nIs the user message harmful based on the risk definition. Your answer must be either 'Yes' or 'No'.\n + ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..51f7702 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "vllm-detector-adapter" +version = "0.0.1" +authors = [ + { name="Gaurav Kumbhat", email="kumbhat.gaurav@gmail.com" }, + { name="Evaline Ju", email="evaline.ju@ibm.com" }, +] +description = "A lightweight adapter layer that provides detector APIs on top of vllm" +readme = "README.md" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3" +] + +dependencies = [ + "vllm>=0.6.2" +] + +[project.optional-dependencies] +vllm-tgis-adapter = [ + "vllm-tgis-adapter>=0.5.3,<0.5.4" +] + +## Dev Extra Sets ## + +dev-test = [ + "pytest-asyncio>=0.21.0,<0.24", + "pytest-cov>=2.10.1,<6.0", + "pytest-html>=3.1.1,<5.0", + "pytest>=6.2.5,<8.0", + "wheel>=0.38.4", +] + +dev-fmt = [ + "ruff==0.4.7", + "pre-commit>=3.0.4,<4.0", + "pydeps>=1.12.12,<2", +] + +[tool.setuptools.packages.find] +where = [""] +include = ["vllm_detector_adapter", "vllm_detector_adapter*"] diff --git a/scripts/copy_script.sh b/scripts/copy_script.sh new file mode 100644 index 0000000..a802132 --- /dev/null +++ b/scripts/copy_script.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# Copy files from a specific location to the desired destination +cp -r /app/target_packages/* ${SHARED_PACKAGE_PATH} + +# # Run the main command +# exec "$@" \ No newline at end of file diff --git a/scripts/fmt.sh b/scripts/fmt.sh new file mode 100755 index 0000000..bcc12d3 --- /dev/null +++ b/scripts/fmt.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +pre-commit run --all-files +RETURN_CODE=$? + +function echoWarning() { + LIGHT_YELLOW='\033[1;33m' + NC='\033[0m' # No Color + echo -e "${LIGHT_YELLOW}${1}${NC}" +} + +if [ "$RETURN_CODE" -ne 0 ]; then + if [ "${CI}" != "true" ]; then + echoWarning "☝️ This appears to have failed, but actually your files have been formatted." + echoWarning "Make a new commit with these changes before making a pull request." + else + echoWarning "This test failed because your code isn't formatted correctly." + echoWarning 'Locally, run `make run fmt`, it will appear to fail, but change files.' + echoWarning "Add the changed files to your commit and this stage will pass." + fi + + exit $RETURN_CODE +fi diff --git a/setup_requirements.txt b/setup_requirements.txt new file mode 100644 index 0000000..dafa932 --- /dev/null +++ b/setup_requirements.txt @@ -0,0 +1,2 @@ +tox>=4.4.2,<5 +build>=1.2.1,<2.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/generative_detectors/__init__.py b/tests/generative_detectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/generative_detectors/test_base.py b/tests/generative_detectors/test_base.py new file mode 100644 index 0000000..60b10dd --- /dev/null +++ b/tests/generative_detectors/test_base.py @@ -0,0 +1,81 @@ +# Standard +from dataclasses import dataclass +import asyncio + +# Third Party +from vllm.config import MultiModalConfig +from vllm.entrypoints.openai.serving_engine import BaseModelPath +import jinja2 +import pytest_asyncio + +# Local +from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + embedding_mode = False + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + + +@dataclass +class MockEngine: + async def get_model_config(self): + return MockModelConfig() + + +async def _async_serving_detection_completion_init(): + """Initialize a chat completion base with string templates""" + engine = MockEngine() + model_config = await engine.get_model_config() + + detection_completion = ChatCompletionDetectionBase( + task_template="hello {{user_text}}", + output_template="bye {{text}}", + engine_client=engine, + model_config=model_config, + base_model_paths=BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + return detection_completion + + +@pytest_asyncio.fixture +async def detection_base(): + return _async_serving_detection_completion_init() + + +### Tests ##################################################################### + + +def test_async_serving_detection_completion_init(detection_base): + detection_completion = asyncio.run(detection_base) + assert detection_completion.chat_template == CHAT_TEMPLATE + + # tests load_template + task_template = detection_completion.task_template + assert type(task_template) == jinja2.environment.Template + assert task_template.render(({"user_text": "moose"})) == "hello moose" + + output_template = detection_completion.output_template + assert type(output_template) == jinja2.environment.Template + assert output_template.render(({"text": "moose"})) == "bye moose" diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py new file mode 100644 index 0000000..7afb02a --- /dev/null +++ b/tests/generative_detectors/test_granite_guardian.py @@ -0,0 +1,246 @@ +# Standard +from dataclasses import dataclass +from http import HTTPStatus +from unittest.mock import patch +import asyncio + +# Third Party +from vllm.config import MultiModalConfig +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import BaseModelPath +import pytest +import pytest_asyncio + +# Local +from vllm_detector_adapter.generative_detectors.granite_guardian import GraniteGuardian +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + DetectionChatMessageParam, +) + +MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model +CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + embedding_mode = False + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + + +@dataclass +class MockEngine: + async def get_model_config(self): + return MockModelConfig() + + +async def _granite_guardian_init(): + """Initialize a granite guardian""" + engine = MockEngine() + model_config = await engine.get_model_config() + + granite_guardian = GraniteGuardian( + task_template=None, + output_template=None, + engine_client=engine, + model_config=model_config, + base_model_paths=BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + return granite_guardian + + +@pytest_asyncio.fixture +async def granite_guardian_detection(): + return _granite_guardian_init() + + +@pytest.fixture(scope="module") +def granite_guardian_completion_response(): + log_probs_content_yes = ChatCompletionLogProbsContent( + token="Yes", + logprob=0.0, + # 5 logprobs requested for scoring, skipping bytes for conciseness + top_logprobs=[ + ChatCompletionLogProb(token="Yes", logprob=0.0), + ChatCompletionLogProb(token='"No', logprob=-6.3), + ChatCompletionLogProb(token="yes", logprob=-16.44), + ChatCompletionLogProb(token=" Yes", logprob=-16.99), + ChatCompletionLogProb(token="YES", logprob=-17.52), + ], + ) + log_probs_content_random = ChatCompletionLogProbsContent( + token="", + logprob=-4.76, + # 5 logprobs requested for scoring, skipping bytes for conciseness + top_logprobs=[ + ChatCompletionLogProb(token="", logprob=-4.76), + ChatCompletionLogProb(token="", logprob=-14.66), + ChatCompletionLogProb(token="\n", logprob=-17.96), + ChatCompletionLogProb(token="[/", logprob=-18.32), + ChatCompletionLogProb(token="\n\n", logprob=-18.41), + ], + ) + choice_0 = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Yes", + ), + logprobs=ChatCompletionLogProbs( + content=[log_probs_content_yes, log_probs_content_random] + ), + ) + choice_1 = ChatCompletionResponseChoice( + index=1, + message=ChatMessage( + role="assistant", + content="Yes", + ), + logprobs=ChatCompletionLogProbs( + content=[log_probs_content_random, log_probs_content_yes] + ), + ) + yield ChatCompletionResponse( + model=MODEL_NAME, + choices=[choice_0, choice_1], + usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4), + ) + + +### Tests ##################################################################### + + +def test_preprocess_with_detector_params(granite_guardian_detection): + llama_guard_detection_instance = asyncio.run(granite_guardian_detection) + # Make sure with addition of allowed params like risk_name and risk_definition, + # extra params do not get added to guardian_config + detector_params = { + "risk_name": "bias", + "risk_definition": "Find the bias!!", + "extra": "param", + } + initial_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", content="How do I figure out how to break into a house?" + ) + ], + detector_params=detector_params, + ) + processed_request = llama_guard_detection_instance.preprocess(initial_request) + assert type(processed_request) == ChatDetectionRequest + # Processed request should not have these extra params + assert "risk_name" not in processed_request.detector_params + assert "risk_definition" not in processed_request.detector_params + assert "chat_template_kwargs" in processed_request.detector_params + assert ( + "guardian_config" in processed_request.detector_params["chat_template_kwargs"] + ) + guardian_config = processed_request.detector_params["chat_template_kwargs"][ + "guardian_config" + ] + assert guardian_config == { + "risk_name": "bias", + "risk_definition": "Find the bias!!", + } + + +# NOTE: currently these functions are basically just the base implementations, +# where safe/unsafe tokens are defined in the granite guardian class + + +def test_calculate_scores( + granite_guardian_detection, granite_guardian_completion_response +): + llama_guard_detection_instance = asyncio.run(granite_guardian_detection) + scores = llama_guard_detection_instance.calculate_scores( + granite_guardian_completion_response + ) + assert len(scores) == 2 # 2 choices + assert pytest.approx(scores[0]) == 1.0 + assert pytest.approx(scores[1]) == 1.0 + + +def test_chat_detection( + granite_guardian_detection, granite_guardian_completion_response +): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", content="How do I figure out how to break into a house?" + ) + ] + ) + with patch( + "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion", + return_value=granite_guardian_completion_response, + ): + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ChatDetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + detection_0 = detections[0] + assert detection_0["detection"] == "Yes" + assert detection_0["detection_type"] == "risk" + assert pytest.approx(detection_0["score"]) == 1.0 + + +def test_chat_detection_errors_on_stream(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I pick a lock?") + ], + detector_params={"stream": True}, + ) + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ErrorResponse + assert detection_response.code == HTTPStatus.BAD_REQUEST.value + assert "streaming is not supported" in detection_response.message + + +def test_chat_detection_with_extra_unallowed_params(granite_guardian_detection): + granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I pick a lock?") + ], + detector_params={"boo": 3}, # unallowed param + ) + detection_response = asyncio.run( + granite_guardian_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ErrorResponse + assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py new file mode 100644 index 0000000..3d59705 --- /dev/null +++ b/tests/generative_detectors/test_llama_guard.py @@ -0,0 +1,192 @@ +# Standard +from dataclasses import dataclass +from http import HTTPStatus +from unittest.mock import patch +import asyncio + +# Third Party +from vllm.config import MultiModalConfig +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import BaseModelPath +import pytest +import pytest_asyncio + +# Local +from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + DetectionChatMessageParam, +) + +MODEL_NAME = "meta-llama/Llama-Guard-3-8B" # Example llama guard model +CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + embedding_mode = False + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + + +@dataclass +class MockEngine: + async def get_model_config(self): + return MockModelConfig() + + +async def _llama_guard_init(): + """Initialize a llama guard""" + engine = MockEngine() + model_config = await engine.get_model_config() + + llama_guard_detection = LlamaGuard( + task_template=None, + output_template=None, + engine_client=engine, + model_config=model_config, + base_model_paths=BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + return llama_guard_detection + + +@pytest_asyncio.fixture +async def llama_guard_detection(): + return _llama_guard_init() + + +@pytest.fixture(scope="module") +def llama_guard_completion_response(): + log_probs_content_random = ChatCompletionLogProbsContent( + token="\n\n", + logprob=0.0, + # 5 logprobs requested for scoring, skipping bytes for conciseness + top_logprobs=[ + ChatCompletionLogProb(token="\n\n", logprob=0.0), + ChatCompletionLogProb(token='"\n\n', logprob=-29.68), + ChatCompletionLogProb(token="\n", logprob=-30.57), + ChatCompletionLogProb(token=")\n\n", logprob=-31.64), + ChatCompletionLogProb(token="()\n\n", logprob=-32.18), + ], + ) + log_probs_content_safe = ChatCompletionLogProbsContent( + token="safe", + logprob=-0.0013, + # 5 logprobs requested for scoring, skipping bytes for conciseness + top_logprobs=[ + ChatCompletionLogProb(token="safe", logprob=-0.0013), + ChatCompletionLogProb(token="unsafe", logprob=-6.61), + ChatCompletionLogProb(token="1", logprob=-16.90), + ChatCompletionLogProb(token="2", logprob=-17.39), + ChatCompletionLogProb(token="3", logprob=-17.61), + ], + ) + choice_0 = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content="safe", + ), + logprobs=ChatCompletionLogProbs( + content=[log_probs_content_random, log_probs_content_safe] + ), + ) + choice_1 = ChatCompletionResponseChoice( + index=1, + message=ChatMessage( + role="assistant", + content="safe", + ), + logprobs=ChatCompletionLogProbs( + content=[log_probs_content_random, log_probs_content_safe] + ), + ) + yield ChatCompletionResponse( + model=MODEL_NAME, + choices=[choice_0, choice_1], + usage=UsageInfo(prompt_tokens=200, total_tokens=206, completion_tokens=6), + ) + + +### Tests ##################################################################### + +# NOTE: currently these functions are basically just the base implementations, +# where safe/unsafe tokens are defined in the llama guard class + + +def test_calculate_scores(llama_guard_detection, llama_guard_completion_response): + llama_guard_detection_instance = asyncio.run(llama_guard_detection) + scores = llama_guard_detection_instance.calculate_scores( + llama_guard_completion_response + ) + assert len(scores) == 2 # 2 choices + assert pytest.approx(scores[0]) == 0.001346767 + assert pytest.approx(scores[1]) == 0.001346767 + + +def test_chat_detection(llama_guard_detection, llama_guard_completion_response): + llama_guard_detection_instance = asyncio.run(llama_guard_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", content="How do I search for moose?" + ), + DetectionChatMessageParam( + role="assistant", content="You could go to Canada" + ), + DetectionChatMessageParam(role="user", content="interesting"), + ] + ) + with patch( + "vllm_detector_adapter.generative_detectors.llama_guard.LlamaGuard.create_chat_completion", + return_value=llama_guard_completion_response, + ): + detection_response = asyncio.run( + llama_guard_detection_instance.chat(chat_request) + ) + assert type(detection_response) == ChatDetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + detection_0 = detections[0] + assert detection_0["detection"] == "safe" + assert detection_0["detection_type"] == "risk" + assert pytest.approx(detection_0["score"]) == 0.001346767 + + +def test_chat_detection_with_extra_unallowed_params(llama_guard_detection): + llama_guard_detection_instance = asyncio.run(llama_guard_detection) + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I search for moose?") + ], + detector_params={"moo": "unallowed"}, # unallowed param + ) + detection_response = asyncio.run(llama_guard_detection_instance.chat(chat_request)) + assert type(detection_response) == ErrorResponse + assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..93ae9e8 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,126 @@ +# Standard +from http import HTTPStatus + +# Third Party +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + UsageInfo, +) + +# Local +from vllm_detector_adapter.protocol import ( + ChatDetectionRequest, + ChatDetectionResponse, + DetectionChatMessageParam, +) + +MODEL_NAME = "org/model-name" + +### Tests ##################################################################### + + +def test_detection_to_completion_request(): + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam( + role="user", content="How do I search for moose?" + ), + DetectionChatMessageParam( + role="assistant", content="You could go to Canada" + ), + ], + detector_params={"n": 3, "temperature": 0.5}, + ) + request = chat_request.to_chat_completion_request(MODEL_NAME) + assert type(request) == ChatCompletionRequest + assert request.messages[0]["role"] == "user" + assert request.messages[0]["content"] == "How do I search for moose?" + assert request.messages[1]["role"] == "assistant" + assert request.messages[1]["content"] == "You could go to Canada" + assert request.model == MODEL_NAME + assert request.temperature == 0.5 + assert request.n == 3 + + +def test_detection_to_completion_request_unknown_params(): + chat_request = ChatDetectionRequest( + messages=[ + DetectionChatMessageParam(role="user", content="How do I search for moose?") + ], + detector_params={"moo": 2}, + ) + request = chat_request.to_chat_completion_request(MODEL_NAME) + assert type(request) == ErrorResponse + assert request.code == HTTPStatus.BAD_REQUEST.value + + +def test_response_from_completion_response(): + # Simplified response without logprobs since not needed for this method + choice_0 = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=" moose", + ), + ) + choice_1 = ChatCompletionResponseChoice( + index=1, + message=ChatMessage( + role="assistant", + content="goose\n\n", + ), + ) + response = ChatCompletionResponse( + model=MODEL_NAME, + choices=[choice_0, choice_1], + usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4), + ) + scores = [0.3, 0.7] + detection_type = "type" + detection_response = ChatDetectionResponse.from_chat_completion_response( + response, scores, detection_type + ) + assert type(detection_response) == ChatDetectionResponse + detections = detection_response.model_dump() + assert len(detections) == 2 # 2 choices + detection_0 = detections[0] + assert detection_0["detection"] == "moose" + assert detection_0["detection_type"] == "type" + assert detection_0["score"] == 0.3 + detection_1 = detections[1] + assert detection_1["detection"] == "goose" + assert detection_1["detection_type"] == "type" + assert detection_1["score"] == 0.7 + + +def test_response_from_completion_response_missing_content(): + choice_0 = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=" moose", + ), + ) + choice_1 = ChatCompletionResponseChoice( + index=1, message=ChatMessage(role="assistant") + ) + response = ChatCompletionResponse( + model=MODEL_NAME, + choices=[choice_0, choice_1], + usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4), + ) + scores = [0.3, 0.7] + detection_type = "type" + detection_response = ChatDetectionResponse.from_chat_completion_response( + response, scores, detection_type + ) + assert type(detection_response) == ErrorResponse + assert ( + "Choice 1 from chat completion does not have content" + in detection_response.message + ) + assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..dfa8504 --- /dev/null +++ b/tox.ini @@ -0,0 +1,36 @@ +[tox] +envlist = py, lint, fmt + +[testenv] +description = run tests with pytest with coverage +extras = + all + dev-test +passenv = + LOG_LEVEL + LOG_FILTERS + LOG_FORMATTER + LOG_THREAD_ID + LOG_CHANNEL_WIDTH +setenv = + DFTYPE = pandas_all + +commands = pytest --cov=vllm_detector_adapter --cov-report=html:coverage-{env_name} --cov-report=xml:coverage-{env_name}.xml --html=durations/{env_name}.html {posargs:tests} -W error::UserWarning +; -W ignore::DeprecationWarning + +; Unclear: We probably want to test wheel packaging +; But! tox will fail when this is set and _any_ interpreter is missing +; Without this, sdist packaging is tested so that's a start. +package=wheel + +[testenv:fmt] +description = format with pre-commit +extras = dev-fmt +commands = ./scripts/fmt.sh +allowlist_externals = ./scripts/fmt.sh + +[testenv:lint] +description = lint with ruff +extras = + dev-fmt +commands = ruff check vllm_detector_adapter diff --git a/vllm_detector_adapter/__init__.py b/vllm_detector_adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py new file mode 100644 index 0000000..bff5ffb --- /dev/null +++ b/vllm_detector_adapter/api_server.py @@ -0,0 +1,195 @@ +# Standard +from argparse import Namespace +import signal +import socket + +# Third Party +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.datastructures import State +from vllm.config import ModelConfig +from vllm.engine.arg_utils import nullable_str +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai import api_server +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.utils import FlexibleArgumentParser +from vllm.version import __version__ as VLLM_VERSION +import uvloop + +# Local +from vllm_detector_adapter import generative_detectors +from vllm_detector_adapter.logging import init_logger +from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger("vllm_detector_adapter.api_server") + +# Use original vllm router and add to it +router = api_server.router + + +def chat_detection( + request: Request, +) -> generative_detectors.base.ChatCompletionDetectionBase: + return request.app.state.detectors_serving_chat_detection + + +def init_app_state_with_detectors( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + """Add detection capabilities to app state""" + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) for name in served_model_names + ] + + # Use vllm app state init + api_server.init_app_state(engine_client, model_config, state, args) + + generative_detector_class = generative_detectors.MODEL_CLASS_MAP[args.model_type] + + # Add chat detection + state.detectors_serving_chat_detection = generative_detector_class( + args.task_template, + args.output_template, + engine_client, + model_config, + base_model_paths, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + ) + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Server should include all vllm supported endpoints and any + newly added detection endpoints""" + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with api_server.build_async_engine_client(args) as engine_client: + # Use vllm build_app which adds middleware + app = api_server.build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state_with_detectors(engine_client, model_config, app.state, args) + + temp_socket.close() + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +@router.post("/api/v1/text/chat") +async def create_chat_detection(request: ChatDetectionRequest, raw_request: Request): + """Support chat detection endpoint""" + + detector_response = await chat_detection(raw_request).chat(request, raw_request) + + if isinstance(detector_response, ErrorResponse): + # ErrorResponse includes code and message, corresponding to errors for the detectorAPI + return JSONResponse( + content=detector_response.model_dump(), status_code=detector_response.code + ) + + elif isinstance(detector_response, ChatDetectionResponse): + return JSONResponse(content=detector_response.model_dump()) + + return JSONResponse({}) + + +def add_chat_detection_params(parser): + parser.add_argument( + "--task-template", + type=nullable_str, + default=None, + help="The file path to the task template, " + "or the template in single-line form " + "for the specified model", + ) + parser.add_argument( + "--output-template", + type=nullable_str, + default=None, + help="The file path to the output template, " + "or the template in single-line form " + "for the specified model", + ) + parser.add_argument( + "--model-type", + type=generative_detectors.ModelTypes, + choices=[ + member.lower() for member in generative_detectors.ModelTypes._member_names_ + ], + default=generative_detectors.ModelTypes.LLAMA_GUARD, + help="The model type of the generative model", + ) + return parser + + +if __name__ == "__main__": + + # Verify vllm compatibility + # Local + from vllm_detector_adapter import package_validate + + package_validate.verify_vllm_compatibility() + + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server." + ) + parser = make_arg_parser(parser) + + # Add chat detection params + parser = add_chat_detection_params(parser) + + args = parser.parse_args() + + uvloop.run(run_server(args)) diff --git a/vllm_detector_adapter/generative_detectors/__init__.py b/vllm_detector_adapter/generative_detectors/__init__.py new file mode 100644 index 0000000..66a17de --- /dev/null +++ b/vllm_detector_adapter/generative_detectors/__init__.py @@ -0,0 +1,48 @@ +# Standard +from collections import defaultdict +from enum import StrEnum, auto # Python3.11+ + +# Local +from vllm_detector_adapter.generative_detectors import ( + base, + granite_guardian, + llama_guard, +) + +# To only expose certain objects from this file +__all__ = ["ModelTypes", "MODEL_CLASS_MAP"] + + +class CaseInsensitiveStrEnum(StrEnum): + """Custom StrEnum to make member lookup case insensitive""" + + @classmethod + def _missing_(cls, value): + return cls.__members__.get(value.upper(), None) + + +class ModelTypes(CaseInsensitiveStrEnum): + + GRANITE_GUARDIAN = auto() + LLAMA_GUARD = auto() + + +# Use ChatCompletionDetectionBase as the base class +def __default_detection_base__(): + return base.ChatCompletionDetectionBase + + +# Dictionary mapping generative detection classes with model types. +# This gets used for configuring which detection processing to use with which model. +MODEL_CLASS_MAP = defaultdict(__default_detection_base__) + +# This is to add values to above MAP. This is private, to discourage these values to +# be used directly. +__MODEL_CLASS_MAP__ = { + ModelTypes.GRANITE_GUARDIAN: granite_guardian.GraniteGuardian, + ModelTypes.LLAMA_GUARD: llama_guard.LlamaGuard, +} + +# Feede all the values to the MODEL_CLASS_MAP +for key, value in __MODEL_CLASS_MAP__.items(): + MODEL_CLASS_MAP[key] = value diff --git a/vllm_detector_adapter/generative_detectors/base.py b/vllm_detector_adapter/generative_detectors/base.py new file mode 100644 index 0000000..9518cea --- /dev/null +++ b/vllm_detector_adapter/generative_detectors/base.py @@ -0,0 +1,187 @@ +# Standard +from http import HTTPStatus +from pathlib import Path +from typing import List, Optional, Union +import codecs +import math + +# Third Party +from fastapi import Request +from vllm.entrypoints.openai.protocol import ChatCompletionResponse, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +import jinja2 +import torch + +# Local +from vllm_detector_adapter.logging import init_logger +from vllm_detector_adapter.protocol import ChatDetectionRequest, ChatDetectionResponse + +logger = init_logger(__name__) + +START_PROB = 1e-50 + + +class ChatCompletionDetectionBase(OpenAIServingChat): + """Base class for developing chat completion based detectors""" + + def __init__(self, task_template: str, output_template: str, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.jinja_env = jinja2.Environment() + self.task_template = self.load_template(task_template) + + self.output_template = self.load_template(output_template) + + def load_template(self, template_path: Optional[Union[Path, str]]) -> str: + """Function to load template + Note: this function currently is largely taken from the chat template method + in vllm.entrypoints.chat_utils + """ + if template_path is None: + return None + try: + with open(template_path, "r") as f: + resolved_template = f.read() + # Addition to vllm's original load chat template method + # This prevents additional escaping of template characters + # such as \n (newlines) + resolved_template = codecs.decode(resolved_template, "unicode-escape") + except OSError as e: + if isinstance(template_path, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in template_path for c in JINJA_CHARS): + msg = ( + f"The supplied template ({template_path}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) + raise ValueError(msg) from e + + # If opening a file fails, set template to be args to + # ensure we decode so our escape are interpreted correctly + resolved_template = codecs.decode(template_path, "unicode_escape") + + logger.info("Using supplied template:\n%s", resolved_template) + return self.jinja_env.from_string(resolved_template) + + def apply_task_template( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Apply task template on the request""" + return request + + def preprocess( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Preprocess request""" + return request + + def apply_output_template( + self, response: ChatCompletionResponse + ) -> Union[ChatCompletionResponse, ErrorResponse]: + """Apply output parsing template for the response""" + return response + + def calculate_scores(self, response: ChatCompletionResponse) -> List[float]: + """Extract scores from logprobs of the raw chat response""" + safe_token_prob = START_PROB + unsafe_token_prob = START_PROB + + choice_scores = [] + + # TODO: consider if this part can be optimized despite nested response structure + for choice in response.choices: + # Each choice will have logprobs for tokens + for logprob_info_i in choice.logprobs.content: + # NOTE: open-ai chat completion performs a max operation over top log probs + # and puts that result in `logprobs`, whereas we need to do a sum over these as + # per discussion with granite team. So we are pulling in `top_logprobs` + for top_logprob in logprob_info_i.top_logprobs: + token = top_logprob.token + if token.strip().lower() == self.SAFE_TOKEN.lower(): + safe_token_prob += math.exp(top_logprob.logprob) + if token.strip().lower() == self.UNSAFE_TOKEN.lower(): + unsafe_token_prob += math.exp(top_logprob.logprob) + + probabilities = torch.softmax( + torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), + dim=0, + ) + + # We calculate "probability of risk" here, therefore, only return probability related to + # unsafe_token_prob. Use .item() to get tensor float + choice_scores.append(probabilities[1].item()) + + return choice_scores + + ##### Detection methods #################################################### + # Base implementation of other detection endpoints like content can go here + + async def chat( + self, + request: ChatDetectionRequest, + raw_request: Optional[Request] = None, + ) -> Union[ChatDetectionResponse, ErrorResponse]: + """Function used to call chat detection and provide a /chat response""" + + # Fetch model name from super class: OpenAIServing + model_name = self.base_model_paths[0].name + + # Apply task template if it exists + if self.task_template: + request = self.apply_task_template(request) + if isinstance(request, ErrorResponse): + # Propagate any request problems that will not allow + # task template to be applied + return request + + # Optionally make model-dependent adjustments for the request + request = self.preprocess(request) + + chat_completion_request = request.to_chat_completion_request(model_name) + if isinstance(chat_completion_request, ErrorResponse): + # Propagate any request problems like extra unallowed parameters + return chat_completion_request + + # Return an error for streaming for now. Since the detector API is unary, + # results would not be streamed back anyway. The chat completion response + # object would look different, and content would have to be aggregated. + if chat_completion_request.stream: + return ErrorResponse( + message="streaming is not supported for the detector", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + # Manually set logprobs to True to calculate score later on + # NOTE: this is supposed to override if user has set logprobs to False + # or left logprobs as the default False + chat_completion_request.logprobs = True + # NOTE: We need top_logprobs to be enabled to calculate score appropriately + # We override this and not allow configuration at this point. In future, we may + # want to expose this configurable to certain range. + chat_completion_request.top_logprobs = 5 + + logger.debug("Request to chat completion: %s", chat_completion_request) + + # Call chat completion + chat_response = await self.create_chat_completion( + chat_completion_request, raw_request + ) + logger.debug("Raw chat completion response: %s", chat_response) + if isinstance(chat_response, ErrorResponse): + # Propagate chat completion errors directly + return chat_response + + # Apply output template if it exists + if self.output_template: + chat_response = self.apply_output_template(chat_response) + + # Calculate scores + scores = self.calculate_scores(chat_response) + + return ChatDetectionResponse.from_chat_completion_response( + chat_response, scores, self.DETECTION_TYPE + ) diff --git a/vllm_detector_adapter/generative_detectors/granite_guardian.py b/vllm_detector_adapter/generative_detectors/granite_guardian.py new file mode 100644 index 0000000..8ad8c06 --- /dev/null +++ b/vllm_detector_adapter/generative_detectors/granite_guardian.py @@ -0,0 +1,49 @@ +# Standard +from typing import Union + +# Third Party +from vllm.entrypoints.openai.protocol import ErrorResponse + +# Local +from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase +from vllm_detector_adapter.logging import init_logger +from vllm_detector_adapter.protocol import ChatDetectionRequest + +logger = init_logger(__name__) + + +class GraniteGuardian(ChatCompletionDetectionBase): + + DETECTION_TYPE = "risk" + # User text pattern in task template + USER_TEXT_PATTERN = "user_text" + + # Model specific tokens + SAFE_TOKEN = "No" + UNSAFE_TOKEN = "Yes" + + def preprocess( + self, request: ChatDetectionRequest + ) -> Union[ChatDetectionRequest, ErrorResponse]: + """Granite guardian specific parameter updates for risk name and risk definition""" + # Validation that one of the 'defined' risks is requested will be + # done through the chat template on each request. Errors will + # be propagated for chat completion separately + guardian_config = {} + if not request.detector_params: + return request + + if risk_name := request.detector_params.pop("risk_name", None): + guardian_config["risk_name"] = risk_name + if risk_definition := request.detector_params.pop("risk_definition", None): + guardian_config["risk_definition"] = risk_definition + if guardian_config: + logger.debug("guardian_config {} provided for request", guardian_config) + # Move the risk name and/or risk definition to chat_template_kwargs + # to be propagated to tokenizer.apply_chat_template during + # chat completion + request.detector_params["chat_template_kwargs"] = { + "guardian_config": guardian_config + } + + return request diff --git a/vllm_detector_adapter/generative_detectors/llama_guard.py b/vllm_detector_adapter/generative_detectors/llama_guard.py new file mode 100644 index 0000000..9e8f796 --- /dev/null +++ b/vllm_detector_adapter/generative_detectors/llama_guard.py @@ -0,0 +1,17 @@ +# Local +from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase +from vllm_detector_adapter.logging import init_logger + +logger = init_logger(__name__) + + +class LlamaGuard(ChatCompletionDetectionBase): + + DETECTION_TYPE = "risk" + + # Model specific tokens + SAFE_TOKEN = "safe" + UNSAFE_TOKEN = "unsafe" + + # NOTE: More intelligent template parsing can be done here, potentially + # as a regex template for safe vs. unsafe and the 'unsafe' category diff --git a/vllm_detector_adapter/logging.py b/vllm_detector_adapter/logging.py new file mode 100644 index 0000000..fef02e9 --- /dev/null +++ b/vllm_detector_adapter/logging.py @@ -0,0 +1,30 @@ +"""Configure logger for vllm_detector_adapter +vllm logger configures app / logger for `vllm`. +Since vllm_detector_adapter is built on top of vllm, +we need to add this app to the logger configuration for the +logs to show up. Otherwise, it will get filtered out. +vllm logger ref: https://github.com/vllm-project/vllm/blob/04de9057ab8099291e66ad876e78693c7c2f2ce5/vllm/logger.py#L81 +""" + +# Standard +import logging + +# Third Party +from vllm.logger import init_logger # noqa: F401 +from vllm.logger import DEFAULT_LOGGING_CONFIG + +config = {**DEFAULT_LOGGING_CONFIG} + +config["formatters"]["vllm_detector_adapter"] = DEFAULT_LOGGING_CONFIG["formatters"][ + "vllm" +] + +handler_config = DEFAULT_LOGGING_CONFIG["handlers"]["vllm"] +handler_config["formatter"] = "vllm_detector_adapter" +config["handlers"]["vllm_detector_adapter"] = handler_config + +logger_config = DEFAULT_LOGGING_CONFIG["loggers"]["vllm"] +logger_config["handlers"] = ["vllm_detector_adapter"] +config["loggers"]["vllm_detector_adapter"] = logger_config + +logging.config.dictConfig(config) diff --git a/vllm_detector_adapter/package_validate.py b/vllm_detector_adapter/package_validate.py new file mode 100644 index 0000000..f891d8a --- /dev/null +++ b/vllm_detector_adapter/package_validate.py @@ -0,0 +1,43 @@ +# Standard +from importlib import metadata +from importlib.metadata import distribution + +# Third Party +from pip._vendor.packaging.requirements import Requirement + +PACKAGE_NAME = "vllm-detector-adapter" +VLLM_PACKAGE_NAME = "vllm" + + +def verify_vllm_compatibility(allow_prereleases=True): + vllm_installed_version = metadata.version(VLLM_PACKAGE_NAME) + + vllm_required_version = "" + + vllm_adpt_dist = distribution(PACKAGE_NAME) + + # Extract vllm specified requirement from vllm-detector-adapter + for package_info in vllm_adpt_dist.requires: + requirement = Requirement(package_info) + if requirement.name == VLLM_PACKAGE_NAME: + vllm_required_version = requirement.specifier + + vllm_required_version.prereleases = allow_prereleases + + # Check if installed vllm version is compatible with PACKAGE_NAME requirements + if vllm_installed_version in vllm_required_version: + print("vLLM versions are compatible!") + else: + print( + """ + Incompatible vLLM version installed. Please fix by either updating vLLM image or updating {} \n\n + Installed vLLM version: {} + {} required vLLM version range: {} + """.format( + PACKAGE_NAME, + vllm_installed_version, + PACKAGE_NAME, + str(vllm_required_version), + ) + ) + exit(1) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py new file mode 100644 index 0000000..c952b21 --- /dev/null +++ b/vllm_detector_adapter/protocol.py @@ -0,0 +1,127 @@ +# Standard +from http import HTTPStatus +from typing import Dict, List, Optional + +# Third Party +from pydantic import BaseModel, Field, RootModel, ValidationError +from typing_extensions import Required, TypedDict +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, +) + +##### [FMS] Detection API types +# NOTE: This currently works with the /chat detection endpoint + +######## Contents Detection types + + +class ContentsDetectionRequest(BaseModel): + contents: List[str] = Field( + examples=[ + "Hi is this conversation guarded", + "Yes, it is", + ] + ) + + +class ContentsDetectionResponseObject(BaseModel): + start: int = Field(examples=[0]) + end: int = Field(examples=[10]) + text: str = Field(examples=["text"]) + detection: str = Field(examples=["positive"]) + detection_type: str = Field(examples=["simple_example"]) + score: float = Field(examples=[0.5]) + + +######## Chat Detection types + + +class DetectionChatMessageParam(TypedDict): + # The role of the message's author + role: Required[str] + + # The contents of the message + content: str + + +class ChatDetectionRequest(BaseModel): + # Chat messages + messages: List[DetectionChatMessageParam] = Field( + examples=[ + DetectionChatMessageParam( + role="user", content="Hi is this conversation guarded" + ), + DetectionChatMessageParam(role="assistant", content="Yes, it is"), + ] + ) + + # Parameters to pass through to chat completions, optional + detector_params: Optional[Dict] = {} + + def to_chat_completion_request(self, model_name: str): + """Function to convert [fms] chat detection request to openai chat completion request""" + messages = [ + {"role": message["role"], "content": message["content"]} + for message in self.messages + ] + + # Try to pass all detector_params through as additional parameters to chat completions. + # This will error if extra unallowed parameters are included. We do not try to provide + # validation or changing of parameters here to not be dependent on chat completion API + # changes + try: + return ChatCompletionRequest( + messages=messages, + # NOTE: below is temporary + model=model_name, + **self.detector_params, + ) + except ValidationError as e: + return ErrorResponse( + message=repr(e.errors()[0]), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + +class ChatDetectionResponseObject(BaseModel): + detection: str = Field(examples=["positive"]) + detection_type: str = Field(examples=["simple_example"]) + score: float = Field(examples=[0.5]) + + +class ChatDetectionResponse(RootModel): + # The root attribute is used here so that the response will appear + # as a list instead of a list nested under a key + root: List[ChatDetectionResponseObject] + + @staticmethod + def from_chat_completion_response( + response: ChatCompletionResponse, scores: List[float], detection_type: str + ): + """Function to convert openai chat completion response to [fms] chat detection response""" + detection_responses = [] + for i, choice in enumerate(response.choices): + content = choice.message.content + if content and isinstance(content, str): + response_object = ChatDetectionResponseObject( + detection_type=detection_type, + detection=content.strip(), + score=scores[i], + ).model_dump() + detection_responses.append(response_object) + else: + # This case should be unlikely but we handle it since a detection + # can't be returned without the content + # A partial response could be considered in the future + # but that would likely not look like the current ErrorResponse + return ErrorResponse( + message=f"Choice {i} from chat completion does not have content. \ + Consider updating input and/or parameters for detections.", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + + return ChatDetectionResponse(root=detection_responses) diff --git a/vllm_detector_adapter/start_with_tgis_adapter.py b/vllm_detector_adapter/start_with_tgis_adapter.py new file mode 100644 index 0000000..ab31e23 --- /dev/null +++ b/vllm_detector_adapter/start_with_tgis_adapter.py @@ -0,0 +1,174 @@ +"""This module bridges the gap between tgis adapter and vllm and allows us to run +vllm-detector-adapter + tgis-adater grpc server together + +Most of this file is taken from: https://github.com/opendatahub-io/vllm-tgis-adapter/blob/main/src/vllm_tgis_adapter/__main__.py +""" +# Future +from __future__ import annotations + +# Standard +from concurrent.futures import FIRST_COMPLETED +from typing import TYPE_CHECKING +import argparse +import asyncio +import contextlib +import importlib.util +import os +import traceback + +# Third Party +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai import api_server +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser +import uvloop + +if TYPE_CHECKING: + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.engine.protocol import AsyncEngineClient + +# Local +from vllm_detector_adapter.api_server import ( + add_chat_detection_params, + init_app_state_with_detectors, +) +from vllm_detector_adapter.logging import init_logger + +TIMEOUT_KEEP_ALIVE = 5 +TGIS_ADAPTER_LIBRARY_NAME = "vllm-tgis-adapter" + +logger = init_logger("vllm_detector_adapter.start_with_tgis_adapter") + + +# Check if tgis_adapter is installed or not +if not importlib.util.find_spec(TGIS_ADAPTER_LIBRARY_NAME): + logger.error("{} library not installed".format(TGIS_ADAPTER_LIBRARY_NAME)) + exit(1) +else: + # Third Party + from vllm_tgis_adapter.grpc import run_grpc_server + from vllm_tgis_adapter.tgis_utils.args import ( + EnvVarArgumentParser, + add_tgis_args, + postprocess_tgis_args, + ) + from vllm_tgis_adapter.utils import check_for_failed_tasks, write_termination_log + +# Note: this function references run_http_server in +# vllm-tgis-adapter directly. https://github.com/opendatahub-io/vllm-tgis-adapter/blob/1d62372b78f24e156d6a748f30bb7273d8364532/src/vllm_tgis_adapter/http.py#L20C1-L48C24 +async def run_http_server( + args: argparse.Namespace, + engine: AsyncLLMEngine | AsyncEngineClient, + **uvicorn_kwargs, # noqa: ANN003 +) -> None: + # modified copy of vllm.entrypoints.openai.api_server.run_server that + # allows passing of the engine + + app = api_server.build_app(args) + model_config = await engine.get_model_config() + init_app_state_with_detectors(engine, model_config, app.state, args) + + serve_kwargs = { + "host": args.host, + "port": args.port, + "log_level": args.uvicorn_log_level, + "timeout_keep_alive": TIMEOUT_KEEP_ALIVE, + "ssl_keyfile": args.ssl_keyfile, + "ssl_certfile": args.ssl_certfile, + "ssl_ca_certs": args.ssl_ca_certs, + "ssl_cert_reqs": args.ssl_cert_reqs, + } + serve_kwargs.update(uvicorn_kwargs) + + shutdown_coro = await serve_http(app, **serve_kwargs) + + # launcher.serve_http returns a shutdown coroutine to await + # (The double await is intentional) + await shutdown_coro + + +async def start_servers(args: argparse.Namespace) -> None: + """This function starts both http_server (openai + vllm-detector-adapter) and + gRPC server (tgis-adapter) + """ + loop = asyncio.get_running_loop() + + tasks: list[asyncio.Task] = [] + async with api_server.build_async_engine_client(args) as engine: + http_server_task = loop.create_task( + run_http_server(args, engine), + name="http_server", + ) + # The http server task will catch interrupt signals for us + tasks.append(http_server_task) + + grpc_server_task = loop.create_task( + run_grpc_server(args, engine), + name="grpc_server", + ) + tasks.append(grpc_server_task) + + runtime_error = None + with contextlib.suppress(asyncio.CancelledError): + # Both server tasks will exit normally on shutdown, so we await + # FIRST_COMPLETED to catch either one shutting down. + await asyncio.wait(tasks, return_when=FIRST_COMPLETED) + if engine and engine.errored and not engine.is_running: + # both servers shut down when an engine error + # is detected, with task done and exception handled + # here we just notify of that error and let servers be + runtime_error = RuntimeError( + "AsyncEngineClient error detected, this may be caused by an \ + unexpected error in serving a request. \ + Please check the logs for more details." + ) + + failed_task = check_for_failed_tasks(tasks) + + # Once either server shuts down, cancel the other + for task in tasks: + task.cancel() + + # Final wait for both servers to finish + await asyncio.wait(tasks) + + # Raise originally-failed task if applicable + if failed_task: + name, coro_name = failed_task.get_name(), failed_task.get_coro().__name__ + exception = failed_task.exception() + raise RuntimeError(f"Failed task={name} ({coro_name})") from exception + + if runtime_error: + raise runtime_error + + +def run_and_catch_termination_cause( + loop: asyncio.AbstractEventLoop, task: asyncio.Task +) -> None: + try: + loop.run_until_complete(task) + except Exception: + # Report the first exception as cause of termination + msg = traceback.format_exc() + write_termination_log( + msg, os.getenv("TERMINATION_LOG_DIR", "/dev/termination-log") + ) + raise + + +if __name__ == "__main__": + parser = FlexibleArgumentParser("vLLM TGIS GRPC + OpenAI REST api server") + # convert to our custom env var arg parser + parser = EnvVarArgumentParser(parser=make_arg_parser(parser)) + parser = add_tgis_args(parser) + parser = add_chat_detection_params(parser) + args = postprocess_tgis_args(parser.parse_args()) + assert args is not None + + # logger.info("vLLM version %s", f"{vllm.__version__}") + logger.info("args: %s", args) + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + loop = asyncio.new_event_loop() + task = loop.create_task(start_servers(args)) + run_and_catch_termination_cause(loop, task)