Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding new widgets to prod #45

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.terraform.lock.hcl
.terraform/*
.vscode/*
90 changes: 81 additions & 9 deletions handlers/ecs/entryextraction_llm/llm/model_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@

import redis
from box import Box

# from langchain_aws import ChatBedrockConverse
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from llm.prompt_utils import (
DESCRIPTION_PILLARS,
DESCRIPTION_ORGANIGRAM,
MAIN_PROMT,
INPUT_PASSAGE,
DESCRIPTION_MATRIX
DESCRIPTION_MATRIX,
)
from llm.utils import (
af_widget_by_id,
combine_properties,
connect_db,
process_primary_tags,
process_multiselect_tags,
process_organigram_tags,
)
from llm.utils import (af_widget_by_id, combine_properties, connect_db,
process_primary_tags)
from pydantic import BaseModel, Field, create_model

# from utils.db import connect_db
Expand Down Expand Up @@ -72,10 +80,24 @@ def __process_widget_2d(self, key: str, properties: dict, version: int):
)

def __process_widget_multiselect(self, key: str, properties: dict, version: int):
raise NotImplementedError

_, rows = process_multiselect_tags(properties)
self.__update(
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.multiselectWidget,
)

def __process_widget_organigram(self, key: str, properties: dict, version: int):
raise NotImplementedError

_, rows = process_organigram_tags(properties)
self.__update(
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.organigramWidget,
)

def __process_widget_daterange(self, key: str, properties: dict, version: int):
raise NotImplementedError
Expand Down Expand Up @@ -213,10 +235,50 @@ def __process_widget_2d(
)

def __process_widget_multiselect(self, key: str, properties: dict, class_name: str):
raise NotImplementedError

properties, _ = process_multiselect_tags(properties)
multiselect = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_PILLARS.format(class_name.lower()),
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.multiselectWidget,
prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=multiselect),
properties=properties,
pyd_class=multiselect,
),
)

def __process_widget_organigram(self, key: str, properties: dict, class_name: str):
raise NotImplementedError

properties, _ = process_organigram_tags(properties)
organigram = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_ORGANIGRAM.format(class_name.lower()),
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.organigramWidget,
prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=organigram),
properties=properties,
pyd_class=organigram,
),
)

def __process_widget_daterange(self, key: str, properties: dict, class_name: str):
raise NotImplementedError
Expand Down Expand Up @@ -252,6 +314,8 @@ class LLMTagsPrediction:
AVAILABLE_WIDGETS: list = [
"matrix2dWidget",
"matrix1dWidget",
"multiselectWidget",
"organigramWidget",
] # it'll be extended to all widget types
AVAILABLE_FOUNDATION_MODELS: list = ["bedrock", "openai"]

Expand Down Expand Up @@ -406,10 +470,18 @@ def __convert_result(self, prediction: dict):
results[k][pillar][subpillar][sector].append(subsector)

elif type_ == self.widgets.mappings_instance.WidgetTypes.multiselectWidget:
raise NotImplementedError

if k not in results.keys():
results.update({k: []})
for c in v:
results[k].append(schema[c]["alias"])

elif type_ == self.widgets.mappings_instance.WidgetTypes.organigramWidget:
raise NotImplementedError

if k not in results.keys():
results.update({k: []})
for c in v:
results[k].append(schema[c]["alias"])

elif type_ == self.widgets.mappings_instance.WidgetTypes.dateRangeWidget:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions handlers/ecs/entryextraction_llm/llm/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

DESCRIPTION_PILLARS = """Those are the {} labels that you will use to classify text. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa
DESCRIPTION_MATRIX = """Those are the {} labels that you will use to classify text. Each label is a combination of a column label (Column category) and a row label (Row category). Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage can't clearly be associated to some label, none of them must be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa
DESCRIPTION_ORGANIGRAM = """Those are the {} labels that you will use to classify text. It's an organigram style set of labels, so classify only to the specified level of granularity of the passage. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about. Don't infer possible labels, just select in respect of the passage content""" # noqa

MULTI_DESCRIPTION = """Column category: {}. Row category: {}."""

Expand Down
67 changes: 54 additions & 13 deletions handlers/ecs/entryextraction_llm/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import re

import psycopg2
from llm.prompt_utils import (
MULTI_DESCRIPTION
)
from llm.prompt_utils import MULTI_DESCRIPTION

af_widget_by_id = (
"SELECT * from analysis_framework_widget ll WHERE ll.analysis_framework_id={}"
Expand Down Expand Up @@ -67,19 +65,20 @@ def add_element(name, description, alias, main_class=None):
_sanitize_keys_with_uniqueness(name.replace(" ", "_").lower())
if isinstance(name, str)
else _sanitize_keys_with_uniqueness(
[el.replace(" ", "_").lower() for el in name][1]
[el.replace(" ", "_").lower() for el in name][-1]
)
): element
}


def process_primary_tags(ex: list, order="columns", type_="2d", max_length: int = 50):
def get_tooltip(el):
if el.get("tooltip"):
return el.get("tooltip")
else:
return ""

def get_tooltip(el):
if el.get("tooltip"):
return el.get("tooltip")
else:
return ""

def process_primary_tags(ex: list, order="columns", type_="2d", max_length: int = 50):

properties = {}
id_to_info = {}
Expand All @@ -106,9 +105,7 @@ def get_tooltip(el):
properties.update(prop)

elif (
type_ == "2d" and
f"sub{order.title()}" in c.keys() and
c.get(f"sub{order.title()}")
type_ == "2d" and f"sub{order.title()}" in c.keys() and c.get(f"sub{order.title()}")
):
for cc in c[f"sub{order.title()}"]:

Expand Down Expand Up @@ -138,6 +135,50 @@ def get_tooltip(el):
return properties, id_to_info


def process_organigram_tags(ex: dict, properties={}, id_to_info={}, parents=None, i=0):

if i == 0:
properties = {}
id_to_info = {}
ex = ex["options"]["children"]
i = 1

for el in ex:
if parents:
parents = [parents] if isinstance(parents, str) else parents
name = parents + [el["label"].strip()]
else:
name = el["label"].strip()

description = get_tooltip(el)
alias = el["key"]
id_to_info.update({alias: {"label": el["label"], "order": el["order"]}})

properties.update(add_element(name=name, description=description, alias=alias))

if not el.get("children", []):
continue
else:
process_organigram_tags(el["children"], properties, id_to_info, name, i)

return properties, id_to_info


def process_multiselect_tags(ex: dict):

properties = {}
id_to_info = {}
for c in ex["options"]:

name = c["label"].strip()
description = get_tooltip(c)
alias = c["key"]
id_to_info.update({c["key"]: {"label": c["label"], "order": c["order"]}})
properties.update(add_element(name=name, description=description, alias=alias))

return properties, id_to_info


def combine_properties(
properties_row: dict,
properties_columns: dict,
Expand Down
4 changes: 2 additions & 2 deletions main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ terraform {
region = "us-east-1"
# dynamodb_table = "terraform-lock-integration-db"
encrypt = true
#profile = "nlp_tf"
# profile = "nlp_tf"
}
}

provider "aws" {
region = var.aws_region
#profile = var.aws_profile
# profile = var.aws_profile
#shared_credentials_files = ["~/.aws/credentials"]
}

Expand Down
10 changes: 4 additions & 6 deletions modules/server/ecs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ locals {
app_image_url = "${data.aws_caller_identity.current_user.account_id}.dkr.ecr.${var.aws_region}.amazonaws.com/${var.app_image}:latest"
}

data "template_file" "config" {
template = file("./modules/server/templates/ecr_image/image.json")

vars = {
locals {
rendered_config = templatefile("./modules/server/templates/ecr_image/image.json", {
app_image = local.app_image_url
app_port = var.app_port
fargate_cpu = var.fargate_cpu
Expand Down Expand Up @@ -113,7 +111,7 @@ data "template_file" "config" {
reliability_model_version = var.reliability_model_version
# OpenAI API key
ssm_openai_api_key_arn = var.ssm_openai_api_key_arn
}
})
}

resource "aws_ecs_task_definition" "task-def" {
Expand All @@ -124,7 +122,7 @@ resource "aws_ecs_task_definition" "task-def" {
requires_compatibilities = ["FARGATE"]
cpu = var.fargate_cpu
memory = var.fargate_memory
container_definitions = data.template_file.config.rendered
container_definitions = local.rendered_config
}

resource "aws_ecs_service" "service" {
Expand Down
Loading