Skip to content

Commit

Permalink
adding new widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
aembryonic committed Dec 19, 2024
1 parent 446d584 commit 0a2d5fe
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 100 deletions.
225 changes: 164 additions & 61 deletions nlp_scripts/model_prediction/llm/model_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@

from nlp_scripts.model_prediction.llm.prompt_utils import (
DESCRIPTION_PILLARS,
DESCRIPTION_ORGANIGRAM,
MAIN_PROMT,
INPUT_PASSAGE,
DESCRIPTION_MATRIX
DESCRIPTION_MATRIX,
)
from nlp_scripts.model_prediction.llm.utils import (
process_primary_tags,
combine_properties,
process_multiselect_tags,
process_organigram_tags,
)
from nlp_scripts.model_prediction.llm.utils import process_primary_tags, combine_properties
from utils.db import connect_db
from core_server.env import env
from core.tasks.queries import af_widget_by_id
Expand All @@ -30,26 +36,27 @@ def __init__(self, selected_widgets: list):
self.mappings = {}
self.selected_widgets = selected_widgets
self.WidgetTypes = make_dataclass(
'WidgetTypes',
[(widget, str, field(default=widget))for widget in list(set([element.widget_id
for element in self.selected_widgets]))]
"WidgetTypes",
[
(widget, str, field(default=widget))
for widget in list(
set([element.widget_id for element in self.selected_widgets])
)
],
)
self.create_mappings()

def __update(self, key: str, value: dict, version: int, widget_type: str):
self.mappings.update({key: value})
self.mappings[key].update({
"version": version,
"widget_id": widget_type
})
self.mappings[key].update({"version": version, "widget_id": widget_type})

def __process_widget_1d(self, key: str, properties: dict, version: int):
_, rows = process_primary_tags(properties, order="rows", type_="1d")
self.__update(
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.matrix1dWidget
widget_type=self.WidgetTypes.matrix1dWidget,
)

def __process_widget_2d(self, key: str, properties: dict, version: int):
Expand All @@ -61,14 +68,28 @@ def __process_widget_2d(self, key: str, properties: dict, version: int):
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.matrix2dWidget
widget_type=self.WidgetTypes.matrix2dWidget,
)

def __process_widget_multiselect(self, key: str, properties: dict, version: int):
raise NotImplementedError

_, rows = process_multiselect_tags(properties)
self.__update(
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.multiselectWidget,
)

def __process_widget_organigram(self, key: str, properties: dict, version: int):
raise NotImplementedError

_, rows = process_organigram_tags(properties)
self.__update(
key=key,
value=rows,
version=version,
widget_type=self.WidgetTypes.organigramWidget,
)

def __process_widget_daterange(self, key: str, properties: dict, version: int):
raise NotImplementedError
Expand Down Expand Up @@ -111,15 +132,15 @@ def __init__(self, selected_widgets: list, model_family: str = "openai"):
self.create_schemas()

def __foundation_model_id_selection(
self,
schema: BaseModel = None,
ln_threshold: int = 30
self, schema: BaseModel = None, ln_threshold: int = 30
):
length = len(schema.schema()["properties"].keys())

if self.model_family == "bedrock":
model_id_main = env("BEDROCK_MAIN_MODEL")
model_id_small = env("BEDROCK_SMALL_MODEL") # haiku model overclassify a lot for some reason
model_id_small = env(
"BEDROCK_SMALL_MODEL"
) # haiku model overclassify a lot for some reason
return model_id_main if length <= ln_threshold else model_id_main
elif self.model_family == "openai":
model_id_main = env("OPENAI_MAIN_MODEL")
Expand All @@ -130,54 +151,109 @@ def __foundation_model_id_selection(
def __update(self, key: str, value: Schema):
self.schemas.update({key: value})

def __process_widget_1d(self, key: str, properties: dict, class_name: str = "Pillars"):
def __process_widget_1d(
self, key: str, properties: dict, class_name: str = "Pillars"
):
properties, _ = process_primary_tags(properties, order="rows", type_="1d")
# dynamic pydantic class creation
pillars = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_PILLARS.format(class_name.lower()),
**{k: (bool, Field(title=k, description=v['description'], default=False))
for k, v in properties.items()}
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)
self.__update(key, self.Schema(
type=self.mappings_instance.WidgetTypes.matrix1dWidget,
prompt=MAIN_PROMT.format(class_name) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=pillars),
properties=properties,
pyd_class=pillars
))

