Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evaluate fiqa openai #1211

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 319 additions & 0 deletions bootcamp/Evaluation/evaluate_fiqa_openai.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# ! python -m pip install openai beir ragas==0.0.17"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1706\n"
]
}
],
"source": [
"import json\n",
"import pandas as pd\n",
"import os\n",
"from tqdm import tqdm\n",
"from datasets import Dataset\n",
"from beir import util\n",
"\n",
"\n",
"def prepare_fiqa_without_answer(knowledge_path):\n",
" dataset_name = \"fiqa\"\n",
"\n",
" if not os.path.exists(os.path.join(knowledge_path, f'{dataset_name}.zip')):\n",
" url = (\n",
" \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n",
" dataset_name\n",
" )\n",
" )\n",
" util.download_and_unzip(url, knowledge_path)\n",
"\n",
" data_path = os.path.join(knowledge_path, 'fiqa')\n",
" with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n",
" cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
"\n",
" corpus_df = pd.DataFrame(cs)\n",
"\n",
" corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n",
" corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n",
" corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n",
" corpus_df.head()\n",
"\n",
" with open(os.path.join(data_path, \"queries.jsonl\")) as f:\n",
" qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
"\n",
" queries_df = pd.DataFrame(qs)\n",
" queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n",
" queries_df = queries_df.drop(columns=[\"metadata\"])\n",
" queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n",
" queries_df.head()\n",
"\n",
" splits = [\"dev\", \"test\", \"train\"]\n",
" split_df = {}\n",
" for s in splits:\n",
" split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n",
" columns=[\"score\"]\n",
" )\n",
"\n",
" final_split_df = {}\n",
" for split in split_df:\n",
" df = queries_df.merge(split_df[split], on=\"query-id\")\n",
" df = df.merge(corpus_df, on=\"corpus-id\")\n",
" df = df.drop(columns=[\"corpus-id\"])\n",
" grouped = df.groupby(\"query-id\").apply(\n",
" lambda x: pd.Series(\n",
" {\n",
" \"question\": x[\"question\"].sample().values[0],\n",
" \"ground_truths\": x[\"ground_truth\"].tolist(),\n",
" }\n",
" )\n",
" )\n",
"\n",
" grouped = grouped.reset_index()\n",
" grouped = grouped.drop(columns=\"query-id\")\n",
" final_split_df[split] = grouped\n",
"\n",
" return final_split_df\n",
"\n",
"\n",
"knowledge_datas_path = './knowledge_datas'\n",
"fiqa_path = os.path.join(knowledge_datas_path, 'fiqa_doc.txt')\n",
"\n",
"if not os.path.exists(knowledge_datas_path):\n",
" os.mkdir(knowledge_datas_path)\n",
"contexts_list = []\n",
"answer_list = []\n",
"\n",
"final_split_df = prepare_fiqa_without_answer(knowledge_datas_path)\n",
"\n",
"docs = []\n",
"\n",
"split = 'test'\n",
"for ds in final_split_df[split][\"ground_truths\"]:\n",
" docs.extend([d for d in ds])\n",
"print(len(docs))\n",
"\n",
"docs_str = '\\n'.join(docs)\n",
"with open(fiqa_path, 'w') as f:\n",
" f.write(docs_str)\n",
"\n",
"split = 'test'\n",
"question_list = final_split_df[split][\"question\"].to_list()\n",
"ground_truth_list = final_split_df[split][\"ground_truths\"].to_list()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"import time\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI()\n",
"\n",
"# Set OPENAI_API_KEY in your environment value\n",
"client.api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"\n",
"class OpenAITimeoutException(Exception):\n",
" pass\n",
"\n",
"\n",
"def get_content_from_retrieved_message(message):\n",
" # Extract the message content\n",
" message_content = message.content[0].text\n",
" annotations = message_content.annotations\n",
" contexts = []\n",
" for annotation in annotations:\n",
" message_content.value = message_content.value.replace(annotation.text, f'')\n",
" if (file_citation := getattr(annotation, 'file_citation', None)):\n",
" contexts.append(file_citation.quote)\n",
" if len(contexts) == 0:\n",
" contexts = ['empty context.']\n",
" return message_content.value, contexts\n",
"\n",
"\n",
"def try_get_answer_contexts(assistant_id, question, timeout_seconds=120):\n",
" thread = client.beta.threads.create(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": question,\n",
" }\n",
" ]\n",
" )\n",
" thread_id = thread.id\n",
" run = client.beta.threads.runs.create(\n",
" thread_id=thread_id,\n",
" assistant_id=assistant_id,\n",
" )\n",
" start_time = time.time()\n",
" while True:\n",
" elapsed_time = time.time() - start_time\n",
" if elapsed_time > timeout_seconds:\n",
" raise Exception(\"OpenAI retrieving answer Timeout!\")\n",
"\n",
" run = client.beta.threads.runs.retrieve(\n",
" thread_id=thread_id,\n",
" run_id=run.id\n",
" )\n",
" if run.status == 'completed':\n",
" break\n",
" messages = client.beta.threads.messages.list(\n",
" thread_id=thread_id\n",
" )\n",
" assert len(messages.data) > 1\n",
" res, contexts = get_content_from_retrieved_message(messages.data[0])\n",
" response = client.beta.threads.delete(thread_id)\n",
" assert response.deleted is True\n",
" return contexts, res\n",
"\n",
"\n",
"def get_answer_contexts_from_assistant(question, assistant_id, timeout_seconds=120, retry_num=6):\n",
" res = 'failed. please retry.'\n",
" contexts = ['failed. please retry.']\n",
" try:\n",
" for _ in range(retry_num):\n",
" try:\n",
" contexts, res = try_get_answer_contexts(assistant_id, question, timeout_seconds)\n",
" break\n",
" except OpenAITimeoutException as e:\n",
" print('OpenAI retrieving answer Timeout, retry...')\n",
" continue\n",
" except Exception as e:\n",
" print(e)\n",
" return res, contexts"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/648 [03:45<?, ?it/s]\n",
"\n",
"KeyboardInterrupt\n",
"\n"
]
}
],
"source": [
"file = client.files.create(\n",
" file=open(fiqa_path, \"rb\"),\n",
" purpose='assistants'\n",
")\n",
"\n",
"# Add the file to the assistant\n",
"assistant = client.beta.assistants.create(\n",
" instructions=\"You are a customer support chatbot. You must use your retrieval tool to retrieve relevant knowledge to best respond to customer queries.\",\n",
" model=\"gpt-4-1106-preview\",\n",
" tools=[{\"type\": \"retrieval\"}],\n",
" file_ids=[file.id]\n",
")\n",
"\n",
"for question in tqdm(question_list):\n",
" answer, contexts = get_answer_contexts_from_assistant(question, assistant.id)\n",
" # print(f'answer = {answer}')\n",
" # print(f'contexts = {contexts}')\n",
" # print('=' * 80)\n",
" answer_list.append(answer)\n",
" contexts_list.append(contexts)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from ragas import evaluate\n",
"from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision, answer_similarity\n",
"\n",
"ds = Dataset.from_dict({\"question\": question_list,\n",
" \"contexts\": contexts_list,\n",
" \"answer\": answer_list,\n",
" \"ground_truths\": ground_truth_list})\n",
"\n",
"result = evaluate(\n",
" ds,\n",
" metrics=[\n",
" context_precision,\n",
" # context_recall,\n",
" # faithfulness,\n",
" # answer_relevancy,\n",
" # answer_similarity,\n",
" # answer_correctness,\n",
" ],\n",
"\n",
")\n",
"print(result)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}