Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Controller session management functions #62

Merged
merged 17 commits into from
Jun 16, 2024
1,540 changes: 1,210 additions & 330 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ version = "0.7.5"
edition = "2021"

[dependencies]
account-sdk = { git = "https://github.com/cartridge-gg/controller", rev = "53a3457" }
katana-primitives = { git = "https://github.com/dojoengine/dojo", tag = "v0.7.0-alpha.3" }
anyhow = "1.0.75"
axum = "0.6"
base64 = "0.21.2"
clap = { version = "4.2", features = ["derive"] }
chrono = "0.4.31"
chrono = "0.4.38"
ctrlc = "3.4.1"
dirs = "5"
env_logger = "0.10"
Expand All @@ -25,13 +26,15 @@ serde_json = "1"
shellexpand = "3.1.0"
thiserror = "1.0.32"
tokio = { version = "1.18.2", features = ["full", "sync"] }
tower-http = "0.4"
tower-http = { version = "0.4", features = ["cors", "trace"] }
tracing = "0.1.34"
urlencoding = "2"
webbrowser = "0.8"
starknet = "0.10.0"
url = "2.2.2"
tempfile = "3.10.1"
hyper = "0.14.27"
tower-layer = "0.3.2"

[[bin]]
name = "slot"
Expand Down
151 changes: 131 additions & 20 deletions src/bin/command/auth/login.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,147 @@
use std::sync::Arc;

use anyhow::Result;
use axum::{
extract::{Query, State},
response::{IntoResponse, Redirect, Response},
routing::get,
Router,
};
use clap::Args;
use slot::{browser::Browser, server::LocalServer};
use tokio::runtime::Runtime;
use graphql_client::GraphQLQuery;
use hyper::StatusCode;
use log::error;
use serde::Deserialize;
use slot::{
api::Client,
browser, constant,
credential::Credentials,
graphql::auth::{
me::{ResponseData, Variables},
Me,
},
server::LocalServer,
};
use tokio::sync::mpsc::Sender;

#[derive(Debug, Args)]
pub struct LoginArgs;