def __process_widget_2d(self, key: str, properties: dict, class_name: str = "Matrix"):
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.matrix1dWidget,
prompt=MAIN_PROMT.format(class_name) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=pillars),
properties=properties,
pyd_class=pillars,
),
)

def __process_widget_2d(
self, key: str, properties: dict, class_name: str = "Matrix"
):
properties_row, _ = process_primary_tags(properties, order="rows")
properties_columns, _ = process_primary_tags(properties, order="columns")
properties = combine_properties(properties_columns=properties_columns,
properties_row=properties_row,
max_length=self.max_widget_length,
reduce_on_length=True) # setting the description reduction
properties = combine_properties(
properties_columns=properties_columns,
properties_row=properties_row,
max_length=self.max_widget_length,
reduce_on_length=True,
) # setting the description reduction
# dynamic pydantic class creation
matrix = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_MATRIX.format(class_name.lower()),
**{k: (
bool, Field(title=k, description=v['description'], default=False))
for k, v in properties.items()}
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)

self.__update(key, self.Schema(
type=self.mappings_instance.WidgetTypes.matrix2dWidget,
prompt=MAIN_PROMT.format(class_name) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=matrix),
properties=properties,
pyd_class=matrix
))
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.matrix2dWidget,
prompt=MAIN_PROMT.format(class_name) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=matrix),
properties=properties,
pyd_class=matrix,
),
)

def __process_widget_multiselect(self, key: str, properties: dict, class_name: str):
raise NotImplementedError

properties, _ = process_multiselect_tags(properties)
multiselect = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_PILLARS.format(class_name.lower()),
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.multiselectWidget,
prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=multiselect),
properties=properties,
pyd_class=multiselect,
),
)

def __process_widget_organigram(self, key: str, properties: dict, class_name: str):
raise NotImplementedError

properties, _ = process_organigram_tags(properties)
organigram = create_model(
class_name,
__base__=BaseModel,
__doc__=DESCRIPTION_ORGANIGRAM.format(class_name.lower()),
**{
k: (bool, Field(title=k, description=v["description"], default=False))
for k, v in properties.items()
},
)
self.__update(
key,
self.Schema(
type=self.mappings_instance.WidgetTypes.organigramWidget,
prompt=MAIN_PROMT.format(class_name.lower()) + INPUT_PASSAGE,
model=self.__foundation_model_id_selection(schema=organigram),
properties=properties,
pyd_class=organigram,
),
)

def __process_widget_daterange(self, key: str, properties: dict, class_name: str):
raise NotImplementedError
Expand All @@ -204,14 +280,21 @@ def create_schemas(self):

class LLMTagsPrediction:

AVAILABLE_WIDGETS: list = ["matrix2dWidget", "matrix1dWidget"] # it'll be extended to all widget types
AVAILABLE_WIDGETS: list = [
"matrix2dWidget",
"matrix1dWidget",
"multiselectWidget",
"organigramWidget",
] # it'll be extended to all widget types
AVAILABLE_FOUNDATION_MODELS: list = ["bedrock", "openai"]

def __init__(self, analysis_framework_id: int, model_family: str = "openai"):
self.af_id = analysis_framework_id
self.model_family = model_family

assert self.model_family in self.AVAILABLE_FOUNDATION_MODELS, ValueError("Selected model family not implemented")
assert self.model_family in self.AVAILABLE_FOUNDATION_MODELS, ValueError(
"Selected model family not implemented"
)

# self.cursor = self.__get_deep_db_connection().cursor
self.selected_widgets = self.__get_framework_widgets()
Expand All @@ -221,9 +304,7 @@ def __get_deep_db_connection(self):
return connect_db()

def __get_elasticache(self, port: int = 6379):
return redis.Redis(host=env("REDIS_HOST"),
port=port,
decode_responses=True)
return redis.Redis(host=env("REDIS_HOST"), port=port, decode_responses=True)

