Skip to content

Commit

Permalink
add receipt cookbook
Browse files Browse the repository at this point in the history
  • Loading branch information
cpfiffer committed Nov 7, 2024
1 parent 5f39ded commit 679aa54
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 0 deletions.
Binary file added docs/cookbook/images/trader-joes-receipt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
283 changes: 283 additions & 0 deletions docs/cookbook/receipt-digitization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Receipt Data Extraction with VLMs

## Setup

You'll need to install the dependencies:

```bash
pip install outlines torch==2.4.0 transformers accelerate pillow rich
```

## Import libraries

Load all the necessary libraries:

```python
# LLM stuff
import outlines
import torch
from transformers import AutoProcessor
from pydantic import BaseModel, Field
from typing import Literal, Optional, List

# Image stuff
from PIL import Image
import requests

# Rich for pretty printing
from rich import print
```

## Choose a model

This example has been tested with `mistral-community/pixtral-12b` ([HF link](https://huggingface.co/mistral-community/pixtral-12b)) and `Qwen/Qwen2-VL-7B-Instruct` ([HF link](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)).

We recommend Qwen-2-VL as we have found it to be more accurate than Pixtral.

If you want to use Qwen-2-VL, you can do the following:

```python
# To use Qwen-2-VL:
from transformers import Qwen2VLForConditionalGeneration
model_name = "Qwen/Qwen2-VL-7B-Instruct"
model_class = Qwen2VLForConditionalGeneration
```

If you want to use Pixtral, you can do the following:

```python
# To use Pixtral:
from transformers import LlavaForConditionalGeneration
model_name="mistral-community/pixtral-12b"
model_class=LlavaForConditionalGeneration
```

## Load the model

Load the model into memory:

```python
model = outlines.models.transformers_vision(
model_name,
model_class=model_class,
model_kwargs={
"device_map": "auto",
"torch_dtype": torch.bfloat16,
},
processor_kwargs={
"device": "cuda", # set to "cpu" if you don't have a GPU
},
)
```

## Image processing

Images can be quite large. In GPU-poor environments, you may need to resize the image to a smaller size.

Here's a helper function to do that:

```python
def load_and_resize_image(image_path, max_size=1024):
"""
Load and resize an image while maintaining aspect ratio
Args:
image_path: Path to the image file
max_size: Maximum dimension (width or height) of the output image
Returns:
PIL Image: Resized image
"""
image = Image.open(image_path)

# Get current dimensions
width, height = image.size

# Calculate scaling factor
scale = min(max_size / width, max_size / height)

# Only resize if image is larger than max_size
if scale < 1:
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

return image
```

You can change the resolution of the image by changing the `max_size` argument. Small max sizes will make the image more blurry, but processing will be faster and require less memory.

## Load an image

Load an image and resize it. We've provided a sample image of a Trader Joe's receipt, but you can use any image you'd like.

Here's what the image looks like:

![Trader Joe's receipt](./images/trader-joes-receipt.jpg)

```python
# Path to the image
image_path = "https://dottxt-ai.github.io/outlines/main/cookbook/images/trader-joes-receipt.png"

# Download the image
response = requests.get(image_path)
with open("receipt.png", "wb") as f:
f.write(response.content)

# Load + resize the image
image = load_and_resize_image("receipt.png")
```

## Define the output structure

We'll define a Pydantic model to describe the data we want to extract from the image. After processing the image, the LLM will output data in this format -- in this case, we'll have a list of items, a store name, address, and so on.

```python
class Item(BaseModel):
name: str
quantity: Optional[int]
price_per_unit: Optional[float]
total_price: Optional[float]

class ReceiptSummary(BaseModel):
store_name: str
store_address: str
store_number: Optional[int]
items: List[Item]
tax: Optional[float]
total: Optional[float]
# Date is in the format YYYY-MM-DD. We can apply a regex pattern to ensure it's formatted correctly.
date: Optional[str] = Field(pattern=r'\d{4}-\d{2}-\d{2}', description="Date in the format YYYY-MM-DD")
payment_method: Literal["cash", "credit", "debit", "check", "other"]
```

## Prepare the prompt

We'll use the `AutoProcessor` to convert the image and the text prompt into a format that the model can understand. Practically,
this is the code that adds user, system, assistant, and image tokens to the prompt.

```python
# Set up the content you want to send to the model
messages = [
{
"role": "user",
"content": [
{
# The image is provided as a PIL Image object
"type": "image",
"image": image,
},
{
"type": "text",
"text": f"""You are an expert at extracting information from receipts.
Please extract the information from the receipt. Be as detailed as possible --
missing or misreporting information is a crime.
Return the information in the following JSON schema:
{ReceiptSummary.model_json_schema()}
"""},
],
}
]

# Convert the messages to the final prompt
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
```

If you are curious, the final prompt that is sent to the model looks (roughly) like this:

```
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>
You are an expert at extracting information from receipts.
Please extract the information from the receipt. Be as detailed as
possible -- missing or misreporting information is a crime.
Return the information in the following JSON schema:
<JSON SCHEMA OMITTED>
<|im_end|>
<|im_start|>assistant
```

## Run the model

```python
# Prepare a function to process receipts
receipt_summary_generator = outlines.generate.json(
model,
ReceiptSummary,

# Greedy sampling is a good idea for numeric
# data extraction -- no randomness.
sampler=outlines.samplers.greedy()
)

# Generate the receipt summary
result = receipt_summary_generator(prompt, [image])
print(result)
```

## Output

The output should look like this:

```
ReceiptSummary(
store_name="Trader Joe's",
store_address='401 Bay Street, San Francisco, CA 94133',
store_number=0,
items=[
Item(name='BANANA EACH', quantity=7, price_per_unit=0.23, total_price=1.61),
Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29),
Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29),
Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29),
Item(name='BAREBELLS CARAMEL CASHEW', quantity=2, price_per_unit=2.29, total_price=4.58),
Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29),
Item(name='SPINDRIFT ORANGE MANGO 8', quantity=1, price_per_unit=7.49, total_price=7.49),
Item(name='Bottle Deposit', quantity=8, price_per_unit=0.05, total_price=0.4),
Item(name='MILK ORGANIC GALLON WHOL', quantity=1, price_per_unit=6.79, total_price=6.79),
Item(name='CLASSIC GREEK SALAD', quantity=1, price_per_unit=3.49, total_price=3.49),
Item(name='COBB SALAD', quantity=1, price_per_unit=5.99, total_price=5.99),
Item(name='PEPPER BELL RED XL EACH', quantity=1, price_per_unit=1.29, total_price=1.29),
Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25),
Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25)
],
tax=0.68,
total=41.98,
date='2023-11-04',
payment_method='debit',
)
```

Voila! You've successfully extracted information from a receipt using an LLM.

## Bonus: roasting the user for their receipt

You can roast the user for their receipt by adding a `roast` field to the end of the `ReceiptSummary` model.

```python
class ReceiptSummary(BaseModel):
...
roast: str
```

which gives you a result like

```
ReceiptSummary(
...
roast="You must be a fan of Trader Joe's because you bought enough
items to fill a small grocery bag and still had to pay for a bag fee.
Maybe you should start using reusable bags to save some money and the
environment."
)
```

Qwen is not particularly funny, but worth a shot.

0 comments on commit 679aa54

Please sign in to comment.