Skip to content

Commit

Permalink
Use random seeds in Sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
mondaychen authored and brandonwillard committed Sep 11, 2023
1 parent e83cd73 commit cbc7c7f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
34 changes: 27 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ from pydantic import BaseModel, constr
import outlines.models as models
import outlines.text.generate as generate

import torch


class Weapon(str, Enum):
sword = "sword"
Expand All @@ -219,20 +221,38 @@ class Character(BaseModel):
strength: int


model = models.transformers("gpt2")
sequence = generate.json(model, Character)("Give me a character description")
model = models.transformers("gpt2", device="cuda")

# Construct guided sequence generator
generator = generate.json(model, Character, max_tokens=100)

# Draw a sample
rng = torch.Generator(device="cuda")
rng.manual_seed(789001)

sequence = generator("Give me a character description", rng=rng)
print(sequence)
# {
# "name": "clerame",
# "age": 7,
# "armor": "plate",
# "weapon": "mace",
# "strength": 4171
# }

sequence = generator("Give me an interesting character description", rng=rng)
print(sequence)
# {
# "name": "ranbelt",
# "age": 26,
# "name": "piggyback",
# "age": 23,
# "armor": "chainmail",
# "weapon": "bow",
# "strength": 5
# "weapon": "sword",
# "strength": 0
# }

parsed = Character.model_validate_json(sequence)
print(parsed)
# name='ranbelt' age=26 armor=<Armor.chainmail: 'chainmail'> weapon=<Weapon.bow: 'bow'> strength=5
# name='piggyback' age=23 armor=<Armor.chainmail: 'chainmail'> weapon=<Weapon.sword: 'sword'> strength=0
```

The method works with union types, optional types, arrays, nested schemas, etc. Some field constraints are [not supported yet](https://github.com/normal-computing/outlines/issues/215), but everything else should work.
Expand Down
1 change: 1 addition & 0 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __call__(

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

num_prompt_tokens = token_ids.shape[-1]

Expand Down

0 comments on commit cbc7c7f

Please sign in to comment.