From 6f9718435ec2aaf8d22284266f0acd32fdb4c5c8 Mon Sep 17 00:00:00 2001 From: aembryonic Date: Wed, 18 Dec 2024 15:50:55 +0100 Subject: [PATCH 1/2] organigram and multiselect --- .gitignore | 1 + .../llm/model_prediction.py | 90 +++++++++++++++++-- .../entryextraction_llm/llm/prompt_utils.py | 1 + handlers/ecs/entryextraction_llm/llm/utils.py | 67 +++++++++++--- 4 files changed, 137 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 8ba1d70..becdff7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .terraform.lock.hcl .terraform/* +.vscode/* \ No newline at end of file diff --git a/handlers/ecs/entryextraction_llm/llm/model_prediction.py b/handlers/ecs/entryextraction_llm/llm/model_prediction.py index 388382e..06ccfc5 100644 --- a/handlers/ecs/entryextraction_llm/llm/model_prediction.py +++ b/handlers/ecs/entryextraction_llm/llm/model_prediction.py @@ -5,18 +5,26 @@ import redis from box import Box + # from langchain_aws import ChatBedrockConverse from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel from langchain_openai import ChatOpenAI from llm.prompt_utils import ( DESCRIPTION_PILLARS, + DESCRIPTION_ORGANIGRAM, MAIN_PROMT, INPUT_PASSAGE, - DESCRIPTION_MATRIX + DESCRIPTION_MATRIX, +) +from llm.utils import ( + af_widget_by_id, + combine_properties, + connect_db, + process_primary_tags, + process_multiselect_tags, + process_organigram_tags, ) -from llm.utils import (af_widget_by_id, combine_properties, connect_db, - process_primary_tags) from pydantic import BaseModel, Field, create_model # from utils.db import connect_db @@ -72,10 +80,24 @@ def __process_widget_2d(self, key: str, properties: dict, version: int): ) def __process_widget_multiselect(self, key: str, properties: dict, version: int): - raise NotImplementedError + + _, rows = process_multiselect_tags(properties) + self.__update( + key=key, + value=rows, + version=version, + widget_type=self.WidgetTypes.multiselectWidget, + ) def __process_widget_organigram(self, key: str, properties: dict, version: int): - raise NotImplementedError + + _, rows = process_organigram_tags(properties) + self.__update( + key=key, + value=rows, + version=version, + widget_type=self.WidgetTypes.organigramWidget, + ) def __process_widget_daterange(self, key: str, properties: dict, version: int): raise NotImplementedError @@ -213,10 +235,50 @@ def __process_widget_2d( ) def __process_widget_multiselect(self, key: str, properties: dict, class_name: str): - raise NotImplementedError + + properties, _ = process_multiselect_tags(properties) + multiselect = create_model( + class_name, + __base__=BaseModel, + __doc__=DESCRIPTION_PILLARS.format(class_name.lower()), + **{ + k: (bool, Field(title=k, description=v["description"], default=False)) + for k, v in properties.items() + }, + ) + self.__update( + key, + self.Schema( + type=self.mappings_instance.WidgetTypes.multiselectWidget, + prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE, + model=self.__foundation_model_id_selection(schema=multiselect), + properties=properties, + pyd_class=multiselect, + ), + ) def __process_widget_organigram(self, key: str, properties: dict, class_name: str): - raise NotImplementedError + + properties, _ = process_organigram_tags(properties) + organigram = create_model( + class_name, + __base__=BaseModel, + __doc__=DESCRIPTION_ORGANIGRAM.format(class_name.lower()), + **{ + k: (bool, Field(title=k, description=v["description"], default=False)) + for k, v in properties.items() + }, + ) + self.__update( + key, + self.Schema( + type=self.mappings_instance.WidgetTypes.organigramWidget, + prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE, + model=self.__foundation_model_id_selection(schema=organigram), + properties=properties, + pyd_class=organigram, + ), + ) def __process_widget_daterange(self, key: str, properties: dict, class_name: str): raise NotImplementedError @@ -252,6 +314,8 @@ class LLMTagsPrediction: AVAILABLE_WIDGETS: list = [ "matrix2dWidget", "matrix1dWidget", + "multiselectWidget", + "organigramWidget", ] # it'll be extended to all widget types AVAILABLE_FOUNDATION_MODELS: list = ["bedrock", "openai"] @@ -406,10 +470,18 @@ def __convert_result(self, prediction: dict): results[k][pillar][subpillar][sector].append(subsector) elif type_ == self.widgets.mappings_instance.WidgetTypes.multiselectWidget: - raise NotImplementedError + + if k not in results.keys(): + results.update({k: []}) + for c in v: + results[k].append(schema[c]["alias"]) elif type_ == self.widgets.mappings_instance.WidgetTypes.organigramWidget: - raise NotImplementedError + + if k not in results.keys(): + results.update({k: []}) + for c in v: + results[k].append(schema[c]["alias"]) elif type_ == self.widgets.mappings_instance.WidgetTypes.dateRangeWidget: raise NotImplementedError diff --git a/handlers/ecs/entryextraction_llm/llm/prompt_utils.py b/handlers/ecs/entryextraction_llm/llm/prompt_utils.py index 98c056b..278531c 100644 --- a/handlers/ecs/entryextraction_llm/llm/prompt_utils.py +++ b/handlers/ecs/entryextraction_llm/llm/prompt_utils.py @@ -5,6 +5,7 @@ DESCRIPTION_PILLARS = """Those are the {} labels that you will use to classify text. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa DESCRIPTION_MATRIX = """Those are the {} labels that you will use to classify text. Each label is a combination of a column label (Column category) and a row label (Row category). Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage can't clearly be associated to some label, none of them must be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa +DESCRIPTION_ORGANIGRAM = """Those are the {} labels that you will use to classify text. It's an organigram style set of labels, so classify only to the specified level of granularity of the passage. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about. Don't infer possible labels, just select in respect of the passage content""" # noqa MULTI_DESCRIPTION = """Column category: {}. Row category: {}.""" diff --git a/handlers/ecs/entryextraction_llm/llm/utils.py b/handlers/ecs/entryextraction_llm/llm/utils.py index 6051818..3ae593f 100644 --- a/handlers/ecs/entryextraction_llm/llm/utils.py +++ b/handlers/ecs/entryextraction_llm/llm/utils.py @@ -2,9 +2,7 @@ import re import psycopg2 -from llm.prompt_utils import ( - MULTI_DESCRIPTION -) +from llm.prompt_utils import MULTI_DESCRIPTION af_widget_by_id = ( "SELECT * from analysis_framework_widget ll WHERE ll.analysis_framework_id={}" @@ -67,19 +65,20 @@ def add_element(name, description, alias, main_class=None): _sanitize_keys_with_uniqueness(name.replace(" ", "_").lower()) if isinstance(name, str) else _sanitize_keys_with_uniqueness( - [el.replace(" ", "_").lower() for el in name][1] + [el.replace(" ", "_").lower() for el in name][-1] ) ): element } -def process_primary_tags(ex: list, order="columns", type_="2d", max_length: int = 50): +def get_tooltip(el): + if el.get("tooltip"): + return el.get("tooltip") + else: + return "" - def get_tooltip(el): - if el.get("tooltip"): - return el.get("tooltip") - else: - return "" + +def process_primary_tags(ex: list, order="columns", type_="2d", max_length: int = 50): properties = {} id_to_info = {} @@ -106,9 +105,7 @@ def get_tooltip(el): properties.update(prop) elif ( - type_ == "2d" and - f"sub{order.title()}" in c.keys() and - c.get(f"sub{order.title()}") + type_ == "2d" and f"sub{order.title()}" in c.keys() and c.get(f"sub{order.title()}") ): for cc in c[f"sub{order.title()}"]: @@ -138,6 +135,50 @@ def get_tooltip(el): return properties, id_to_info +def process_organigram_tags(ex: dict, properties={}, id_to_info={}, parents=None, i=0): + + if i == 0: + properties = {} + id_to_info = {} + ex = ex["options"]["children"] + i = 1 + + for el in ex: + if parents: + parents = [parents] if isinstance(parents, str) else parents + name = parents + [el["label"].strip()] + else: + name = el["label"].strip() + + description = get_tooltip(el) + alias = el["key"] + id_to_info.update({alias: {"label": el["label"], "order": el["order"]}}) + + properties.update(add_element(name=name, description=description, alias=alias)) + + if not el.get("children", []): + continue + else: + process_organigram_tags(el["children"], properties, id_to_info, name, i) + + return properties, id_to_info + + +def process_multiselect_tags(ex: dict): + + properties = {} + id_to_info = {} + for c in ex["options"]: + + name = c["label"].strip() + description = get_tooltip(c) + alias = c["key"] + id_to_info.update({c["key"]: {"label": c["label"], "order": c["order"]}}) + properties.update(add_element(name=name, description=description, alias=alias)) + + return properties, id_to_info + + def combine_properties( properties_row: dict, properties_columns: dict, From c0e154aff7423e4818e7ee46fbf1863f45213248 Mon Sep 17 00:00:00 2001 From: aembryonic Date: Thu, 19 Dec 2024 11:52:44 +0100 Subject: [PATCH 2/2] remove aws profile --- main.tf | 4 ++-- modules/server/ecs.tf | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/main.tf b/main.tf index 59d1d61..1671d3a 100644 --- a/main.tf +++ b/main.tf @@ -12,13 +12,13 @@ terraform { region = "us-east-1" # dynamodb_table = "terraform-lock-integration-db" encrypt = true - #profile = "nlp_tf" + # profile = "nlp_tf" } } provider "aws" { region = var.aws_region - #profile = var.aws_profile + # profile = var.aws_profile #shared_credentials_files = ["~/.aws/credentials"] } diff --git a/modules/server/ecs.tf b/modules/server/ecs.tf index cf6ecfb..88cfd64 100644 --- a/modules/server/ecs.tf +++ b/modules/server/ecs.tf @@ -8,10 +8,8 @@ locals { app_image_url = "${data.aws_caller_identity.current_user.account_id}.dkr.ecr.${var.aws_region}.amazonaws.com/${var.app_image}:latest" } -data "template_file" "config" { - template = file("./modules/server/templates/ecr_image/image.json") - - vars = { +locals { + rendered_config = templatefile("./modules/server/templates/ecr_image/image.json", { app_image = local.app_image_url app_port = var.app_port fargate_cpu = var.fargate_cpu @@ -113,7 +111,7 @@ data "template_file" "config" { reliability_model_version = var.reliability_model_version # OpenAI API key ssm_openai_api_key_arn = var.ssm_openai_api_key_arn - } + }) } resource "aws_ecs_task_definition" "task-def" { @@ -124,7 +122,7 @@ resource "aws_ecs_task_definition" "task-def" { requires_compatibilities = ["FARGATE"] cpu = var.fargate_cpu memory = var.fargate_memory - container_definitions = data.template_file.config.rendered + container_definitions = local.rendered_config } resource "aws_ecs_service" "service" {