Skip to content

Commit

Permalink
Merge pull request #255 from runpod/fix-pydantic
Browse files Browse the repository at this point in the history
Fix pydantic
  • Loading branch information
justinmerrell authored Dec 14, 2023
2 parents 3059ed0 + 2bb5cfa commit c22b523
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## Release 1.4.2 (12/14/23)

### Fixed

- Added defaults for optional parameters in `rp_fastapi` to be compatible with pydantic.

## Release 1.4.1 (12/13/23)

### Added
Expand Down
43 changes: 31 additions & 12 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import os
import uuid
from dataclasses import dataclass
from typing import Union, Optional, Dict, Any

import uvicorn
from fastapi import FastAPI, APIRouter
from fastapi.encoders import jsonable_encoder
from fastapi.responses import RedirectResponse
from pydantic import BaseModel

from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
Expand Down Expand Up @@ -39,40 +39,45 @@


# ------------------------------- Input Objects ------------------------------ #
class Job(BaseModel):
@dataclass
class Job:
''' Represents a job. '''
id: str
input: Union[dict, list, str, int, float, bool]


class TestJob(BaseModel):
@dataclass
class TestJob:
''' Represents a test job.
input can be any type of data.
'''
id: Optional[str]
input: Optional[Union[dict, list, str, int, float, bool]]
id: Optional[str] = None
input: Optional[Union[dict, list, str, int, float, bool]] = None


class DefaultInput(BaseModel):
@dataclass
class DefaultInput:
""" Represents a test input. """
input: Dict[str, Any]


# ------------------------------ Output Objects ------------------------------ #
class JobOutput(BaseModel):
@dataclass
class JobOutput:
''' Represents the output of a job. '''
id: str
status: str
output: Optional[Union[dict, list, str, int, float, bool]]
error: Optional[str]
output: Optional[Union[dict, list, str, int, float, bool]] = None
error: Optional[str] = None


class StreamOutput(BaseModel):
@dataclass
class StreamOutput:
""" Stream representation of a job. """
id: str
status: str = "IN_PROGRESS"
stream: Optional[Union[dict, list, str, int, float, bool]]
error: Optional[str]
stream: Optional[Union[dict, list, str, int, float, bool]] = None
error: Optional[str] = None


# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -191,6 +196,13 @@ async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
else:
job_output = await run_job(self.config["handler"], job.__dict__)

if job_output.get('error', None):
return jsonable_encoder({
"id": job.id,
"status": "FAILED",
"error": job_output['error']
})

return jsonable_encoder({
"id": job.id,
"status": "COMPLETED",
Expand Down Expand Up @@ -253,6 +265,13 @@ async def _sim_status(self, job_id: str) -> JobOutput:

job_list.remove_job(job.id)

if job_output.get('error', None):
return jsonable_encoder({
"id": job_id,
"status": "FAILED",
"error": job_output['error']
})

return jsonable_encoder({
"id": job_id,
"status": "COMPLETED",
Expand Down
16 changes: 16 additions & 0 deletions tests/test_serverless/test_modules/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def setUp(self) -> None:
self.handler = Mock()
self.handler.return_value = {"result": "success"}

self.error_handler = Mock()
self.error_handler.side_effect = Exception("test error")

def test_start_serverless_with_realtime(self):
'''
Tests the start_serverless() method with the realtime option.
Expand Down Expand Up @@ -139,6 +142,12 @@ def generator_handler(job):
"output": [{"result": "success"}]
}

# Test with error handler
error_worker_api = rp_fastapi.WorkerAPI({"handler": self.error_handler})
error_runsync_return = asyncio.run(
error_worker_api._sim_runsync(default_input_object))
assert "error" in error_runsync_return

loop.close()

@pytest.mark.asyncio
Expand Down Expand Up @@ -243,4 +252,11 @@ def generator_handler(job):
"output": [{"result": "success"}]
}

# Test with error handler
error_worker_api = rp_fastapi.WorkerAPI({"handler": self.error_handler})
asyncio.run(error_worker_api._sim_run(default_input_object))
error_status_return = asyncio.run(
error_worker_api._sim_status("test-123"))
assert "error" in error_status_return

loop.close()

0 comments on commit c22b523

Please sign in to comment.