-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #61 from DARPA-ASKEM/35-support-for-model-configur…
…ation-editing Add support for model configuration editing with Mira
- Loading branch information
Showing
12 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
OPENAI_API_KEY={OPENAI_API_KEY} | ||
DATA_SERVICE_URL=http://localhost:8001 | ||
HMI_SERVER_URL=http://localhost:3000 | ||
HMI_SERVER_USER=user | ||
HMI_SERVER_PASSWORD=pass | ||
JUPYTER_SERVER=http://localhost:8888 | ||
JUPYTER_TOKEN=ebcec7fcf42f28baccfab1cbc07bfb3f | ||
ENABLE_USER_PROMPT=true |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import json | ||
import logging | ||
import re | ||
|
||
import requests | ||
from archytas.react import Undefined | ||
from archytas.tool_utils import AgentRef, LoopControllerRef, tool | ||
|
||
from beaker_kernel.lib.agent import BaseAgent | ||
from beaker_kernel.lib.context import BaseContext | ||
from beaker_kernel.lib.jupyter_kernel_proxy import JupyterMessage | ||
|
||
logging.disable(logging.WARNING) # Disable warnings | ||
logger = logging.Logger(__name__) | ||
|
||
|
||
class MiraConfigEditAgent(BaseAgent): | ||
""" | ||
LLM agent used for working with the Mira Modeling framework ("mira_model" package) in Python 3. | ||
This will be used to find pre-written functions which will be used to edit a model. | ||
A mira model is made up of multiple templates that are merged together like ... | ||
An example mira model will look like this when encoded in json: | ||
``` | ||
{ | ||
"id": "foo", | ||
"bar": "foobar", | ||
... | ||
} | ||
Instead of manipulating the model directly, the agent will always return code that will be run externally in a jupyter notebook. | ||
""" | ||
|
||
def __init__(self, context: BaseContext = None, tools: list = None, **kwargs): | ||
super().__init__(context, tools, **kwargs) | ||
|
||
@tool() | ||
async def get_parameters_initials(self, _type: str, agent: AgentRef, loop: LoopControllerRef): | ||
""" | ||
This tool is used when a user wants to see the names and values of the model configuration's parameters. | ||
Please generate the code as if you were programming inside a Jupyter Notebook and the code is to be executed inside a cell. | ||
You MUST wrap the code with a line containing three backticks (```) before and after the generated code. | ||
No addtional text is needed in the response, just the code block. | ||
Args: | ||
_type (str): either "parameters" or "initials" and determines whether to fetch values of the parameters or the initial conditions | ||
""" | ||
loop.set_state(loop.STOP_SUCCESS) | ||
if _type == "parameters": | ||
code = agent.context.get_code("get_params") | ||
elif _type == "initials": | ||
code = agent.context.get_code("get_initials") | ||
return json.dumps( | ||
{ | ||
"action": "code_cell", | ||
"language": "python3", | ||
"content": code.strip(), | ||
} | ||
) | ||
|
||
@tool() | ||
async def update_parameters(self, parameter_values: dict, agent: AgentRef, loop: LoopControllerRef): | ||
""" | ||
This tool is used when a user wants to update the model configuration parameter values. | ||
It takes in a dictionary where the key is the parameter name and the value is the new value in the form: | ||
``` | ||
{'param1': 10, | ||
'param_n: 2, | ||
...} | ||
``` | ||
Please generate the code as if you were programming inside a Jupyter Notebook and the code is to be executed inside a cell. | ||
You MUST wrap the code with a line containing three backticks (```) before and after the generated code. | ||
No addtional text is needed in the response, just the code block. | ||
Args: | ||
parameter_values (dict): the dictionary of parameter names and the values to update them with | ||
""" | ||
loop.set_state(loop.STOP_SUCCESS) | ||
code = agent.context.get_code("update_params", {"parameter_values": parameter_values}) | ||
return json.dumps( | ||
{ | ||
"action": "code_cell", | ||
"language": "python3", | ||
"content": code.strip(), | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import copy | ||
import datetime | ||
import json | ||
import logging | ||
import os | ||
from typing import TYPE_CHECKING, Any, Dict, Optional | ||
|
||
import requests | ||
from requests.auth import HTTPBasicAuth | ||
|
||
from beaker_kernel.lib.context import BaseContext | ||
from beaker_kernel.lib.utils import intercept | ||
|
||
from .agent import MiraConfigEditAgent | ||
|
||
if TYPE_CHECKING: | ||
from beaker_kernel.kernel import LLMKernel | ||
from beaker_kernel.lib.subkernels.base import BaseSubkernel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MiraConfigEditContext(BaseContext): | ||
|
||
agent_cls = MiraConfigEditAgent | ||
|
||
model_config_id: Optional[str] | ||
model_config_json: Optional[str] | ||
model_config_dict: Optional[dict[str, Any]] | ||
var_name: Optional[str] = "model_config" | ||
|
||
def __init__(self, beaker_kernel: "LLMKernel", subkernel: "BaseSubkernel", config: Dict[str, Any]) -> None: | ||
self.reset() | ||
logger.error("initializing...") | ||
super().__init__(beaker_kernel, subkernel, self.agent_cls, config) | ||
|
||
def reset(self): | ||
pass | ||
|
||
async def post_execute(self, message): | ||
pass | ||
|
||
async def setup(self, config, parent_header): | ||
logger.error(f"performing setup...") | ||
self.config = config | ||
item_id = config["id"] | ||
item_type = config.get("type", "model_config") | ||
logger.error(f"Processing {item_type} AMR {item_id} as a MIRA model") | ||
await self.set_model_config( | ||
item_id, item_type, parent_header=parent_header | ||
) | ||
|
||
async def set_model_config(self, item_id, agent=None, parent_header={}): | ||
self.config_id = item_id | ||
meta_url = f"{os.environ['HMI_SERVER_URL']}/model-configurations/{self.config_id}" | ||
logger.error(f"Meta url: {meta_url}") | ||
self.configuration = requests.get(meta_url, | ||
auth=(os.environ['HMI_SERVER_USER'], | ||
os.environ['HMI_SERVER_PASSWORD']) | ||
).json() | ||
logger.error(f"Succeeded in fetching model configuration, proceeding.") | ||
self.amr = self.configuration.get("configuration") | ||
self.original_amr = copy.deepcopy(self.amr) | ||
if self.amr: | ||
await self.load_mira() | ||
else: | ||
raise Exception(f"Model config '{item_id}' not found.") | ||
await self.send_mira_preview_message(parent_header=parent_header) | ||
|
||
async def load_mira(self): | ||
command = "\n".join( | ||
[ | ||
self.get_code("setup"), | ||
self.get_code("load_model", { | ||
"var_name": self.var_name, | ||
"model": self.amr, | ||
}), | ||
] | ||
) | ||
print(f"Running command:\n-------\n{command}\n---------") | ||
await self.execute(command) | ||
|
||
async def send_mira_preview_message( | ||
self, server=None, target_stream=None, data=None, parent_header={} | ||
): | ||
try: | ||
|
||
preview = await self.evaluate(self.get_code("model_preview"), {"var_name": self.var_name}) | ||
content = preview["return"] | ||
self.beaker_kernel.send_response( | ||
"iopub", "model_preview", content, parent_header=parent_header | ||
) | ||
except Exception as e: | ||
raise | ||
|
||
@intercept() | ||
async def save_model_config_request(self, message): | ||
''' | ||
Updates the model configuration in place. | ||
''' | ||
content = message.content | ||
|
||
new_model: dict = ( | ||
await self.evaluate( | ||
f"template_model_to_petrinet_json({self.var_name})" | ||
) | ||
)["return"] | ||
|
||
model_config = self.configuration | ||
model_config["configuration"] = new_model | ||
|
||
create_req = requests.put( | ||
f"{os.environ['HMI_SERVER_URL']}/model-configurations/{self.config_id}", json=model_config, | ||
auth =(os.environ['HMI_SERVER_USER'], os.environ['HMI_SERVER_PASSWORD']) | ||
) | ||
|
||
if create_req.status_code == 200: | ||
logger.error(f"Successfuly updated model config {self.config_id}") | ||
response_id = create_req.json()["id"] | ||
|
||
content = {"model_configuration_id": response_id} | ||
self.beaker_kernel.send_response( | ||
"iopub", "save_model_response", content, parent_header=message.header | ||
) |
3 changes: 3 additions & 0 deletions
3
src/askem_beaker/contexts/mira_config_edit/default_payload.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"id": "sir-model-id" | ||
} |
1 change: 1 addition & 0 deletions
1
src/askem_beaker/contexts/mira_config_edit/procedures/python3/get_initials.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
model_config.initials |
15 changes: 15 additions & 0 deletions
15
src/askem_beaker/contexts/mira_config_edit/procedures/python3/get_params.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
def get_params(model_config): | ||
params = "" | ||
for kk, vv in model_config.parameters.items(): | ||
if vv.display_name: | ||
display_name = f" ({vv.display_name})" | ||
else: | ||
display_name = "" | ||
if vv.units: | ||
units = f" ({vv.units})" | ||
else: | ||
units = "" | ||
params += f"{kk}{display_name}: {vv.value}{units}\n" | ||
return params | ||
|
||
print(get_params({{ var_name|default("model_config") }})) |
4 changes: 4 additions & 0 deletions
4
src/askem_beaker/contexts/mira_config_edit/procedures/python3/load_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import copy, requests | ||
amr_json = dict({{ model }}) | ||
{{ var_name|default("model_config") }} = model_from_json(amr_json) | ||
_model_orig = copy.deepcopy({{ var_name|default("model_config") }}) |
1 change: 1 addition & 0 deletions
1
src/askem_beaker/contexts/mira_config_edit/procedures/python3/metadata.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
13 changes: 13 additions & 0 deletions
13
src/askem_beaker/contexts/mira_config_edit/procedures/python3/model_preview.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from IPython.core.interactiveshell import InteractiveShell; | ||
from IPython.core import display_functions; | ||
from mira.modeling.amr.petrinet import template_model_to_petrinet_json | ||
|
||
format_dict, md_dict = InteractiveShell.instance().display_formatter.format(GraphicalModel.for_jupyter(model_config)) | ||
result = { | ||
"application/json": template_model_to_petrinet_json(model_config) | ||
} | ||
for key, value in format_dict.items(): | ||
if "image" in key: | ||
result[key] = value | ||
|
||
result |
4 changes: 4 additions & 0 deletions
4
src/askem_beaker/contexts/mira_config_edit/procedures/python3/setup.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import requests; import pandas as pd; import numpy as np; import scipy; | ||
import json; import mira; | ||
import sympy; import itertools; from mira.metamodel import *; from mira.modeling import Model; | ||
from mira.sources.amr import model_from_json; from mira.modeling.viz import GraphicalModel; |
3 changes: 3 additions & 0 deletions
3
src/askem_beaker/contexts/mira_config_edit/procedures/python3/update_params.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
parameter_values = {{ parameter_values['parameter_values'] }} | ||
for kk, vv in parameter_values.items(): | ||
{{ var_name|default("model_config") }}.parameters[kk].value = vv |