From 00a3ac67c88c5fc24db0eb5fed04c053418d7459 Mon Sep 17 00:00:00 2001 From: Chau Pham Date: Fri, 26 Jan 2024 17:53:13 +0700 Subject: [PATCH] update refinement.py --- README.md | 7 +++---- script/example.ipynb | 30 ++++++++++++------------------ script/refinement.py | 6 ++++-- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 8def8f9..2cdd92f 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ This repository contains scripts and prompts for our paper ["TopicGPT: Topic Mod ![TopicGPT Pipeline Overview](pipeline.png) ## Updates -- [02/18/23] Second-level topic generation code and refinement code are uploaded. -- [02/11/23] Basic pipeline is uploaded. Refinement and second-level topic generation code are coming soon. +- [11/18/23] Second-level topic generation code and refinement code are uploaded. +- [11/11/23] Basic pipeline is uploaded. Refinement and second-level topic generation code are coming soon. ## Setup - Install the requirements: `pip install -r requirements.txt` @@ -97,5 +97,4 @@ This repository contains scripts and prompts for our paper ["TopicGPT: Topic Mod archivePrefix={arXiv}, primaryClass={cs.CL} } -``` - +``` \ No newline at end of file diff --git a/script/example.ipynb b/script/example.ipynb index ce28ab2..6be33dd 100644 --- a/script/example.ipynb +++ b/script/example.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -161,20 +161,26 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": { "vscode": { "languageId": "shellscript" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No updated/merged topics!\n" + ] + } + ], "source": [ "# Refinement \n", "# Run the script multiple times to get a better result\n", "# Default: 1 runs\n", - "for run in range(1): \n", - " if run == 0:\n", - " %run refinement.py --deployment_name gpt-4 \\\n", + "%run refinement.py --deployment_name gpt-4 \\\n", " --max_tokens 500 --temperature 0.0 --top_p 0.0 \\\n", " --prompt_file $refinement_prompt \\\n", " --generation_file $generation_out \\\n", @@ -184,18 +190,6 @@ " --updated_file $refinement_updated \\\n", " --mapping_file $refinement_mapping \\\n", " --refined_again False \\\n", - " --remove False\n", - " else: \n", - " %run refinement.py --deployment_name gpt-4 \\\n", - " --max_tokens 500 --temperature 0.0 --top_p 0.0 \\\n", - " --prompt_file $refinement_prompt \\\n", - " --generation_file $generation_out \\\n", - " --topic_file $generation_topic \\\n", - " --out_file $refinement_out \\\n", - " --verbose True \\\n", - " --updated_file $refinement_updated \\\n", - " --mapping_file $refinement_mapping \\\n", - " --refined_again True \\\n", " --remove False" ] }, diff --git a/script/refinement.py b/script/refinement.py index 224f324..9568ac8 100644 --- a/script/refinement.py +++ b/script/refinement.py @@ -250,7 +250,7 @@ def main(): "--refined_again", type=str, default="refiner", - help="Is this the second time refining?", + help="Is this the second time you run refinement on the topics?", ) args = parser.parse_args() @@ -287,7 +287,7 @@ def main(): if len(responses) > 0: # Writing updated topics ---- - with open(args.topic_file, "w") as f: + with open(args.out_file, "w") as f: print(tree_view(updated_topics_root), file=f) # Writing orig-new mapping ---- @@ -323,6 +323,8 @@ def main(): updated_responses.append("\n".join(sub_list)) df["refined_responses"] = updated_responses df.to_json(args.updated_file, lines=True, orient="records") + else: + print("No updated/merged topics!") if __name__ == "__main__":