diff --git a/README.md b/README.md index d91dc1356..3c90b75bb 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,8 @@ from pydantic import BaseModel, constr import outlines.models as models import outlines.text.generate as generate +import torch + class Weapon(str, Enum): sword = "sword" @@ -219,20 +221,38 @@ class Character(BaseModel): strength: int -model = models.transformers("gpt2") -sequence = generate.json(model, Character)("Give me a character description") +model = models.transformers("gpt2", device="cuda") + +# Construct guided sequence generator +generator = generate.json(model, Character, max_tokens=100) + +# Draw a sample +rng = torch.Generator(device="cuda") +rng.manual_seed(789001) + +sequence = generator("Give me a character description", rng=rng) +print(sequence) +# { +# "name": "clerame", +# "age": 7, +# "armor": "plate", +# "weapon": "mace", +# "strength": 4171 +# } + +sequence = generator("Give me an interesting character description", rng=rng) print(sequence) # { -# "name": "ranbelt", -# "age": 26, +# "name": "piggyback", +# "age": 23, # "armor": "chainmail", -# "weapon": "bow", -# "strength": 5 +# "weapon": "sword", +# "strength": 0 # } parsed = Character.model_validate_json(sequence) print(parsed) -# name='ranbelt' age=26 armor= weapon= strength=5 +# name='piggyback' age=23 armor= weapon= strength=0 ``` The method works with union types, optional types, arrays, nested schemas, etc. Some field constraints are [not supported yet](https://github.com/normal-computing/outlines/issues/215), but everything else should work. diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 77edcfc0a..699f12c5e 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -199,6 +199,7 @@ def __call__( if rng is None: rng = torch.Generator(device=self.device) + rng.seed() num_prompt_tokens = token_ids.shape[-1]