From c40cbfbbc547612791efa9f1904cf171b7f12f7b Mon Sep 17 00:00:00 2001 From: xiaobeicn Date: Fri, 27 Sep 2024 11:35:43 +0800 Subject: [PATCH] fix and add --- README.md | 2 +- examples/costorm_examples/run_costorm_gpt.py | 5 ++++ knowledge_storm/encoder.py | 27 +------------------- 3 files changed, 7 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 9df0507e..8cdf03ff 100644 --- a/README.md +++ b/README.md @@ -255,7 +255,7 @@ python examples/storm_examples/run_storm_wiki_gpt.py \ To run Co-STORM with `gpt` family models with default configurations, -1. Add `BING_SEARCH_API_KEY="xxx"`to `secrets.toml` +1. Add `BING_SEARCH_API_KEY="xxx"` and `ENCODER_API_TYPE="xxx"` to `secrets.toml` 2. Run the following command ```bash diff --git a/examples/costorm_examples/run_costorm_gpt.py b/examples/costorm_examples/run_costorm_gpt.py index 66d6bd3c..5f59a80d 100644 --- a/examples/costorm_examples/run_costorm_gpt.py +++ b/examples/costorm_examples/run_costorm_gpt.py @@ -143,6 +143,11 @@ def main(args): with open(os.path.join(args.output_dir, "report.md"), "w") as f: f.write(article) + # Save instance dump + instance_copy = costorm_runner.to_dict() + with open(os.path.join(args.output_dir, "instance_dump.json"), "w") as f: + json.dump(instance_copy, f, indent=2) + # Save logging log_dump = costorm_runner.dump_logging_and_reset() with open(os.path.join(args.output_dir, "log.json"), "w") as f: diff --git a/knowledge_storm/encoder.py b/knowledge_storm/encoder.py index 3d14e63c..01fc9725 100644 --- a/knowledge_storm/encoder.py +++ b/knowledge_storm/encoder.py @@ -7,38 +7,13 @@ class EmbeddingModel: - def __init__(): + def __init__(self): pass def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: raise Exception("Not implemented") -class OpenAIEmbeddingModel(EmbeddingModel): - def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): - if not api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - - self.url = "https://api.openai.com/v1/embeddings" - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - self.model = model - - def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: - data = {"input": text, "model": "text-embedding-3-small"} - - response = requests.post(self.url, headers=self.headers, json=data) - if response.status_code == 200: - data = response.json() - embedding = np.array(data["data"][0]["embedding"]) - token = data["usage"]["prompt_tokens"] - return embedding, token - else: - response.raise_for_status() - - class OpenAIEmbeddingModel(EmbeddingModel): def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): if not api_key: