Skip to content

Commit

Permalink
update refinement.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chtmp223 committed Jan 26, 2024
1 parent e3e3684 commit 00a3ac6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 24 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ This repository contains scripts and prompts for our paper ["TopicGPT: Topic Mod
![TopicGPT Pipeline Overview](pipeline.png)

## Updates
- [02/18/23] Second-level topic generation code and refinement code are uploaded.
- [02/11/23] Basic pipeline is uploaded. Refinement and second-level topic generation code are coming soon.
- [11/18/23] Second-level topic generation code and refinement code are uploaded.
- [11/11/23] Basic pipeline is uploaded. Refinement and second-level topic generation code are coming soon.

## Setup
- Install the requirements: `pip install -r requirements.txt`
Expand Down Expand Up @@ -97,5 +97,4 @@ This repository contains scripts and prompts for our paper ["TopicGPT: Topic Mod
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```
30 changes: 12 additions & 18 deletions script/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -161,20 +161,26 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No updated/merged topics!\n"
]
}
],
"source": [
"# Refinement \n",
"# Run the script multiple times to get a better result\n",
"# Default: 1 runs\n",
"for run in range(1): \n",
" if run == 0:\n",
" %run refinement.py --deployment_name gpt-4 \\\n",
"%run refinement.py --deployment_name gpt-4 \\\n",
" --max_tokens 500 --temperature 0.0 --top_p 0.0 \\\n",
" --prompt_file $refinement_prompt \\\n",
" --generation_file $generation_out \\\n",
Expand All @@ -184,18 +190,6 @@
" --updated_file $refinement_updated \\\n",
" --mapping_file $refinement_mapping \\\n",
" --refined_again False \\\n",
" --remove False\n",
" else: \n",
" %run refinement.py --deployment_name gpt-4 \\\n",
" --max_tokens 500 --temperature 0.0 --top_p 0.0 \\\n",
" --prompt_file $refinement_prompt \\\n",
" --generation_file $generation_out \\\n",
" --topic_file $generation_topic \\\n",
" --out_file $refinement_out \\\n",
" --verbose True \\\n",
" --updated_file $refinement_updated \\\n",
" --mapping_file $refinement_mapping \\\n",
" --refined_again True \\\n",
" --remove False"
]
},
Expand Down
6 changes: 4 additions & 2 deletions script/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def main():
"--refined_again",
type=str,
default="refiner",
help="Is this the second time refining?",
help="Is this the second time you run refinement on the topics?",
)

args = parser.parse_args()
Expand Down Expand Up @@ -287,7 +287,7 @@ def main():

if len(responses) > 0:
# Writing updated topics ----
with open(args.topic_file, "w") as f:
with open(args.out_file, "w") as f:
print(tree_view(updated_topics_root), file=f)

# Writing orig-new mapping ----
Expand Down Expand Up @@ -323,6 +323,8 @@ def main():
updated_responses.append("\n".join(sub_list))
df["refined_responses"] = updated_responses
df.to_json(args.updated_file, lines=True, orient="records")
else:
print("No updated/merged topics!")


if __name__ == "__main__":
Expand Down

0 comments on commit 00a3ac6

Please sign in to comment.