Skip to content

Commit

Permalink
Add Controller session management functions (#62)
Browse files Browse the repository at this point in the history
* create session

* fix error

* wip

* tests

* remove unused env var

* move test

* remove debug

* docs

* util test

* add more tests

* change type to u64

* simplify local server when handling shutdown

* comment

* fix login server

* comment

* doc

* better error handling
  • Loading branch information
kariy authored Jun 16, 2024
1 parent c85145f commit f5d674c
Show file tree
Hide file tree
Showing 11 changed files with 1,926 additions and 525 deletions.
1,540 changes: 1,210 additions & 330 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ version = "0.7.5"
edition = "2021"

[dependencies]
account-sdk = { git = "https://github.com/cartridge-gg/controller", rev = "53a3457" }
katana-primitives = { git = "https://github.com/dojoengine/dojo", tag = "v0.7.0-alpha.3" }
anyhow = "1.0.75"
axum = "0.6"
base64 = "0.21.2"
clap = { version = "4.2", features = ["derive"] }
chrono = "0.4.31"
chrono = "0.4.38"
ctrlc = "3.4.1"
dirs = "5"
env_logger = "0.10"
Expand All @@ -25,13 +26,15 @@ serde_json = "1"
shellexpand = "3.1.0"
thiserror = "1.0.32"
tokio = { version = "1.18.2", features = ["full", "sync"] }
tower-http = "0.4"
tower-http = { version = "0.4", features = ["cors", "trace"] }
tracing = "0.1.34"
urlencoding = "2"
webbrowser = "0.8"
starknet = "0.10.0"
url = "2.2.2"
tempfile = "3.10.1"
hyper = "0.14.27"
tower-layer = "0.3.2"

[[bin]]
name = "slot"
Expand Down
151 changes: 131 additions & 20 deletions src/bin/command/auth/login.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,147 @@
use std::sync::Arc;

use anyhow::Result;
use axum::{
extract::{Query, State},
response::{IntoResponse, Redirect, Response},
routing::get,
Router,
};
use clap::Args;
use slot::{browser::Browser, server::LocalServer};
use tokio::runtime::Runtime;
use graphql_client::GraphQLQuery;
use hyper::StatusCode;
use log::error;
use serde::Deserialize;
use slot::{
api::Client,
browser, constant,
credential::Credentials,
graphql::auth::{
me::{ResponseData, Variables},
Me,
},
server::LocalServer,
};
use tokio::sync::mpsc::Sender;

#[derive(Debug, Args)]
pub struct LoginArgs;

impl LoginArgs {
pub fn run(&self) -> Result<()> {
let rt = Runtime::new()?;
pub async fn run(&self) -> Result<()> {
let server = Self::callback_server().expect("Failed to create a server");
let port = server.local_addr()?.port();
let callback_uri = format!("http://localhost:{port}/callback");

let handler = std::thread::spawn(move || {
let server = LocalServer::new().expect("Failed to start a server");
let addr = server.local_addr().unwrap();
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={callback_uri}");

let res = rt.block_on(async { tokio::join!(server.start(), Browser::open(&addr)) });
browser::open(&url)?;
server.start().await?;

match res {
(Err(e), _) => {
eprintln!("Server error: {e}");
}
(_, Err(e)) => {
eprintln!("Browser error: {e}");
}
_ => {
// println!("Success");
Ok(())
}

fn callback_server() -> Result<LocalServer> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
let shared_state = Arc::new(AppState::new(tx));

let router = Router::new()
.route("/callback", get(handler))
.with_state(shared_state);

Ok(LocalServer::new(router)?.with_shutdown_signal(rx))
}
}

#[derive(Debug, Deserialize)]
struct CallbackPayload {
code: Option<String>,
}

#[derive(Clone)]
struct AppState {
shutdown_tx: Sender<()>,
}

impl AppState {
fn new(shutdown_tx: Sender<()>) -> Self {
Self { shutdown_tx }
}

async fn shutdown(&self) -> Result<()> {
self.shutdown_tx.send(()).await?;
Ok(())
}
}

#[derive(Debug, thiserror::Error)]
enum CallbackError {
#[error(transparent)]
Io(#[from] std::io::Error),

#[error("Api error: {0}")]
Api(#[from] slot::api::Error),

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error(transparent)]
Credentials(#[from] slot::credential::Error),
}

impl IntoResponse for CallbackError {
fn into_response(self) -> Response {
let status = StatusCode::INTERNAL_SERVER_ERROR;
let message = format!("Something went wrong: {self}");
(status, message).into_response()
}
}

async fn handler(
State(state): State<Arc<AppState>>,
Query(payload): Query<CallbackPayload>,
) -> Result<Redirect, CallbackError> {
// 1. Shutdown the server
state.shutdown().await?;

// 2. Get access token using the authorization code
match payload.code {
Some(code) => {
let mut api = Client::new();

let token = api.oauth2(&code).await?;
api.set_token(token.clone());

// fetch the account information
let request_body = Me::build_query(Variables {});
let res: graphql_client::Response<ResponseData> = api.query(&request_body).await?;

// display the errors if any, but still process bcs we have the token
if let Some(errors) = res.errors {
for err in errors {
eprintln!("Error: {}", err.message);
}
}
});

handler.join().unwrap();
let account_info = res.data.map(|data| data.me.expect("should exist"));

Ok(())
// 3. Store the access token locally
Credentials::new(account_info, token).store()?;

println!("You are now logged in!\n");

Ok(Redirect::permanent(&format!(
"{}/slot/auth/success",
constant::CARTRIDGE_KEYCHAIN_URL
)))
}
None => {
error!("User denied consent. Try again.");

Ok(Redirect::permanent(&format!(
"{}/slot/auth/failure",
constant::CARTRIDGE_KEYCHAIN_URL
)))
}
}
}
7 changes: 6 additions & 1 deletion src/bin/command/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@ use self::{info::InfoArgs, login::LoginArgs};

mod info;
mod login;
mod session;

#[derive(Subcommand, Debug)]
pub enum Auth {
#[command(about = "Login to your Cartridge account.")]
Login(LoginArgs),
#[command(about = "Display info about the authenticated user.")]
Info(InfoArgs),
// Mostly for testing purposes, will eventually turn it into a library call from `sozo`.
#[command(hide = true)]
CreateSession(session::CreateSession),
}

impl Auth {
pub async fn run(&self) -> Result<()> {
match &self {
Auth::Login(args) => args.run(),
Auth::Login(args) => args.run().await,
Auth::Info(args) => args.run().await,
Auth::CreateSession(args) => args.run().await,
}
}
}
49 changes: 49 additions & 0 deletions src/bin/command/auth/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::str::FromStr;

use anyhow::{anyhow, ensure, Result};
use clap::Parser;
use slot::session::{self, Policy};
use starknet::core::types::FieldElement;
use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient, Provider};
use url::Url;

#[derive(Debug, Parser)]
pub struct CreateSession {
#[arg(long)]
#[arg(value_name = "URL")]
// #[arg(default_value = "http://localhost:5050")]
#[arg(help = "The RPC URL of the network you want to create a session for.")]
rpc_url: String,

#[arg(help = "The session's policies.")]
#[arg(value_parser = parse_policy)]
#[arg(required = true)]
policies: Vec<Policy>,
}

impl CreateSession {
pub async fn run(&self) -> Result<()> {
let url = Url::parse(&self.rpc_url)?;
let chain_id = get_network_chain_id(url.clone()).await?;
let session = session::create(url, &self.policies).await?;
session::store(chain_id, &session)?;
Ok(())
}
}

fn parse_policy(value: &str) -> Result<Policy> {
let mut parts = value.split(',');

let target = parts.next().ok_or(anyhow!("missing target"))?.to_owned();
let target = FieldElement::from_str(&target)?;
let method = parts.next().ok_or(anyhow!("missing method"))?.to_owned();

ensure!(parts.next().is_none(), " bruh");

Ok(Policy { target, method })
}

async fn get_network_chain_id(url: Url) -> Result<FieldElement> {
let provider = JsonRpcClient::new(HttpTransport::new(url));
Ok(provider.chain_id().await?)
}
25 changes: 8 additions & 17 deletions src/browser.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
use anyhow::Result;
use std::net::SocketAddr;
use urlencoding::encode;

pub struct Browser;

impl Browser {
pub async fn open(local_addr: &SocketAddr) -> Result<()> {
let callback_uri = format!("http://{local_addr}/callback").replace("[::1]", "localhost");
let encoded_callback_uri = encode(&callback_uri);
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={encoded_callback_uri}");

println!("Your browser has been opened to visit: \n\n {url}\n");
webbrowser::open(&url)?;

Ok(())
}
use anyhow::{Context, Result};
use tracing::trace;

pub fn open(url: &str) -> Result<()> {
trace!(%url, "Opening browser.");
webbrowser::open(url).context("Failed to open web browser")?;
println!("Your browser has been opened to visit: \n\n {url}\n");
Ok(())
}
Loading

0 comments on commit f5d674c

Please sign in to comment.