def __get_framework_widgets(self, expire_time: int = 1200):
# let's get or save the af_id widget original data on elasticache for 20 minutes
Expand All @@ -237,14 +318,21 @@ def __get_framework_widgets(self, expire_time: int = 1200):
self.cursor.execute(af_widget_by_id.format(self.af_id))
fetch = self.cursor.fetchall()
if not fetch:
raise ValueError(f"Not possible to retrieve framework widgets: {self.af_id}")
raise ValueError(
f"Not possible to retrieve framework widgets: {self.af_id}"
)
else:
afw = [Box(dict(zip([c.name for c in self.cursor.description], row))) for row in fetch]
afw = [element for element in afw if element.widget_id in self.AVAILABLE_WIDGETS]
afw = [
Box(dict(zip([c.name for c in self.cursor.description], row)))
for row in fetch
]
afw = [
element
for element in afw
if element.widget_id in self.AVAILABLE_WIDGETS
]
self.redis.set(
name=f"af_id:{self.af_id}",
ex=expire_time,
value=json.dumps(afw)
name=f"af_id:{self.af_id}", ex=expire_time, value=json.dumps(afw)
)
return afw

Expand All @@ -261,13 +349,17 @@ def select_model_instance(model_name: str):

def create_chain(prompt: str, llm: str, pydantic_model: BaseModel):
tagging_prompt = ChatPromptTemplate.from_template(prompt)
_llm = select_model_instance(model_name=llm).with_structured_output(pydantic_model)
_llm = select_model_instance(model_name=llm).with_structured_output(
pydantic_model
)
return tagging_prompt | _llm

# running the excerpt tagging in a parallel way across all the widgets of the framework
parallel_tasks = RunnableParallel(
**{k: create_chain(v.prompt, v.model, v.pyd_class)
for k, v in self.widgets.schemas.items()}
**{
k: create_chain(v.prompt, v.model, v.pyd_class)
for k, v in self.widgets.schemas.items()
}
)

results = parallel_tasks.invoke({"input": excerpt})
Expand Down Expand Up @@ -327,10 +419,21 @@ def __convert_result(self, prediction: dict):
if sector not in results[k][pillar][subpillar].keys():
results[k][pillar][subpillar].update({sector: []})
results[k][pillar][subpillar][sector].append(subsector)

elif type_ == self.widgets.mappings_instance.WidgetTypes.multiselectWidget:
raise NotImplementedError

if k not in results.keys():
results.update({k: []})
for c in v:
results[k].append(schema[c]["alias"])

elif type_ == self.widgets.mappings_instance.WidgetTypes.organigramWidget:
raise NotImplementedError

if k not in results.keys():
results.update({k: []})
for c in v:
results[k].append(schema[c]["alias"])

elif type_ == self.widgets.mappings_instance.WidgetTypes.dateRangeWidget:
raise NotImplementedError
elif type_ == self.widgets.mappings_instance.WidgetTypes.scaleWidget:
Expand Down
2 changes: 1 addition & 1 deletion nlp_scripts/model_prediction/llm/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

MAIN_PROMT = """
You are an humanitarian analyst. Extract the desired information from the following passage.
Only extract the properties mentioned in the {} function. It's a multi-label text classification task, try to assign a boolean value for each label considering the content of the passage
""" # noqa

DESCRIPTION_PILLARS = """Those are the {} labels that you will use to classify text. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa
DESCRIPTION_MATRIX = """Those are the {} labels that you will use to classify text. Each label is a combination of a column label (Column category) and a row label (Row category). Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage can't clearly be associated to some label, none of them must be selected. Be sure to not over-classify the passage, but just select what you're sure about.""" # noqa
DESCRIPTION_ORGANIGRAM = """Those are the {} labels that you will use to classify text. It's an organigram style set of labels, so classify only to the specified level of granularity of the passage. Each element can be selected as True or False. It's a multi-label classification task, so not just one label can be inferred as True. If the passage is not enough clear to be associated to some label, also none of them can be selected. Be sure to not over-classify the passage, but just select what you're sure about. Don't infer possible labels, just select in respect of the passage content""" # noqa

MULTI_DESCRIPTION = """Column category: {}. Row category: {}."""

Expand Down
Loading

0 comments on commit 0a2d5fe

Please sign in to comment.