Skip to content

Commit

Permalink
Add deployment for prediction-request-rag mech (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
evangriffiths authored Apr 25, 2024
1 parent 1416bec commit 992e43d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 7 deletions.
58 changes: 57 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions prediction_market_agent/agents/mech_agent/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,9 @@ class DeployablePredictionOfflineSMEAgent(DeployableMechAgentBase):
def load(self) -> None:
self.local = True
self.tool = MechTool.PREDICTION_OFFLINE_SME


class DeployablePredictionRequestRAGAgent(DeployableMechAgentBase):
def load(self) -> None:
self.local = True
self.tool = MechTool.PREDICTION_REQUEST_RAG
3 changes: 3 additions & 0 deletions prediction_market_agent/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DeployablePredictionOfflineSMEAgent,
DeployablePredictionOnlineAgent,
DeployablePredictionOnlineSMEAgent,
DeployablePredictionRequestRAGAgent,
)
from prediction_market_agent.agents.replicate_to_omen_agent.deploy import (
DeployableReplicateToOmenAgent,
Expand All @@ -39,6 +40,7 @@ class RunnableAgent(str, Enum):
mech_prediction_offline = "mech_prediction-offline"
mech_prediction_online_sme = "mech_prediction-online-sme"
mech_prediction_offline_sme = "mech_prediction-offline-sme"
mech_prediction_request_rag = "mech_prediction-request-rag"


RUNNABLE_AGENTS = {
Expand All @@ -50,6 +52,7 @@ class RunnableAgent(str, Enum):
RunnableAgent.mech_prediction_offline: DeployablePredictionOfflineAgent,
RunnableAgent.mech_prediction_online_sme: DeployablePredictionOnlineSMEAgent,
RunnableAgent.mech_prediction_offline_sme: DeployablePredictionOfflineSMEAgent,
RunnableAgent.mech_prediction_request_rag: DeployablePredictionRequestRAGAgent,
}


Expand Down
14 changes: 14 additions & 0 deletions prediction_market_agent/tools/mech/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from prediction_market_agent_tooling.benchmark.utils import OutcomePrediction

from prediction_market_agent.tools.mech.api_keys import MechAPIKeys
from prediction_market_agent.tools.mech.mech.packages.napthaai.customs.prediction_request_rag import (
prediction_request_rag,
)
from prediction_market_agent.tools.mech.mech.packages.nickcom007.customs.prediction_request_sme import (
prediction_request_sme,
)
Expand Down Expand Up @@ -38,6 +41,7 @@ class MechTool(str, Enum):
PREDICTION_OFFLINE = "prediction-offline"
PREDICTION_ONLINE_SME = "prediction-online-sme"
PREDICTION_OFFLINE_SME = "prediction-offline-sme"
PREDICTION_REQUEST_RAG = "prediction-request-rag"


def mech_request(question: str, mech_tool: MechTool) -> OutcomePrediction:
Expand Down Expand Up @@ -94,6 +98,16 @@ def mech_request_local(question: str, mech_tool: MechTool) -> OutcomePrediction:
"google_engine_id": keys.google_search_engine_id.get_secret_value(),
},
)
elif mech_tool == MechTool.PREDICTION_REQUEST_RAG:
response = prediction_request_rag.run(
tool=mech_tool.value,
prompt=question,
api_keys={
"openai": keys.openai_api_key.get_secret_value(),
"google_api_key": keys.google_search_api_key.get_secret_value(),
"google_engine_id": keys.google_search_engine_id.get_secret_value(),
},
)
else:
raise ValueError(f"Mech type '{mech_tool}' not supported")

Expand Down
14 changes: 8 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ isort = "^5.13.2"
markdownify = "^0.11.6"
tavily-python = "^0.3.1"
microchain-python = "^0.3.5"
setuptools = "^69.5.1"
jsonschema = "^4.3.3"
chromadb = "^0.4.24"
spacy = "^3.7.4"
readability-lxml = "^0.8.1"
lxml = {extras = ["html-clean"], version = "^5.2.1"}
setuptools = "^69.5.1" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
jsonschema = "^4.3.3" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
chromadb = "^0.4.24" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
spacy = "^3.7.4" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
readability-lxml = "^0.8.1" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
lxml = {extras = ["html-clean"], version = "^5.2.1"} # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
pypdf2 = "^3.0.1" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97
faiss-cpu = "^1.8.0" # TODO remove with https://github.com/gnosis/prediction-market-agent/issues/97

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit 992e43d

Please sign in to comment.