diff --git a/src/askem_beaker/contexts/model_configuration/agent.py b/src/askem_beaker/contexts/model_configuration/agent.py index 3436a10..df1cce1 100644 --- a/src/askem_beaker/contexts/model_configuration/agent.py +++ b/src/askem_beaker/contexts/model_configuration/agent.py @@ -31,12 +31,13 @@ async def generate_code( self, query: str, agent: AgentRef, loop: LoopControllerRef ) -> None: """ - Generated code to be run in an interactive Jupyter notebook for the purpose of modifying a model configuration. + Generated code to be run in an interactive Jupyter notebook for the purpose of modifying a model configuration. This may include modifying + the configuration based on an available dataset. If the user mentions a dataset, it will always be a Pandas DataFrame called `dataset`. - Input is a full grammatically correct question about or request for an action to be performed on the loaded model configuration. + Input is a full grammatically correct question about or request for an action to be performed on the loaded model configuration (and optionally a dataset). Args: - query (str): A fully grammatically correct question about the current model configuration. + query (str): A fully grammatically correct question about the current model configuration (and optional dataset). """ prompt = f""" @@ -53,6 +54,8 @@ async def generate_code( The current configuration is: {agent.context.model_config} +The user may ask you to update the model configuration based on a dataset. If they do, you should use the `dataset` DataFrame to update the model configuration. + Please write code that satisfies the user's request below. Please generate the code as if you were programming inside a Jupyter Notebook and the code is to be executed inside a cell. diff --git a/src/askem_beaker/contexts/model_configuration/context.py b/src/askem_beaker/contexts/model_configuration/context.py index 82cf293..f7dd4dd 100644 --- a/src/askem_beaker/contexts/model_configuration/context.py +++ b/src/askem_beaker/contexts/model_configuration/context.py @@ -12,6 +12,7 @@ from beaker_kernel.lib.utils import intercept from .agent import ConfigEditAgent +from askem_beaker.utils import get_auth if TYPE_CHECKING: from beaker_kernel.kernel import LLMKernel @@ -41,13 +42,17 @@ async def setup(self, context_info, parent_header): self.config["context_info"] = context_info item_id = self.config["context_info"]["id"] item_type = self.config["context_info"].get("type", "model_config") + self.dataset_id = self.config["context_info"].get("dataset_id", None) logger.error(f"Processing {item_type} {item_id}") + self.auth = get_auth() await self.set_model_config( item_id, item_type, parent_header=parent_header ) + if self.dataset_id: + await self.load_dataset(parent_header=parent_header) async def auto_context(self): - return f"""You are an scientific modeler whose goal is to help the user understand and update a model configuration. + context = f"""You are an scientific modeler whose goal is to help the user understand and update a model configuration. Model configurations are defined by a specific model configuration JSON schema. The schema defines the structure of the model configuration, including the parameters, initial conditions, and other attributes of the model. @@ -63,6 +68,29 @@ async def auto_context(self): If you need to generate code, you should write it in the '{self.subkernel.DISPLAY_NAME}' language for execution in a Jupyter notebook using the '{self.subkernel.KERNEL_NAME}' kernel. """ + if self.dataset_id: + context += f"""\n Additionally, a DataFrame is loaded called `dataset`. + This DataFrame contains the data that can be used to parameterize the model configuration. If the user mentions updating the model + configuration based on a dataset, use the `dataset` DataFrame. + + It has the following structure: + {await self.describe_dataset()}\n + + When running or generating code to update the model configuration, make sure to do so based on your knowledge of `dataset`s structure + and schema. You should never "assume" a dataset in the code that you generate or run, instead you should always use the `dataset` DataFrame. + You'll need to access various values from it based on your knowledge of the model configuration schema and your understanding of the dataset. + + A common convention is that strata are delinated with `_n` where `n` is a number in increasing order. For example, if there are multiple `age` strata + you might expect that `_1` indicates the first (or youngest) strata, `_2` the second, and so on. + + You'll also find that for stratified models, the parameter `referenceId` often indicates the strata to which the parameter applies. For example you may see a + `referenceId` of `beta_old_young` which would correspond with the interaction in the `dataset` between these two strata (if the interaction represents the `beta` parameter). + So if there are three strata, `old`, `middle`, and `young`, you might expect to see `beta_old_middle`, `beta_old_young`, and `beta_middle_young` as + the referenceIds for the interactions between these strata. In the dataset you might see groups delineated with `_1`, `_2`, `_3` to indicate the strata. + The value at `S_3` and `I_2` would correspond with `beta_old_middle` in this case since it reflects the transition from from old susceptible population + and middle infected population. + """ + return context async def get_schema(self) -> str: """ @@ -93,17 +121,27 @@ async def get_config(self) -> str: schema = ( await self.evaluate(self.get_code("get_config")) )["return"] - return json.dumps(schema, indent=2) + return json.dumps(schema, indent=2) + + + async def describe_dataset(self) -> str: + """ + Describe structure of provided dataset to assist in model configuration. + + Returns: + str: a description of the dataset structure + """ + schema = ( + await self.evaluate(self.get_code("describe_dataset", {"dataset_name": "dataset"})) + )["return"] + return schema async def set_model_config(self, item_id, agent=None, parent_header={}): self.config_id = item_id meta_url = f"{os.environ['HMI_SERVER_URL']}/model-configurations/{self.config_id}" logger.error(f"Meta url: {meta_url}") - self.model_config = requests.get(meta_url, - auth=(os.environ['AUTH_USERNAME'], - os.environ['AUTH_PASSWORD']) - ).json() + self.model_config = requests.get(meta_url, auth=self.auth.requests_auth()).json() logger.error(f"Succeeded in fetching configured model, proceeding.") await self.load_config() @@ -117,7 +155,30 @@ async def load_config(self): ] ) print(f"Running command:\n-------\n{command}\n---------") - await self.execute(command) + await self.execute(command) + + async def load_dataset(self, parent_header={}): + meta_url = f"{os.environ['HMI_SERVER_URL']}/datasets/{self.dataset_id}" + dataset = requests.get(meta_url, auth=self.auth.requests_auth()) + if dataset.status_code == 404: + raise Exception(f"Dataset '{self.dataset_id}' not found.") + filename = dataset.json().get("fileNames", [])[0] + meta_url = f"{os.environ['HMI_SERVER_URL']}/datasets/{self.dataset_id}" + url = f"{meta_url}/download-url?filename={filename}" + data_url_req = requests.get( + url=url, + auth=self.auth.requests_auth(), + ) + data_url = data_url_req.json().get("url", None) + command = "\n".join( + [ + self.get_code("load_dataset", { + "var_name": "dataset", + "data_url": data_url, + }), + ] + ) + await self.execute(command) async def post_execute(self, message): content = (await self.evaluate(self.get_code("get_config")))["return"] diff --git a/src/askem_beaker/contexts/model_configuration/procedures/python3/describe_dataset.py b/src/askem_beaker/contexts/model_configuration/procedures/python3/describe_dataset.py new file mode 100644 index 0000000..262aca7 --- /dev/null +++ b/src/askem_beaker/contexts/model_configuration/procedures/python3/describe_dataset.py @@ -0,0 +1,25 @@ +import pandas as pd + +def describe_dataframe_structure(df): + # Basic information about the DataFrame + description = f"The DataFrame has {df.shape[0]} rows and {df.shape[1]} columns.\n" + + # Extract column names and types + description += "Column names and types:\n" + for column in df.columns: + description += f"- {column}: {df[column].dtype}\n" + + # Check if the DataFrame has a matrix-like structure + if not df.index.is_integer(): + description += "The DataFrame appears to have a matrix-like structure with row headers.\n" + description += "Row headers:\n" + for idx_name in df.index.names: + description += f"- {idx_name or 'Unnamed index'}: {df.index.get_level_values(idx_name).dtype}\n" + + # Append the head of the DataFrame + description += "\nThe first few rows of the DataFrame (head):\n" + description += df.head().to_string(index=True) + + return description + +describe_dataframe_structure({{ dataset_name }}) \ No newline at end of file diff --git a/src/askem_beaker/contexts/model_configuration/procedures/python3/load_dataset.py b/src/askem_beaker/contexts/model_configuration/procedures/python3/load_dataset.py new file mode 100644 index 0000000..5fc2f82 --- /dev/null +++ b/src/askem_beaker/contexts/model_configuration/procedures/python3/load_dataset.py @@ -0,0 +1,2 @@ +import pandas as pd +{{ var_name }} = pd.read_csv('{{ data_url }}')