Skip to content

Commit

Permalink
Merge pull request #1211 from zc277584121/master
Browse files Browse the repository at this point in the history
evaluate fiqa openai
  • Loading branch information
jaelgu authored Dec 4, 2023
2 parents d7d45be + 1c9c92a commit 9b35ae6
Showing 1 changed file with 319 additions and 0 deletions.
319 changes: 319 additions & 0 deletions bootcamp/Evaluation/evaluate_fiqa_openai.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# ! python -m pip install openai beir ragas==0.0.17"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1706\n"
]
}
],
"source": [
"import json\n",
"import pandas as pd\n",
"import os\n",
"from tqdm import tqdm\n",
"from datasets import Dataset\n",
"from beir import util\n",
"\n",
"\n",
"def prepare_fiqa_without_answer(knowledge_path):\n",
" dataset_name = \"fiqa\"\n",
"\n",
" if not os.path.exists(os.path.join(knowledge_path, f'{dataset_name}.zip')):\n",
" url = (\n",
" \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n",
" dataset_name\n",
" )\n",
" )\n",
" util.download_and_unzip(url, knowledge_path)\n",
"\n",
" data_path = os.path.join(knowledge_path, 'fiqa')\n",
" with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n",
" cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
"\n",
" corpus_df = pd.DataFrame(cs)\n",
"\n",
" corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n",
" corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n",
" corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n",
" corpus_df.head()\n",
"\n",
" with open(os.path.join(data_path, \"queries.jsonl\")) as f:\n",
" qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
"\n",
" queries_df = pd.DataFrame(qs)\n",
" queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n",
" queries_df = queries_df.drop(columns=[\"metadata\"])\n",
" queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n",
" queries_df.head()\n",
"\n",
" splits = [\"dev\", \"test\", \"train\"]\n",
" split_df = {}\n",
" for s in splits:\n",
" split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n",
" columns=[\"score\"]\n",
" )\n",
"\n",
" final_split_df = {}\n",
" for split in split_df:\n",
" df = queries_df.merge(split_df[split], on=\"query-id\")\n",
" df = df.merge(corpus_df, on=\"corpus-id\")\n",
" df = df.drop(columns=[\"corpus-id\"])\n",
" grouped = df.groupby(\"query-id\").apply(\n",
" lambda x: pd.Series(\n",
" {\n",
" \"question\": x[\"question\"].sample().values[0],\n",
" \"ground_truths\": x[\"ground_truth\"].tolist(),\n",
" }\n",
" )\n",
" )\n",
"\n",
" grouped = grouped.reset_index()\n",
" grouped = grouped.drop(columns=\"query-id\")\n",
" final_split_df[split] = grouped\n",
"\n",
" return final_split_df\n",
"\n",
"\n",
"knowledge_datas_path = './knowledge_datas'\n",
"fiqa_path = os.path.join(knowledge_datas_path, 'fiqa_doc.txt')\n",
"\n",
"if not os.path.exists(knowledge_datas_path):\n",
" os.mkdir(knowledge_datas_path)\n",
"contexts_list = []\n",
"answer_list = []\n",
"\n",
"final_split_df = prepare_fiqa_without_answer(knowledge_datas_path)\n",
"\n",
"docs = []\n",
"\n",
"split = 'test'\n",
"for ds in final_split_df[split][\"ground_truths\"]:\n",
" docs.extend([d for d in ds])\n",
"print(len(docs))\n",
"\n",
"docs_str = '\\n'.join(docs)\n",
"with open(fiqa_path, 'w') as f:\n",
" f.write(docs_str)\n",
"\n",
"split = 'test'\n",
"question_list = final_split_df[split][\"question\"].to_list()\n",
"ground_truth_list = final_split_df[split][\"ground_truths\"].to_list()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"import time\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI()\n",
"\n",
"# Set OPENAI_API_KEY in your environment value\n",
"client.api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"\n",
"class OpenAITimeoutException(Exception):\n",
" pass\n",
"\n",
"\n",
"def get_content_from_retrieved_message(message):\n",
" # Extract the message content\n",
" message_content = message.content[0].text\n",
" annotations = message_content.annotations\n",
" contexts = []\n",
" for annotation in annotations:\n",
" message_content.value = message_content.value.replace(annotation.text, f'')\n",
" if (file_citation := getattr(annotation, 'file_citation', None)):\n",
" contexts.append(file_citation.quote)\n",
" if len(contexts) == 0:\n",
" contexts = ['empty context.']\n",
" return message_content.value, contexts\n",
"\n",
"\n",
"def try_get_answer_contexts(assistant_id, question, timeout_seconds=120):\n",
" thread = client.beta.threads.create(\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": question,\n",
" }\n",
" ]\n",
" )\n",
" thread_id = thread.id\n",
" run = client.beta.threads.runs.create(\n",
" thread_id=thread_id,\n",
" assistant_id=assistant_id,\n",
" )\n",
" start_time = time.time()\n",
" while True:\n",
" elapsed_time = time.time() - start_time\n",
" if elapsed_time > timeout_seconds:\n",
" raise Exception(\"OpenAI retrieving answer Timeout!\")\n",
"\n",
" run = client.beta.threads.runs.retrieve(\n",
" thread_id=thread_id,\n",
" run_id=run.id\n",
" )\n",
" if run.status == 'completed':\n",
" break\n",
" messages = client.beta.threads.messages.list(\n",
" thread_id=thread_id\n",
" )\n",
" assert len(messages.data) > 1\n",
" res, contexts = get_content_from_retrieved_message(messages.data[0])\n",
" response = client.beta.threads.delete(thread_id)\n",
" assert response.deleted is True\n",
" return contexts, res\n",
"\n",
"\n",
"def get_answer_contexts_from_assistant(question, assistant_id, timeout_seconds=120, retry_num=6):\n",
" res = 'failed. please retry.'\n",
" contexts = ['failed. please retry.']\n",
" try:\n",
" for _ in range(retry_num):\n",
" try:\n",
" contexts, res = try_get_answer_contexts(assistant_id, question, timeout_seconds)\n",
" break\n",
" except OpenAITimeoutException as e:\n",
" print('OpenAI retrieving answer Timeout, retry...')\n",
" continue\n",
" except Exception as e:\n",
" print(e)\n",
" return res, contexts"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/648 [03:45<?, ?it/s]\n",
"\n",
"KeyboardInterrupt\n",
"\n"
]
}
],
"source": [
"file = client.files.create(\n",
" file=open(fiqa_path, \"rb\"),\n",
" purpose='assistants'\n",
")\n",
"\n",
"# Add the file to the assistant\n",
"assistant = client.beta.assistants.create(\n",
" instructions=\"You are a customer support chatbot. You must use your retrieval tool to retrieve relevant knowledge to best respond to customer queries.\",\n",
" model=\"gpt-4-1106-preview\",\n",
" tools=[{\"type\": \"retrieval\"}],\n",
" file_ids=[file.id]\n",
")\n",
"\n",
"for question in tqdm(question_list):\n",
" answer, contexts = get_answer_contexts_from_assistant(question, assistant.id)\n",
" # print(f'answer = {answer}')\n",
" # print(f'contexts = {contexts}')\n",
" # print('=' * 80)\n",
" answer_list.append(answer)\n",
" contexts_list.append(contexts)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from ragas import evaluate\n",
"from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision, answer_similarity\n",
"\n",
"ds = Dataset.from_dict({\"question\": question_list,\n",
" \"contexts\": contexts_list,\n",
" \"answer\": answer_list,\n",
" \"ground_truths\": ground_truth_list})\n",
"\n",
"result = evaluate(\n",
" ds,\n",
" metrics=[\n",
" context_precision,\n",
" # context_recall,\n",
" # faithfulness,\n",
" # answer_relevancy,\n",
" # answer_similarity,\n",
" # answer_correctness,\n",
" ],\n",
"\n",
")\n",
"print(result)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 9b35ae6

Please sign in to comment.