impl LoginArgs {
pub fn run(&self) -> Result<()> {
let rt = Runtime::new()?;
pub async fn run(&self) -> Result<()> {
let server = Self::callback_server().expect("Failed to create a server");
let port = server.local_addr()?.port();
let callback_uri = format!("http://localhost:{port}/callback");

let handler = std::thread::spawn(move || {
let server = LocalServer::new().expect("Failed to start a server");
let addr = server.local_addr().unwrap();
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={callback_uri}");

let res = rt.block_on(async { tokio::join!(server.start(), Browser::open(&addr)) });
browser::open(&url)?;
server.start().await?;

match res {
(Err(e), _) => {
eprintln!("Server error: {e}");
}
(_, Err(e)) => {
eprintln!("Browser error: {e}");
}
_ => {
// println!("Success");
Ok(())
}

fn callback_server() -> Result<LocalServer> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
let shared_state = Arc::new(AppState::new(tx));

let router = Router::new()
.route("/callback", get(handler))
.with_state(shared_state);

Ok(LocalServer::new(router)?.with_shutdown_signal(rx))
}
}

#[derive(Debug, Deserialize)]
struct CallbackPayload {
code: Option<String>,
}

#[derive(Clone)]
struct AppState {
shutdown_tx: Sender<()>,
}

impl AppState {
fn new(shutdown_tx: Sender<()>) -> Self {
Self { shutdown_tx }
}

async fn shutdown(&self) -> Result<()> {
self.shutdown_tx.send(()).await?;
Ok(())
}
}

#[derive(Debug, thiserror::Error)]
enum CallbackError {
#[error(transparent)]
Io(#[from] std::io::Error),

#[error("Api error: {0}")]
Api(#[from] slot::api::Error),

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error(transparent)]
Credentials(#[from] slot::credential::Error),
}

impl IntoResponse for CallbackError {
fn into_response(self) -> Response {
let status = StatusCode::INTERNAL_SERVER_ERROR;
let message = format!("Something went wrong: {self}");
(status, message).into_response()
}
}

async fn handler(
State(state): State<Arc<AppState>>,
Query(payload): Query<CallbackPayload>,
) -> Result<Redirect, CallbackError> {
// 1. Shutdown the server
state.shutdown().await?;

// 2. Get access token using the authorization code
match payload.code {
Some(code) => {
let mut api = Client::new();

let token = api.oauth2(&code).await?;
api.set_token(token.clone());

// fetch the account information
let request_body = Me::build_query(Variables {});
let res: graphql_client::Response<ResponseData> = api.query(&request_body).await?;

// display the errors if any, but still process bcs we have the token
if let Some(errors) = res.errors {
for err in errors {
eprintln!("Error: {}", err.message);
}
}
});

handler.join().unwrap();
let account_info = res.data.map(|data| data.me.expect("should exist"));

Ok(())
// 3. Store the access token locally
Credentials::new(account_info, token).store()?;

println!("You are now logged in!\n");

Ok(Redirect::permanent(&format!(
"{}/slot/auth/success",
constant::CARTRIDGE_KEYCHAIN_URL
)))
}
None => {
error!("User denied consent. Try again.");

Ok(Redirect::permanent(&format!(
"{}/slot/auth/failure",
constant::CARTRIDGE_KEYCHAIN_URL
)))
}
}
}
7 changes: 6 additions & 1 deletion src/bin/command/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@ use self::{info::InfoArgs, login::LoginArgs};

mod info;
mod login;
mod session;

#[derive(Subcommand, Debug)]
pub enum Auth {
#[command(about = "Login to your Cartridge account.")]
Login(LoginArgs),
#[command(about = "Display info about the authenticated user.")]
Info(InfoArgs),
// Mostly for testing purposes, will eventually turn it into a library call from `sozo`.
#[command(hide = true)]
CreateSession(session::CreateSession),
}

impl Auth {
pub async fn run(&self) -> Result<()> {
match &self {
Auth::Login(args) => args.run(),
Auth::Login(args) => args.run().await,
Auth::Info(args) => args.run().await,
Auth::CreateSession(args) => args.run().await,
}
}
}
49 changes: 49 additions & 0 deletions src/bin/command/auth/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::str::FromStr;

use anyhow::{anyhow, ensure, Result};
use clap::Parser;
use slot::session::{self, Policy};
use starknet::core::types::FieldElement;
use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient, Provider};
use url::Url;

#[derive(Debug, Parser)]
pub struct CreateSession {
#[arg(long)]
#[arg(value_name = "URL")]
// #[arg(default_value = "http://localhost:5050")]
#[arg(help = "The RPC URL of the network you want to create a session for.")]
rpc_url: String,

#[arg(help = "The session's policies.")]
#[arg(value_parser = parse_policy)]
#[arg(required = true)]
policies: Vec<Policy>,
}

impl CreateSession {
pub async fn run(&self) -> Result<()> {
let url = Url::parse(&self.rpc_url)?;
let chain_id = get_network_chain_id(url.clone()).await?;
let session = session::create(url, &self.policies).await?;
session::store(chain_id, &session)?;
Ok(())
}
}

fn parse_policy(value: &str) -> Result<Policy> {
let mut parts = value.split(',');

let target = parts.next().ok_or(anyhow!("missing target"))?.to_owned();
let target = FieldElement::from_str(&target)?;
let method = parts.next().ok_or(anyhow!("missing method"))?.to_owned();

ensure!(parts.next().is_none(), " bruh");

Ok(Policy { target, method })
}

async fn get_network_chain_id(url: Url) -> Result<FieldElement> {
let provider = JsonRpcClient::new(HttpTransport::new(url));
Ok(provider.chain_id().await?)
}
25 changes: 8 additions & 17 deletions src/browser.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
use anyhow::Result;
use std::net::SocketAddr;
use urlencoding::encode;

pub struct Browser;

impl Browser {
pub async fn open(local_addr: &SocketAddr) -> Result<()> {
let callback_uri = format!("http://{local_addr}/callback").replace("[::1]", "localhost");
let encoded_callback_uri = encode(&callback_uri);
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={encoded_callback_uri}");

println!("Your browser has been opened to visit: \n\n {url}\n");
webbrowser::open(&url)?;

Ok(())
}
use anyhow::{Context, Result};
use tracing::trace;

pub fn open(url: &str) -> Result<()> {
trace!(%url, "Opening browser.");
webbrowser::open(url).context("Failed to open web browser")?;
println!("Your browser has been opened to visit: \n\n {url}\n");
Ok(())
}
Loading
Loading