-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add image generation [DCH-127] (#6)
* feat: Add a proxy for cohere-embed-english-v3 * feat: Add a Docker release * feaet: Add CI testing * fix: Switch directory in testing * fix: Attempt to fix directory management * fix: Attempt to fix directory in CI * feat: Add mypy/ruff dependencies * fix: Formatting issues * fix formatting * fix: ruff and env variable fixes * feat: Add image generation
- Loading branch information
1 parent
db66052
commit 8280645
Showing
10 changed files
with
219 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[mypy] | ||
files = proxy | ||
warn_unused_configs = True | ||
namespace_packages = True | ||
ignore_missing_imports = True | ||
strict = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Images router.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Business logic of the images router.""" | ||
|
||
import json | ||
|
||
import boto3 | ||
import fastapi | ||
from fastapi import status | ||
|
||
from app.core import config | ||
from app.routers.images import schemas | ||
|
||
settings = config.get_settings() | ||
|
||
|
||
def post_image_generation( | ||
payload: schemas.PostImageGenerationRequest, | ||
) -> schemas.PostImageGenerationResponse: | ||
"""Performs the image generation business logic. | ||
Args: | ||
payload: The request body. | ||
Returns: | ||
The images, following OpenAI's API specification. | ||
""" | ||
if payload.provider == "aws": | ||
images = [_aws_image_generation(payload) for _ in range(payload.n)] | ||
else: | ||
raise fastapi.HTTPException( | ||
status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
"Unknown provider.", | ||
) | ||
return schemas.PostImageGenerationResponse(data=images) | ||
|
||
|
||
def _aws_image_generation( | ||
payload: schemas.PostImageGenerationRequest, | ||
) -> schemas.PostImageGenerationImage: | ||
"""Runs image generation on AWS. | ||
Args: | ||
payload: The request body. | ||
Returns: | ||
A single image and its (revised) prompt. For the curent models, revised and | ||
input prompts are identical. | ||
""" | ||
body = json.dumps( | ||
{ | ||
"prompt": payload.prompt, | ||
"mode": "text-to-image", | ||
"aspect_ratio": "1:1", | ||
"output_format": "jpeg", | ||
}, | ||
) | ||
|
||
bedrock = boto3.client( | ||
service_name="bedrock-runtime", | ||
region_name=settings.AWS_REGION, | ||
aws_access_key_id=settings.AWS_ACCESS_KEY.get_secret_value(), | ||
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY.get_secret_value(), | ||
) | ||
|
||
response = bedrock.invoke_model( | ||
body=body, | ||
modelId=payload.model_name, | ||
accept="application/json", | ||
contentType="application/json", | ||
) | ||
|
||
response_body = json.loads(response.get("body").read()) | ||
return schemas.PostImageGenerationImage( | ||
b64_json=response_body["images"][0], | ||
revised_prompt=payload.prompt, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""Type definitions used across multiple files in the embeddings router.""" | ||
|
||
import time | ||
from typing import Literal | ||
|
||
import pydantic | ||
|
||
IMAGE_GENERATION_MODELS = Literal["aws/stability.sd3-large-v1:0"] | ||
|
||
|
||
class PostImageGenerationRequest(pydantic.BaseModel): | ||
"""Post request for image generation.""" | ||
|
||
model: IMAGE_GENERATION_MODELS | ||
prompt: str | ||
n: Literal[1] | ||
size: str | ||
response_format: Literal["b64_json"] | ||
|
||
@property | ||
def provider(self) -> str: | ||
"""The model provider.""" | ||
return self.model.split("/")[0] | ||
|
||
@property | ||
def model_name(self) -> str: | ||
"""The model name.""" | ||
return self.model.split("/")[1] | ||
|
||
|
||
class PostImageGenerationImage(pydantic.BaseModel): | ||
"""An image object for the image generation response.""" | ||
|
||
b64_json: str | ||
revised_prompt: str | ||
|
||
|
||
class PostImageGenerationResponse(pydantic.BaseModel): | ||
"""The response for image generation.""" | ||
|
||
data: list[PostImageGenerationImage] | ||
created: int = pydantic.Field(default_factory=lambda: int(time.time())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Views for the embeddings router.""" | ||
|
||
import fastapi | ||
|
||
from app.core import config | ||
from app.routers.images import controller, schemas | ||
|
||
router = fastapi.APIRouter(prefix="/images") | ||
|
||
logger = config.get_logger() | ||
|
||
|
||
@router.post( | ||
"/generations", | ||
description="Fetches the images generated for a prompt.", | ||
response_description="The requested images.", | ||
) | ||
def post_images( | ||
payload: schemas.PostImageGenerationRequest, | ||
) -> schemas.PostImageGenerationResponse: | ||
"""Gets the embedding of a string. | ||
Args: | ||
payload: The message body, c.f. the model for details. | ||
Returns: | ||
The embedding response in OpenAI's formatting. | ||
""" | ||
logger.debug("Entering post-images endpoint.") | ||
response = controller.post_image_generation(payload) | ||
logger.debug("Exiting post-images endpoint.") | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Endpoint tests for the embeddings endpoints.""" | ||
|
||
import io | ||
import json | ||
from typing import Any | ||
|
||
import pytest | ||
import pytest_mock | ||
from fastapi import status | ||
from fastapi.testclient import TestClient | ||
|
||
from app.core import auth | ||
from app.main import app | ||
|
||
client = TestClient(app) | ||
|
||
app.dependency_overrides[auth.check_api_key] = lambda: None | ||
|
||
|
||
@pytest.fixture | ||
def valid_post_embedding_payload() -> dict[str, Any]: | ||
"""A valid payload for the POStT /v1/embeddings endpoint.""" | ||
return { | ||
"model": "aws/stability.sd3-large-v1:0", | ||
"prompt": "example text", | ||
"n": 1, | ||
"size": "512x512", | ||
"response_format": "b64_json", | ||
} | ||
|
||
|
||
def test_post_image_generation_endpoint( | ||
valid_post_embedding_payload: dict[str, Any], | ||
mocker: pytest_mock.MockerFixture, | ||
) -> None: | ||
"""Tests the happy-path of the post-embedding endpoint.""" | ||
mock_boto_client = mocker.patch("app.routers.images.controller.boto3.client") | ||
mock_boto_client.return_value.invoke_model.return_value = { | ||
"body": io.StringIO( | ||
json.dumps( | ||
{ | ||
"images": ["abc"], | ||
}, | ||
), | ||
), | ||
} | ||
|
||
response = client.post("/v1/images/generations", json=valid_post_embedding_payload) | ||
response_data = response.json() | ||
|
||
assert response.status_code == status.HTTP_200_OK | ||
assert response_data["data"][0]["b64_json"] == "abc" | ||
assert ( | ||
response_data["data"][0]["revised_prompt"] | ||
== valid_post_embedding_payload["prompt"] | ||
) |