Skip to content

Commit

Permalink
fix embedding model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu0307 committed Jan 2, 2024
1 parent 05fe755 commit d9fdd18
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
4 changes: 0 additions & 4 deletions taskweaver/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def __init__(self, config: LLMModuleConfig, injector: Injector) -> None:
self.embedding_service = PlaceholderEmbeddingService(
"Azure ML does not support embeddings yet. Please configure a different embedding API.",
)
elif self.config.embedding_api_type == "qwen":
self.embedding_service = PlaceholderEmbeddingService(
"QWen does not support embeddings yet. Please configure a different embedding API.",
)
else:
raise ValueError(
f"Embedding API type {self.config.embedding_api_type} is not supported",
Expand Down
2 changes: 1 addition & 1 deletion taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def summarize_experience_in_batch(
exp_embeddings = self.llm_api.get_embedding_list([exp.experience_text for exp in self.experience_list])
for i, session_id in enumerate(session_ids):
self.experience_list[i].embedding = exp_embeddings[i]
self.experience_list[i].embedding_model = self.llm_api.config.embedding_model
self.experience_list[i].embedding_model = self.llm_api.embedding_service.config.embedding_model
self.logger.info("Experience embeddings created. Embeddings number: {}".format(len(exp_embeddings)))

for exp in self.experience_list:
Expand Down

0 comments on commit d9fdd18

Please sign in to comment.