Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move towards grammar-driven generation to improve gameplay. #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ This demo has been recorded on an M1 Mac with 32 GB RAM. The demo is in real tim

## Running

Download an LLM to your local machine. Recommended model is [LLaMA2-13B-Psyfighter2-GGUF](https://huggingface.co/KoboldAI/LLaMA2-13B-Psyfighter2-GGUF), the Q4_K_M variant. Why is it recommended you may ask? Because this is the model I've made the game with 😂
Download an LLM to your local machine. Recommended model is [Big-Tiger-Gemma-27B-v1-GGUF](https://huggingface.co/TheDrummer/Big-Tiger-Gemma-27B-v1-GGUF), even the Q3_K_M is fine.

This game was initially made with LLaMA2-13B-Psyfighter2-GGUF, but after GBNF rewrite and some testing, Gemma turned out to be way more capable.

### Mac M*

Expand All @@ -48,10 +50,10 @@ If you don't like poetry, you can also install the dependencies manually using `

```bash
python3 -m cherryberry \
--model ../models/LLaMA2-13B-Psyfighter2.Q4_K_M.gguf \
--n_ctx 4096 \
--model ../models/Big-Tiger-Gemma-27B-v1c-Q3_K_M.gguf \
--n_ctx 8192 \
--n_batch 512 \
-ngl 1 \
-ngl 200 \
--threads 4
```

Expand Down Expand Up @@ -81,8 +83,8 @@ If you don't like poetry, you can also install the dependencies manually using `

```bash
python3 -m cherryberry \
--model ../models/LLaMA2-13B-Psyfighter2.Q4_K_M.gguf \
--n_ctx 4096 \
--model ../models/Big-Tiger-Gemma-27B-v1c-Q3_K_M.gguf \
--n_ctx 8192 \
--n_batch 512 \
-ngl 200 \
--threads 4
Expand Down Expand Up @@ -112,4 +114,10 @@ To run the game in the debug mode:
textual console
# In another
textual run --dev cherryberry.py --debug --model ...
```
```

## FAQ

**Q: Can this game support [my favourite LLM server] instead of llama.cpp?**

A: I'm afraid not. This game relies heavily on grammar-driven generation using [GBNF](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md), and it will fall apart if that's removed. As far as I know there is currently no standard support for grammars in LLM servers.
22 changes: 22 additions & 0 deletions language_model/grammars/10_find_exits.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
root ::= array "\n" "`"

val ::= "{" ws "\"exit_name\"" ":" ws shortstring "," ws "\"transition_text\"" ":" ws youstring "," ws "\"new_location_name\"" ":" ws shortstring "," ws "\"new_location_description\"" ":" ws shortstring "}" ws

array ::= "[" ws (val)? ( "," ws val )? ( "," ws val )? "]"

string ::=
"\"" (
[a-zA-Z, ]
)+ "\"" ws

shortstring ::=
"\"" (
[a-zA-Z ]
)+ "\"" ws

youstring ::=
"\"You " (
[a-zA-Z, ]
)+ "\"" ws

ws ::= [ \t\n]{0,10}
25 changes: 25 additions & 0 deletions language_model/grammars/15_find_items.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= array "\n" "`"

val ::= "{" ws "\"name\"" ":" ws shortstring "," ws "\"obtain_action\"" ":" ws youstring "," ws "\"description\"" ":" ws string "}" ws

array ::= "[" ws (val)? ( "," ws val )? ( "," ws val )? "]"

string ::=
"\"" (
[a-zA-Z ]
)+ "\"" ws

shortstring ::=
"\"" (
[a-zA-Z ]
)+ "\"" ws

youstring ::=
"\"You " (
[a-zA-Z, ]
)+ "\"" ws

number ::=
[1-9][0-9]* ws

ws ::= [ \t\n]{0,10}
8 changes: 8 additions & 0 deletions language_model/grammars/50_inventory_updates.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
root ::= object ws

object ::= "*" ws ( "Add " | "Remove " ) ws string ( nl "*" ws ( "Add" | "Remove" ) ws string )*

string ::= ( [a-zA-Z0-9 ] )*

ws ::= ([ ])*
nl ::= [\n]
1 change: 1 addition & 0 deletions language_model/grammars/general_text.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
root ::= [a-zA-Z, -:;]*. [a-zA-Z, -:;]*. [a-zA-Z, -:;]*.
109 changes: 86 additions & 23 deletions language_model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from llama_cpp import Llama, LlamaCache, LlamaDiskCache, llama_log_set
from llama_cpp import Llama, LlamaCache, LlamaDiskCache, llama_log_set, LlamaGrammar
from jinja2 import Environment, PackageLoader
import orjson
import re
Expand Down Expand Up @@ -49,7 +49,7 @@ def printg(self, message="", end="\n", flush=False):
def clearg(self):
self.queue.put(GenerateCleared(), block=False)

def generate_location(self, setting, history, requirements):
def generate_location(self, setting, history, loc_name, loc_description):
self.printb("[grey46][Generating location...][/]")

htext = ""
Expand All @@ -59,13 +59,17 @@ def generate_location(self, setting, history, requirements):
hist += [f"* {i}"]
htext = "The story so far, for context only:\n\n" + "\n".join(hist)

prompt = self.tmpl.get_template("00_generate_location.txt").render(
prompt = self.tmpl.get_template("00_generate_location.md").render(
{
"setting": setting,
"history": htext,
"requirements": requirements,
"loc_name": loc_name,
"loc_description": loc_description,
}
)

g = LlamaGrammar.from_file("language_model/grammars/general_text.gbnf")

if self.debug:
logging.debug(prompt)
stream = self.llm.create_completion(
Expand All @@ -77,6 +81,7 @@ def generate_location(self, setting, history, requirements):
top_k=200,
stop=["\n\n", "#"],
stream=True,
grammar=g,
)
out = ""
for output in stream:
Expand All @@ -90,7 +95,7 @@ def generate_location(self, setting, history, requirements):
return out

def generate_location_from_exit(
self, setting, history, previous, exit_name, exit_description
self, setting, history, previous, loc_name, loc_description
):
self.printb("[grey46][Generating location from exit...][/]")

Expand All @@ -101,15 +106,18 @@ def generate_location_from_exit(
hist += [f"* {i}"]
htext = "The story so far, for context only:\n\n" + "\n".join(hist)

prompt = self.tmpl.get_template("05_generate_location_from_exit.txt").render(
prompt = self.tmpl.get_template("05_generate_location_from_exit.md").render(
{
"setting": setting,
"history": htext,
"previous": previous,
"exit_name": exit_name,
"exit_description": exit_description,
"loc_name": loc_name,
"loc_description": loc_description,
}
)

g = LlamaGrammar.from_file("language_model/grammars/general_text.gbnf")

if self.debug:
logging.debug(prompt)
stream = self.llm.create_completion(
Expand All @@ -121,6 +129,7 @@ def generate_location_from_exit(
top_k=200,
stop=["\n\n", "#"],
stream=True,
grammar=g,
)
out = ""
for output in stream:
Expand All @@ -144,7 +153,7 @@ def action_items(self, description, inventory, action):
inv += [f"* {i}"]
inv = "\n".join(inv)

prompt = self.tmpl.get_template("20_action_items.txt").render(
prompt = self.tmpl.get_template("20_action_items.md").render(
{
"description": description,
"inventory": inv,
Expand Down Expand Up @@ -209,7 +218,7 @@ def consequences(self, setting, history, description, inventory, action):
inv += [f"* {i}"]
inv = "\n".join(inv)

prompt = self.tmpl.get_template("30_consequences.txt").render(
prompt = self.tmpl.get_template("30_consequences.md").render(
{
"setting": setting,
"history": htext,
Expand Down Expand Up @@ -244,7 +253,7 @@ def consequences(self, setting, history, description, inventory, action):
def update_description(self, setting, description, action, consequences):
self.printb("[grey46][Generating update description...][/]")

prompt = self.tmpl.get_template("40_update_description.txt").render(
prompt = self.tmpl.get_template("40_update_description.md").render(
{
"setting": setting,
"description": description,
Expand All @@ -261,7 +270,7 @@ def update_description(self, setting, description, action, consequences):
repeat_penalty=1.1,
top_p=0.95,
top_k=40,
stop=["#"],
stop=["\n\n", "#"],
stream=True,
)

Expand All @@ -287,7 +296,7 @@ def update_inventory(self, inventory, description, action, consequences):
inv += [f"* {i}"]
inv = "\n".join(inv)

prompt = self.tmpl.get_template("50_inventory_updates.txt").render(
prompt = self.tmpl.get_template("50_inventory_updates.md").render(
{
"inventory": inv,
"description": description,
Expand All @@ -297,6 +306,9 @@ def update_inventory(self, inventory, description, action, consequences):
)
if self.debug:
logging.debug(prompt)

g = LlamaGrammar.from_file("language_model/grammars/50_inventory_updates.gbnf")

stream = self.llm.create_completion(
prompt=prompt,
max_tokens=2048,
Expand All @@ -306,8 +318,9 @@ def update_inventory(self, inventory, description, action, consequences):
top_k=40,
stop=["#"],
stream=True,
grammar=g,
)
out = "1. "
out = ""
for output in stream:
out += output["choices"][0]["text"]
self.printg(output["choices"][0]["text"], end="", flush=True)
Expand All @@ -319,7 +332,7 @@ def update_inventory(self, inventory, description, action, consequences):
updates = out

self.printb("[grey46][Generating update inventory][/]")
prompt = self.tmpl.get_template("55_update_inventory.txt").render(
prompt = self.tmpl.get_template("55_update_inventory.md").render(
{
"inventory": inv,
"description": description,
Expand Down Expand Up @@ -373,24 +386,27 @@ def update_inventory(self, inventory, description, action, consequences):
def find_exits(self, setting, location_description):
self.printb("[grey46][Generating find exits...][/]")

prompt = self.tmpl.get_template("10_find_exits.txt").render(
prompt = self.tmpl.get_template("10_find_exits.md").render(
{"setting": setting, "description": location_description}
)
if self.debug:
logging.debug(prompt)

g = LlamaGrammar.from_file("language_model/grammars/10_find_exits.gbnf")

stream = self.llm.create_completion(
prompt=prompt,
max_tokens=512,
temperature=0.8,
repeat_penalty=1.1,
top_p=0.95,
top_k=40,
stop=["`", "#"],
stop=["`", "\n\n"],
stream=True,
grammar=g,
)

out = '{\n "'
out = ""
self.printg(out, end="")
for output in stream:
out += output["choices"][0]["text"]
Expand All @@ -402,6 +418,7 @@ def find_exits(self, setting, location_description):
logging.debug(out)

out = re.sub("`.*", "", out, re.M)
print(out)
try:
obj = orjson.loads(out)
self.clearg()
Expand All @@ -420,23 +437,69 @@ def find_exits(self, setting, location_description):
obj = orjson.loads(out + "} }")
self.clearg()
return obj
except:
pass
except Exception as exc:
self.clearg()
raise Exception(f"Unable to parse: {out}") from exc

def find_items(self, setting, location_description):
self.printb("[grey46][Generating find items...][/]")

prompt = self.tmpl.get_template("15_find_items.md").render(
{"setting": setting, "description": location_description}
)
if self.debug:
logging.debug(prompt)

g = LlamaGrammar.from_file("language_model/grammars/15_find_items.gbnf")

stream = self.llm.create_completion(
prompt=prompt,
max_tokens=512,
temperature=0.8,
repeat_penalty=1.1,
top_p=0.95,
top_k=40,
stop=["`", "\n\n"],
stream=True,
grammar=g,
)

self.printg("[Attempting to fix JSON...]")
out = self.json_fixer(out)
out = ""
self.printg(out, end="")
for output in stream:
out += output["choices"][0]["text"]
self.printg(output["choices"][0]["text"], end="", flush=True)
self.printg()

out = out.strip()
if self.debug:
logging.debug(out)

out = re.sub("`.*", "", out, re.M)
try:
obj = orjson.loads(out)
self.clearg()
return obj
except:
pass

try:
obj = orjson.loads(out + "}")
self.clearg()
return obj
except:
pass

try:
obj = orjson.loads(out + "} }")
self.clearg()
return obj
except Exception as exc:
self.clearg()
raise Exception(f"Unable to parse: {out}") from exc

def json_fixer(self, json_str):
prompt = self.tmpl.get_template("99_json_fixer.txt").render(
prompt = self.tmpl.get_template("99_json_fixer.md").render(
{
"json": json_str,
}
Expand Down
Loading