Skip to content

Commit

Permalink
Merge pull request #88 from lamini-ai/update-data-pipeline
Browse files Browse the repository at this point in the history
Update pipeline to use the newest APIs
  • Loading branch information
yx-lamini authored Jul 15, 2024
2 parents c301689 + 07480c1 commit 70accea
Showing 1 changed file with 70 additions and 153 deletions.
223 changes: 70 additions & 153 deletions 05_data_pipeline/generate_data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from lamini.generation.generation_node import GenerationNode
from lamini.generation.generation_pipeline import GenerationPipeline
from lamini.generation.base_prompt_object import PromptObject

import asyncio
import itertools
import jsonlines
import logging

import itertools
import asyncio
from tqdm import tqdm

from typing import Union, Iterator, AsyncIterator

import logging
from lamini.generation.generation_node import GenerationNode
from lamini.generation.generation_pipeline import GenerationPipeline
from lamini.generation.base_prompt_object import PromptObject

logger = logging.getLogger(__name__)

Expand All @@ -19,24 +16,6 @@
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


async def main():
earnings_calls = load_earnings_calls()

answers = QuestionAnswerPipeline().call(earnings_calls)

await save_answers(answers)


async def load_earnings_calls():
path = "/app/lamini-earnings-sdk/data/test_set_transcripts.jsonl"

with jsonlines.open(path) as reader:
for line in itertools.islice(reader, 1):
logger.info(f"Loaded earnings call for {line['ticker']}")
yield PromptObject(prompt="", data=line)


class QuestionAnswerPipeline(GenerationPipeline):
def __init__(self):
super(QuestionAnswerPipeline, self).__init__()
Expand All @@ -45,175 +24,107 @@ def __init__(self):
self.answer_generator = AnswerGenerator()

def forward(self, x):
x = self.question_generator(x)
x = self.question_generator(x, output_type={
"question_1": "str",
"question_2": "str",
"question_3": "str",
})
x = self.answer_generator(x)
return x

def get_company_info(chunk):
info = f"Company: {chunk.data['exchange']}\n"
info += f"Ticker: {chunk.data['ticker']}\n"
info += f"Date: {chunk.data['date']}\n"
info += f"Quarter: {chunk.data['q']}\n"
return info

class QuestionGenerator(GenerationNode):
def __init__(self):
super(QuestionGenerator, self).__init__(
model_name="meta-llama/Meta-Llama-3-8B-Instruct", max_new_tokens=150
)

def generate(
self,
prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
*args,
**kwargs,
):
prompt = self.add_template(prompt)

results = super(QuestionGenerator, self).generate(
prompt,
output_type={
"question_1": "string",
"question_2": "string",
"question_3": "string",
},
*args,
**kwargs,
)
return results

async def process_results(self, results):
async for result in results:
logger.debug(f"Generated question for {result}")
if result is None:
continue

if "question_1" not in result.response:
continue

if "question_2" not in result.response:
continue

if "question_3" not in result.response:
continue

questions = (
result.response["question_1"],
result.response["question_2"],
result.response["question_3"],
)
for question in questions:
result = PromptObject(prompt=question, data=result.data.copy())
yield result

async def add_template(self, prompts):
async for prompt in prompts:
chunks = chunk_prompt(prompt)
for chunk in chunks:
chunk.prompt = self.make_prompt(chunk)
logger.info(
f"Generating question for {chunk.data['ticker']}, {chunk.data['q']}"
)
yield chunk

def make_prompt(self, chunk):
prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"

