-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from HumanCompatibleAI/fastapi_example
New example that uses fastapi, pydantic models for the api schema
- Loading branch information
Showing
16 changed files
with
528 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Basic ranker example (nltk) | ||
|
||
This is a toy example that ranks a set of content items based on their sentiment using nltk. | ||
|
||
## Data models | ||
|
||
This example uses [pydantic](https://docs.pydantic.dev/) to validate the schema for requests and responses. | ||
|
||
## Setting up your environment | ||
|
||
1. Create a virtual environment using your preferred method | ||
2. `pip install -r requirements.txt` | ||
|
||
## Running tests | ||
|
||
Just run `pytest` | ||
|
||
## Running the service in development | ||
|
||
```bash | ||
uvicorn ranking_server:app --reload | ||
``` | ||
|
||
This will spin up a server on `http://127.0.0.1:8000` | ||
|
||
## Executing your server outside of a unit test | ||
|
||
You can start the server and then run `caller.py` to send it data, or you can use the interface provided by FastAPI | ||
|
||
## Automatically-generated api docs | ||
|
||
With a running server, visit `http://127.0.0.1:8000/docs`. You can send requests from there too. | ||
|
||
## Running the service in production | ||
|
||
```bash | ||
uvicorn ranking_server:app --host 0.0.0.0 --port 5000 | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import time | ||
from pprint import pprint | ||
|
||
import requests | ||
|
||
import sample_data | ||
|
||
# This is a simple script that sends a POST request to the API and prints the response. | ||
# Your server should be running on localhost:8000 | ||
|
||
# Wait for the Flask app to start up | ||
time.sleep(2) | ||
|
||
# Send POST request to the API | ||
response = requests.post("http://localhost:8000/rank", json=sample_data.BASIC_EXAMPLE) | ||
|
||
# Check if the request was successful (status code 200) | ||
if response.status_code == 200: | ||
try: | ||
# Attempt to parse the JSON response | ||
json_response = response.json() | ||
pprint(json_response) | ||
except requests.exceptions.JSONDecodeError: | ||
print("Failed to parse JSON response. Response may be empty.") | ||
else: | ||
print(f"Request failed with status code: {response.status_code}") | ||
print(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
import sys | ||
import inspect | ||
|
||
parentdir = os.path.dirname( # make it possible to import from ../ in a reliable way | ||
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | ||
) | ||
sys.path.insert(0, parentdir) | ||
|
||
from fastapi import FastAPI | ||
import nltk | ||
from nltk.sentiment.vader import SentimentIntensityAnalyzer | ||
|
||
from models.request import RankingRequest | ||
from models.response import RankingResponse | ||
from fastapi_nltk.sample_data import NEW_POSTS | ||
|
||
nltk.download("vader_lexicon") | ||
|
||
analyzer = SentimentIntensityAnalyzer() | ||
|
||
app = FastAPI( | ||
title="Prosocial Ranking Challenge nltk example", | ||
description="Ranks input by sentiment using nltk's VADER sentiment analysis.", | ||
version="0.1.0", | ||
) | ||
|
||
|
||
@app.post("/rank") | ||
def rank(ranking_request: RankingRequest) -> RankingResponse: | ||
ranked_results = [] | ||
for item in ranking_request.items: | ||
scores = analyzer.polarity_scores(item.text) | ||
sentiment = ( | ||
"positive" | ||
if scores["compound"] > 0 | ||
else "negative" if scores["compound"] < 0 else "neutral" | ||
) | ||
ranked_results.append( | ||
{"id": item.id, "text": item.text, "sentiment": sentiment, "scores": scores} | ||
) | ||
|
||
ranked_results.sort(key=lambda x: x["scores"]["compound"], reverse=True) | ||
ranked_ids = [content["id"] for content in ranked_results] | ||
|
||
# Add a new post (not part of the candidate set) to the top of the result | ||
new_post = NEW_POSTS[0] | ||
ranked_ids.insert(0, new_post["id"]) | ||
|
||
result = { | ||
"ranked_ids": ranked_ids, | ||
"new_items": [ | ||
{ | ||
"id": new_post["id"], | ||
"url": new_post["url"], | ||
} | ||
], | ||
} | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import json | ||
import os | ||
import sys | ||
import inspect | ||
|
||
parentdir = os.path.dirname( # make it possible to import from ../ in a reliable way | ||
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | ||
) | ||
sys.path.insert(0, parentdir) | ||
|
||
import pytest | ||
|
||
from fastapi.testclient import TestClient | ||
|
||
from fastapi_nltk import ranking_server | ||
from fastapi_nltk import sample_data | ||
|
||
|
||
@pytest.fixture | ||
def app(): | ||
app = ranking_server.app | ||
yield app | ||
|
||
|
||
@pytest.fixture | ||
def client(app): | ||
return TestClient(app) | ||
|
||
|
||
def test_rank(client): | ||
# Send POST request to the API | ||
response = client.post("/rank", json=sample_data.BASIC_EXAMPLE) | ||
|
||
# Check if the request was successful (status code 200) | ||
if response.status_code != 200: | ||
print(f"Request failed with status code: {response.status_code}") | ||
print(json.dumps(response.json(), indent=2)) | ||
assert False | ||
|
||
result = response.json() | ||
|
||
# Check if the response contains the expected ids, in the expected order | ||
assert result["ranked_ids"] == [ | ||
"571775f3-2564-4cf5-b01c-f4cb6bab461b", | ||
"s5ad13266-8abk4-5219-kre5-2811022l7e43dv", | ||
"a4c08177-8db2-4507-acc1-1298220be98d", | ||
"de83fc78-d648-444e-b20d-853bf05e4f0e", | ||
] | ||
|
||
# check for inserted posts | ||
assert result["new_items"] == [ | ||
{ | ||
"id": "571775f3-2564-4cf5-b01c-f4cb6bab461b", | ||
"url": "https://reddit.com/r/PRCExample/comments/1f33ead/example_to_insert", | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
fastapi>=0.110.0 | ||
nltk | ||
pytest | ||
requests | ||
httpx | ||
uvicorn[standard] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Sample data containing multiple text items | ||
|
||
BASIC_EXAMPLE = { | ||
"session": { | ||
"user_id": "193a9e01-8849-4e1f-a42a-a859fa7f2ad3", | ||
"user_name_hash": "6511c5688bbb87798128695a283411a26da532df06e6e931a53416e379ddda0e", | ||
"platform": "reddit", | ||
"current_time": "2024-01-20 18:41:20", | ||
}, | ||
"items": [ | ||
{ | ||
"id": "de83fc78-d648-444e-b20d-853bf05e4f0e", | ||
"title": "this is the post title, available only on reddit", | ||
"text": "this is the worst thing I have ever seen!", | ||
"author_name_hash": "60b46b7370f80735a06b7aa8c4eb6bd588440816b086d5ef7355cf202a118305", | ||
"type": "post", | ||
"created_at": "2023-12-06 17:02:11", | ||
"engagements": {"upvote": 34, "downvote": 27, "comment": 20, "award": 0}, | ||
}, | ||
{ | ||
"id": "s5ad13266-8abk4-5219-kre5-2811022l7e43dv", | ||
"post_id": "de83fc78-d648-444e-b20d-853bf05e4f0e", | ||
"parent_id": "", | ||
"text": "this is amazing!", | ||
"author_name_hash": "60b46b7370f80735a06b7aa8c4eb6bd588440816b086d5ef7355cf202a118305", | ||
"type": "comment", | ||
"created_at": "2023-12-08 11:32:12", | ||
"engagements": {"upvote": 15, "downvote": 2, "comment": 22, "award": 2}, | ||
}, | ||
{ | ||
"id": "a4c08177-8db2-4507-acc1-1298220be98d", | ||
"post_id": "de83fc78-d648-444e-b20d-853bf05e4f0e", | ||
"parent_id": "s5ad13266-8abk4-5219-kre5-2811022l7e43dv", | ||
"text": "this thing is ok.", | ||
"author_name_hash": "60b46b7370f80735a06b7aa8c4eb6bd588440816b086d5ef7355cf202a118305", | ||
"type": "comment", | ||
"created_at": "2023-12-08 11:35:00", | ||
"engagements": {"upvote": 3, "downvote": 5, "comment": 10, "award": 0}, | ||
}, | ||
], | ||
} | ||
|
||
# some new posts that can be added to the response | ||
NEW_POSTS = [ | ||
{ | ||
"id": "571775f3-2564-4cf5-b01c-f4cb6bab461b", | ||
"url": "https://reddit.com/r/PRCExample/comments/1f33ead/example_to_insert", | ||
}, | ||
{ | ||
"id": "1fcbb164-f81f-4532-b068-2561941d0f63", | ||
"url": "https://reddit.com/r/PRCExample/comments/ef56a23/another_example_to_insert", | ||
}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# pydantic models for the PRC API schema | ||
|
||
You can use these models in your Python code, both to generate valid data, and to parse incoming data. | ||
|
||
Using the models ensures that your data has been at least somewhat validated. If the schema changes and your code needs an update, you're more likely to be able to tell right away. | ||
|
||
## Parsing a request | ||
|
||
### With FastAPI | ||
|
||
If you're using fastapi, you can use the models right in your server: | ||
|
||
```python | ||
from models.request import RankingRequest | ||
from models.response import RankingResponse | ||
|
||
@app.post("/rank") | ||
def rank(ranking_request: RankingRequest) -> RankingResponse: | ||
... | ||
# You can return a RankingResponse here, or a dict with the correct keys and | ||
# pydantic will figure it out. | ||
``` | ||
|
||
If you specify `RankingResponse` as your reeturn type, you will get validation of your response for free. | ||
|
||
For a complete example, check out `../fastapi_nltk/` | ||
|
||
### Otherwise | ||
|
||
If you'd like to parse a request directly, here is how: | ||
|
||
```python | ||
from models.request import RankingRequest | ||
|
||
loaded_request = RankingRequest.model_validate_json(json_data) | ||
``` | ||
|
||
## Generating fake data | ||
|
||
There is a fake data generator in `fake.py`. If you run it directly it'll print some. You can also import it and run `fake_request()` or `fake_response()`. Take a look at the test for a usage example. | ||
|
||
## More | ||
|
||
[The pydantic docs](https://docs.pydantic.dev/latest/) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import hashlib | ||
import inspect | ||
import os | ||
from random import randint | ||
import sys | ||
import time | ||
from uuid import uuid4 | ||
|
||
parentdir = os.path.dirname( # make it possible to import from ../ in a reliable way | ||
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | ||
) | ||
sys.path.insert(0, parentdir) | ||
|
||
from faker import Faker | ||
|
||
fake = Faker(locale="la") # remove locale to get rid of the fake latin | ||
|
||
from models.request import ContentItem, RankingRequest, Session | ||
from models.response import RankingResponse | ||
|
||
def fake_request(n_items=1): | ||
return RankingRequest( | ||
session=Session( | ||
user_id=str(uuid4()), | ||
user_name_hash=hashlib.sha256(fake.name().encode()).hexdigest(), | ||
platform="reddit", | ||
current_time=time.time(), | ||
), | ||
items=[fake_item() for _ in range(n_items)] | ||
|
||
) | ||
|
||
def fake_item(): | ||
return ContentItem( | ||
id=str(uuid4()), | ||
text=fake.text(), | ||
author_name_hash=hashlib.sha256(fake.name().encode()).hexdigest(), | ||
type="post", | ||
created_at=time.time(), | ||
engagements={ | ||
"upvote": randint(0, 50), | ||
"downvote": randint(0, 50), | ||
"comment": randint(0, 50), | ||
"award": randint(0, 50) | ||
}, | ||
) | ||
|
||
def fake_response(ids, n_new_items=1): | ||
new_items = [fake_new_item() for _ in range(n_new_items)] | ||
|
||
ids = list(ids) + [item["id"] for item in new_items] | ||
|
||
return RankingResponse( | ||
ranked_ids=ids, | ||
new_items=new_items | ||
) | ||
|
||
def fake_new_item(): | ||
return { | ||
"id": str(uuid4()), | ||
"url": fake.url(), | ||
} | ||
|
||
# if run from command line | ||
if __name__ == "__main__": | ||
request = fake_request(3) | ||
print("Request:") | ||
print(request.model_dump_json(indent=2)) | ||
|
||
# use ids from request | ||
response = fake_response([item.id for item in request.items], 2) | ||
print("\nResponse:") | ||
print(response.model_dump_json(indent=2)) |
Oops, something went wrong.