Skip to content

Commit

Permalink
feat: Add image generation [DCH-127] (#6)
Browse files Browse the repository at this point in the history
* feat: Add a proxy for cohere-embed-english-v3

* feat: Add a Docker release

* feaet: Add CI testing

* fix: Switch directory in testing

* fix: Attempt to fix directory management

* fix: Attempt to fix directory in CI

* feat: Add mypy/ruff dependencies

* fix: Formatting issues

* fix formatting

* fix: ruff and env variable fixes

* feat: Add image generation
  • Loading branch information
ReinderVosDeWael authored Nov 7, 2024
1 parent db66052 commit 8280645
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 6 deletions.
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[mypy]
files = proxy
warn_unused_configs = True
namespace_packages = True
ignore_missing_imports = True
strict = True
2 changes: 1 addition & 1 deletion proxy/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Settings(pydantic_settings.BaseSettings):
AWS_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)
AWS_SECRET_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)

LOGGER_VERBOSITY: int = logging.INFO
LOGGER_VERBOSITY: int = logging.DEBUG


@functools.lru_cache
Expand Down
3 changes: 3 additions & 0 deletions proxy/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from app.core import auth, config
from app.routers.embeddings import views as embeddings_views
from app.routers.images import views as images_views

logger = config.get_logger()

Expand All @@ -20,4 +21,6 @@

version_router = fastapi.APIRouter(prefix="/v1")
version_router.include_router(embeddings_views.router)
version_router.include_router(images_views.router)

app.include_router(version_router)
4 changes: 2 additions & 2 deletions proxy/app/routers/embeddings/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def post_embedding(
logger.debug("Running Azure Embedding.")
return _run_aws_embedding(payload)
raise fastapi.HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Unknown model provider.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unknown provider.",
)


Expand Down
4 changes: 1 addition & 3 deletions proxy/app/routers/embeddings/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from app.core import config
from app.routers.embeddings import controller, schemas

router = fastapi.APIRouter(
prefix="/embeddings",
)
router = fastapi.APIRouter(prefix="/embeddings")

logger = config.get_logger()

Expand Down
1 change: 1 addition & 0 deletions proxy/app/routers/images/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Images router."""
75 changes: 75 additions & 0 deletions proxy/app/routers/images/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Business logic of the images router."""

import json

import boto3
import fastapi
from fastapi import status

from app.core import config
from app.routers.images import schemas

settings = config.get_settings()


def post_image_generation(
payload: schemas.PostImageGenerationRequest,
) -> schemas.PostImageGenerationResponse:
"""Performs the image generation business logic.
Args:
payload: The request body.
Returns:
The images, following OpenAI's API specification.
"""
if payload.provider == "aws":
images = [_aws_image_generation(payload) for _ in range(payload.n)]
else:
raise fastapi.HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Unknown provider.",
)
return schemas.PostImageGenerationResponse(data=images)


def _aws_image_generation(
payload: schemas.PostImageGenerationRequest,
) -> schemas.PostImageGenerationImage:
"""Runs image generation on AWS.
Args:
payload: The request body.
Returns:
A single image and its (revised) prompt. For the curent models, revised and
input prompts are identical.
"""
body = json.dumps(
{
"prompt": payload.prompt,
"mode": "text-to-image",
"aspect_ratio": "1:1",
"output_format": "jpeg",
},
)

bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=settings.AWS_REGION,
aws_access_key_id=settings.AWS_ACCESS_KEY.get_secret_value(),
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY.get_secret_value(),
)

response = bedrock.invoke_model(
body=body,
modelId=payload.model_name,
accept="application/json",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())
return schemas.PostImageGenerationImage(
b64_json=response_body["images"][0],
revised_prompt=payload.prompt,
)
42 changes: 42 additions & 0 deletions proxy/app/routers/images/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Type definitions used across multiple files in the embeddings router."""

import time
from typing import Literal

import pydantic

IMAGE_GENERATION_MODELS = Literal["aws/stability.sd3-large-v1:0"]


class PostImageGenerationRequest(pydantic.BaseModel):
"""Post request for image generation."""

model: IMAGE_GENERATION_MODELS
prompt: str
n: Literal[1]
size: str
response_format: Literal["b64_json"]

@property
def provider(self) -> str:
"""The model provider."""
return self.model.split("/")[0]

@property
def model_name(self) -> str:
"""The model name."""
return self.model.split("/")[1]


class PostImageGenerationImage(pydantic.BaseModel):
"""An image object for the image generation response."""

b64_json: str
revised_prompt: str


class PostImageGenerationResponse(pydantic.BaseModel):
"""The response for image generation."""

data: list[PostImageGenerationImage]
created: int = pydantic.Field(default_factory=lambda: int(time.time()))
32 changes: 32 additions & 0 deletions proxy/app/routers/images/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Views for the embeddings router."""

import fastapi

from app.core import config
from app.routers.images import controller, schemas

router = fastapi.APIRouter(prefix="/images")

logger = config.get_logger()


@router.post(
"/generations",
description="Fetches the images generated for a prompt.",
response_description="The requested images.",
)
def post_images(
payload: schemas.PostImageGenerationRequest,
) -> schemas.PostImageGenerationResponse:
"""Gets the embedding of a string.
Args:
payload: The message body, c.f. the model for details.
Returns:
The embedding response in OpenAI's formatting.
"""
logger.debug("Entering post-images endpoint.")
response = controller.post_image_generation(payload)
logger.debug("Exiting post-images endpoint.")
return response
56 changes: 56 additions & 0 deletions proxy/tests/endpoint/test_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Endpoint tests for the embeddings endpoints."""

import io
import json
from typing import Any

import pytest
import pytest_mock
from fastapi import status
from fastapi.testclient import TestClient

from app.core import auth
from app.main import app

client = TestClient(app)

app.dependency_overrides[auth.check_api_key] = lambda: None


@pytest.fixture
def valid_post_embedding_payload() -> dict[str, Any]:
"""A valid payload for the POStT /v1/embeddings endpoint."""
return {
"model": "aws/stability.sd3-large-v1:0",
"prompt": "example text",
"n": 1,
"size": "512x512",
"response_format": "b64_json",
}


def test_post_image_generation_endpoint(
valid_post_embedding_payload: dict[str, Any],
mocker: pytest_mock.MockerFixture,
) -> None:
"""Tests the happy-path of the post-embedding endpoint."""
mock_boto_client = mocker.patch("app.routers.images.controller.boto3.client")
mock_boto_client.return_value.invoke_model.return_value = {
"body": io.StringIO(
json.dumps(
{
"images": ["abc"],
},
),
),
}

response = client.post("/v1/images/generations", json=valid_post_embedding_payload)
response_data = response.json()

assert response.status_code == status.HTTP_200_OK
assert response_data["data"][0]["b64_json"] == "abc"
assert (
response_data["data"][0]["revised_prompt"]
== valid_post_embedding_payload["prompt"]
)

0 comments on commit 8280645

Please sign in to comment.