From bd690c6f00e34d6a0408c4e9cc6ae642af8316b0 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 31 Oct 2024 10:26:17 +0000 Subject: [PATCH] feat: basic KB creation & search --- cli/Cargo.toml | 8 ++- cli/src/main.rs | 142 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 147 insertions(+), 3 deletions(-) diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 297baea..53b9dce 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -4,12 +4,18 @@ version = "0.1.0" edition = "2021" [dependencies] -griptape = { version = "2023.9.19", path = "../griptape" } +griptape = { path = "../griptape" } tokio = { version = "1.41.0", features = ["full"] } reqwest = { version = "^0.12", features = ["json", "multipart"] } serde = { version = "^1.0", features = ["derive"] } serde_with = { version = "^3.8", default-features = false, features = ["base64", "std", "macros"] } serde_json = "^1.0" +clap = "4.5.20" +clap_derive = "4.5.18" +env_logger = "0.11.5" +log = "0.4.22" +toml = "0.8.19" +anyhow = "1.0.91" [[bin]] name = "griptape" diff --git a/cli/src/main.rs b/cli/src/main.rs index 3db307f..59b67b0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,3 +1,141 @@ -pub fn main() { - println!("Hello Griptape!"); +use clap_derive::{Parser, Subcommand}; +use anyhow::{Context, Result}; +use clap::Parser; +use tokio; +use std::env; +use griptape::{apis::{configuration::Configuration, knowledge_bases_api::*, data_connectors_api::*}, + models::{self, CreateDataConnectorRequestContent, DataConnectorConfigInputUnion, Webscraper}}; +use env_logger; +use log::{debug, info}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, } + +#[derive(Subcommand)] +enum Commands { + Query { + #[arg(long, short)] + name: String, + question: String, + }, + Record { + #[arg(long, short)] + name: String, + #[arg(long, short)] + url: Vec, + }, +} + +trait ConnectorLookup { + + /// Find a DataConnector by name + fn get_by_name(&self, name: &str) -> Option; +} + +impl ConnectorLookup for Vec { + fn get_by_name(&self, name: &str) -> Option { + self.iter().find(|c| c.name.to_uppercase() == name.to_uppercase()).cloned() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + + let api_key = if let Ok(v) = env::var("GT_CLOUD_API_KEY") { + v + } else { + eprintln!("Missing GT_CLOUD_API_KEY env. variable."); + std::process::exit(1); + }; + + env_logger::init(); + + let cli = Cli::parse(); + + let mut config = Configuration::new(); + config.bearer_access_token = Some(api_key); + + match &cli.command { + Commands::Query { name, question } => { + let answer = ask(&config, &name, &question).await.context("Couldn't ask the question")?; + println!("Question: {question}"); + println!("Answer:\n{answer}"); + } + Commands::Record { name, url } => { + let _ = record(&config, &name, url.to_vec()).await; + } + } + + Ok(()) +} + +// Stubbed async function for `query` +async fn ask(config: &Configuration, name: &str, question: &str) -> Result { + let kbs = list_knowledge_bases(&config, None, None).await + .context("Couldn't list knowledge bases")? + .knowledge_bases.context("No knowledge bases found")?; + let kb = kbs.iter().find(|k| k.name.to_uppercase() == name.to_uppercase()).context("Couldn't find knowledge base")?; + + let req = models::SearchKnowledgeBaseRequestContent::new(question.to_string()); + let res = search_knowledge_base(&config, &kb.knowledge_base_id, req).await.context("Couldn't query knowledge base")?; + let answer = get_knowledge_base_search(&config, &res.knowledge_base_search_id).await.context("Couldn't get answer")?; + + debug!("Answer is {answer:?}"); + + Ok(answer.result) +} + +// Record the given URLs identified by name +async fn record(config: &Configuration, name: &str, url: Vec) -> Result<()> { + // call to another module would go here + //FIXME support pagination + let connectors = list_data_connectors(&config, None, None).await + .context("Couldn't download data connectors from Griptape Cloud")? + .data_connectors.context("No data connectors found")?; + + debug!("Found {} connectors", connectors.len()); + let dc = connectors.get_by_name(name); + let dc_id = if let Some(dc) = dc { + debug!("Found connector {dc:?} for {name}"); + dc.data_connector_id + } else { + debug!("Connector {name} not found. Creating"); + //FIXME support schedule, with default + let req = CreateDataConnectorRequestContent::new( + name.to_string(), + DataConnectorConfigInputUnion::Webscraper( + Box::new(Webscraper::new(models::WebscraperInput::new( + url.to_vec(), + ))), + ), + "webscraper".to_string(), + ); + //FIXME fails with "missing field `data_job_id` - raise bug + let res = create_data_connector(&config, req).await.context("Couldn't create Data Connector")?; + res.data_connector_id + }; + + //FIXME support pagination + let kbs = list_knowledge_bases(&config, None, None).await + .context("Couldn't list knowledge bases")? + .knowledge_bases.context("No knowledge bases found")?; + let kb = kbs.iter().find(|k| k.name.to_uppercase() == name.to_uppercase()).context("Couldn't find knowledge base").ok(); + + if let Some(kb) = kb { + debug!("Knowledge base exists {}", kb.knowledge_base_id) + } else { + debug!("Creating knowledge base {name}"); + let req = models::CreateKnowledgeBaseRequestContent::new( + name.to_string(), + vec![dc_id] + ); + let res = create_knowledge_base(&config, req).await.context("Couldn't create knowledge base")?; + debug!("Created knowlege base {name} {}", res.knowledge_base_id); + } + Ok(()) +} +