prompt += (
"You are a financial analyst with extensive experience at Goldman Sachs."

def preprocess(self, obj: PromptObject):
obj.prompt = self.make_prompt(obj)
logger.info(f"Generating question for {obj.data['ticker']}, {obj.data['q']}")

def postprocess(self, obj: PromptObject):
response = obj.response
questions = [
response["question_1"],
response["question_2"],
response["question_3"],
]
for question in questions:
ans = PromptObject(prompt=question, data=obj.data.copy())
yield ans


def make_prompt(self, obj):
prompt = (
"<s>[INSTR]You are a financial analyst with extensive experience at Goldman Sachs."
)
prompt += "You are reading the earnings call transcript for the following company:\n\n"
prompt += "====================\n\n"
prompt += get_company_info(chunk) + "\n"
prompt += get_company_info(obj) + "\n"
prompt += "====================\n\n"
prompt += (
"You are reading the following section of the earnings call transcript:\n\n"
)
prompt += "====================\n\n"
prompt += chunk.data["transcript"]
prompt += obj.data["transcript"]
prompt += "====================\n\n"
prompt += "Consider the numbers in the transscript. "
prompt += "Consider the numbers in the transcript. "
prompt += "Ask three questions about the numbers in the transcript that require precise answers. "
prompt += "Only ask questions that can be answered using the transcript."
prompt += "<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>"
prompt +="[/INSTR]"

return prompt


def chunk_prompt(prompt):
transcript = prompt.data["transcript"]
chunk_size = 4096
chunk_step = 2048

for i in range(0, len(transcript), chunk_step):
chunk = transcript[i : i + chunk_size]
chunked_data = prompt.data.copy()
chunked_data["transcript"] = chunk
prompt_object = PromptObject(prompt=prompt.prompt, data=chunked_data)

yield prompt_object


def get_company_info(chunk):
info = f"Company: {chunk.data['exchange']}\n"
info += f"Ticker: {chunk.data['ticker']}\n"
info += f"Date: {chunk.data['date']}\n"
info += f"Quarter: {chunk.data['q']}\n"
return info


class AnswerGenerator(GenerationNode):
def __init__(self):
super(AnswerGenerator, self).__init__(
model_name="meta-llama/Meta-Llama-3-8B-Instruct", max_new_tokens=150
)

def generate(
self,
prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
*args,
**kwargs,
):
prompt = self.add_template(prompt)
results = super(AnswerGenerator, self).generate(prompt, output_type={"answer" : "str"}, *args, **kwargs)
return results

async def process_results(self, results):
async for result in results:
logger.info(f"Generated answer for {result}")
if result is None:
continue
yield result

async def add_template(self, prompts):
async for prompt in prompts:
logger.info(
f"Generating answer for {prompt.data['ticker']}, {prompt.data['q']}, {prompt.prompt}"
)
prompt.data["question"] = prompt.prompt
prompt.prompt = self.make_prompt(prompt)
yield prompt

def make_prompt(self, chunk):
prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
def postprocess(self, obj: PromptObject):
logger.info(f"Generated answer for {obj}")

prompt += (
"You are a financial analyst with extensive experience at Goldman Sachs."
def preprocess(self, obj: PromptObject):
obj.data["question"] = obj.prompt
obj.prompt = self.make_prompt(obj)

def make_prompt(self, obj: PromptObject):
prompt = (
"<s>[INSTR] You are a financial analyst with extensive experience at Goldman Sachs."
)
prompt += "You are reading the earnings call transcript for the following company:\n\n"
prompt += "====================\n\n"
prompt += get_company_info(chunk)
prompt += get_company_info(obj)
prompt += "====================\n\n"
prompt += (
"You are reading the following section of the earnings call transcript:\n\n"
)
prompt += "====================\n\n"
prompt += chunk.data["transcript"] + "\n"
prompt += obj.data["transcript"] + "\n"
prompt += "====================\n\n"
prompt += "Consider the numbers in the transcript. "
prompt += "If the answer to the question cannot be found in the transcript, reply that you do not know. "
prompt += "Answer the following questions about the numbers in the transcript. "
prompt += chunk.prompt
prompt += "<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>"
prompt += obj.prompt
prompt += "[/INSTR]"

return prompt


async def load_earnings_calls():
path = "/app/lamini-earnings-sdk/data/test_set_transcripts.jsonl"

with jsonlines.open(path) as reader:
for line in itertools.islice(reader, 1):
logger.info(f"Loaded earnings call for {line['ticker']}")
yield PromptObject(prompt="", data=line)

async def save_answers(answers):
path = "/app/lamini-earnings-sdk/data/results/generated_q_a.jsonl"
Expand All @@ -228,10 +139,16 @@ async def save_answers(answers):
"transcript": answer.data["transcript"],
"prompt": answer.prompt,
"question": answer.data["question"],
"answer": answer.response["answer"],
"answer": answer.response["output"],
}
writer.write(answer)
pbar.update()


asyncio.run(main())
async def run_pipeline():
earnings_calls = load_earnings_calls()
answers = QuestionAnswerPipeline().call(earnings_calls)
await save_answers(answers)


asyncio.run(run_pipeline())

0 comments on commit 70accea

Please sign in to comment.