From 1c9c92a006bce18ddbe8d74a1348d66209a79287 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Mon, 4 Dec 2023 10:15:31 +0800 Subject: [PATCH] evaluate fiqa openai Signed-off-by: ChengZi --- .../Evaluation/evaluate_fiqa_openai.ipynb | 319 ++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 bootcamp/Evaluation/evaluate_fiqa_openai.ipynb diff --git a/bootcamp/Evaluation/evaluate_fiqa_openai.ipynb b/bootcamp/Evaluation/evaluate_fiqa_openai.ipynb new file mode 100644 index 000000000..e13003749 --- /dev/null +++ b/bootcamp/Evaluation/evaluate_fiqa_openai.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# ! python -m pip install openai beir ragas==0.0.17" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1706\n" + ] + } + ], + "source": [ + "import json\n", + "import pandas as pd\n", + "import os\n", + "from tqdm import tqdm\n", + "from datasets import Dataset\n", + "from beir import util\n", + "\n", + "\n", + "def prepare_fiqa_without_answer(knowledge_path):\n", + " dataset_name = \"fiqa\"\n", + "\n", + " if not os.path.exists(os.path.join(knowledge_path, f'{dataset_name}.zip')):\n", + " url = (\n", + " \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n", + " dataset_name\n", + " )\n", + " )\n", + " util.download_and_unzip(url, knowledge_path)\n", + "\n", + " data_path = os.path.join(knowledge_path, 'fiqa')\n", + " with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n", + " cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", + "\n", + " corpus_df = pd.DataFrame(cs)\n", + "\n", + " corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n", + " corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n", + " corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n", + " corpus_df.head()\n", + "\n", + " with open(os.path.join(data_path, \"queries.jsonl\")) as f:\n", + " qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", + "\n", + " queries_df = pd.DataFrame(qs)\n", + " queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n", + " queries_df = queries_df.drop(columns=[\"metadata\"])\n", + " queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n", + " queries_df.head()\n", + "\n", + " splits = [\"dev\", \"test\", \"train\"]\n", + " split_df = {}\n", + " for s in splits:\n", + " split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n", + " columns=[\"score\"]\n", + " )\n", + "\n", + " final_split_df = {}\n", + " for split in split_df:\n", + " df = queries_df.merge(split_df[split], on=\"query-id\")\n", + " df = df.merge(corpus_df, on=\"corpus-id\")\n", + " df = df.drop(columns=[\"corpus-id\"])\n", + " grouped = df.groupby(\"query-id\").apply(\n", + " lambda x: pd.Series(\n", + " {\n", + " \"question\": x[\"question\"].sample().values[0],\n", + " \"ground_truths\": x[\"ground_truth\"].tolist(),\n", + " }\n", + " )\n", + " )\n", + "\n", + " grouped = grouped.reset_index()\n", + " grouped = grouped.drop(columns=\"query-id\")\n", + " final_split_df[split] = grouped\n", + "\n", + " return final_split_df\n", + "\n", + "\n", + "knowledge_datas_path = './knowledge_datas'\n", + "fiqa_path = os.path.join(knowledge_datas_path, 'fiqa_doc.txt')\n", + "\n", + "if not os.path.exists(knowledge_datas_path):\n", + " os.mkdir(knowledge_datas_path)\n", + "contexts_list = []\n", + "answer_list = []\n", + "\n", + "final_split_df = prepare_fiqa_without_answer(knowledge_datas_path)\n", + "\n", + "docs = []\n", + "\n", + "split = 'test'\n", + "for ds in final_split_df[split][\"ground_truths\"]:\n", + " docs.extend([d for d in ds])\n", + "print(len(docs))\n", + "\n", + "docs_str = '\\n'.join(docs)\n", + "with open(fiqa_path, 'w') as f:\n", + " f.write(docs_str)\n", + "\n", + "split = 'test'\n", + "question_list = final_split_df[split][\"question\"].to_list()\n", + "ground_truth_list = final_split_df[split][\"ground_truths\"].to_list()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "import time\n", + "from openai import OpenAI\n", + "\n", + "client = OpenAI()\n", + "\n", + "# Set OPENAI_API_KEY in your environment value\n", + "client.api_key = os.getenv('OPENAI_API_KEY')\n", + "\n", + "\n", + "class OpenAITimeoutException(Exception):\n", + " pass\n", + "\n", + "\n", + "def get_content_from_retrieved_message(message):\n", + " # Extract the message content\n", + " message_content = message.content[0].text\n", + " annotations = message_content.annotations\n", + " contexts = []\n", + " for annotation in annotations:\n", + " message_content.value = message_content.value.replace(annotation.text, f'')\n", + " if (file_citation := getattr(annotation, 'file_citation', None)):\n", + " contexts.append(file_citation.quote)\n", + " if len(contexts) == 0:\n", + " contexts = ['empty context.']\n", + " return message_content.value, contexts\n", + "\n", + "\n", + "def try_get_answer_contexts(assistant_id, question, timeout_seconds=120):\n", + " thread = client.beta.threads.create(\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": question,\n", + " }\n", + " ]\n", + " )\n", + " thread_id = thread.id\n", + " run = client.beta.threads.runs.create(\n", + " thread_id=thread_id,\n", + " assistant_id=assistant_id,\n", + " )\n", + " start_time = time.time()\n", + " while True:\n", + " elapsed_time = time.time() - start_time\n", + " if elapsed_time > timeout_seconds:\n", + " raise Exception(\"OpenAI retrieving answer Timeout!\")\n", + "\n", + " run = client.beta.threads.runs.retrieve(\n", + " thread_id=thread_id,\n", + " run_id=run.id\n", + " )\n", + " if run.status == 'completed':\n", + " break\n", + " messages = client.beta.threads.messages.list(\n", + " thread_id=thread_id\n", + " )\n", + " assert len(messages.data) > 1\n", + " res, contexts = get_content_from_retrieved_message(messages.data[0])\n", + " response = client.beta.threads.delete(thread_id)\n", + " assert response.deleted is True\n", + " return contexts, res\n", + "\n", + "\n", + "def get_answer_contexts_from_assistant(question, assistant_id, timeout_seconds=120, retry_num=6):\n", + " res = 'failed. please retry.'\n", + " contexts = ['failed. please retry.']\n", + " try:\n", + " for _ in range(retry_num):\n", + " try:\n", + " contexts, res = try_get_answer_contexts(assistant_id, question, timeout_seconds)\n", + " break\n", + " except OpenAITimeoutException as e:\n", + " print('OpenAI retrieving answer Timeout, retry...')\n", + " continue\n", + " except Exception as e:\n", + " print(e)\n", + " return res, contexts" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/648 [03:45