-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Controller session management functions (#62)
* create session * fix error * wip * tests * remove unused env var * move test * remove debug * docs * util test * add more tests * change type to u64 * simplify local server when handling shutdown * comment * fix login server * comment * doc * better error handling
- Loading branch information
Showing
11 changed files
with
1,926 additions
and
525 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,147 @@ | ||
use std::sync::Arc; | ||
|
||
use anyhow::Result; | ||
use axum::{ | ||
extract::{Query, State}, | ||
response::{IntoResponse, Redirect, Response}, | ||
routing::get, | ||
Router, | ||
}; | ||
use clap::Args; | ||
use slot::{browser::Browser, server::LocalServer}; | ||
use tokio::runtime::Runtime; | ||
use graphql_client::GraphQLQuery; | ||
use hyper::StatusCode; | ||
use log::error; | ||
use serde::Deserialize; | ||
use slot::{ | ||
api::Client, | ||
browser, constant, | ||
credential::Credentials, | ||
graphql::auth::{ | ||
me::{ResponseData, Variables}, | ||
Me, | ||
}, | ||
server::LocalServer, | ||
}; | ||
use tokio::sync::mpsc::Sender; | ||
|
||
#[derive(Debug, Args)] | ||
pub struct LoginArgs; | ||
|
||
impl LoginArgs { | ||
pub fn run(&self) -> Result<()> { | ||
let rt = Runtime::new()?; | ||
pub async fn run(&self) -> Result<()> { | ||
let server = Self::callback_server().expect("Failed to create a server"); | ||
let port = server.local_addr()?.port(); | ||
let callback_uri = format!("http://localhost:{port}/callback"); | ||
|
||
let handler = std::thread::spawn(move || { | ||
let server = LocalServer::new().expect("Failed to start a server"); | ||
let addr = server.local_addr().unwrap(); | ||
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={callback_uri}"); | ||
|
||
let res = rt.block_on(async { tokio::join!(server.start(), Browser::open(&addr)) }); | ||
browser::open(&url)?; | ||
server.start().await?; | ||
|
||
match res { | ||
(Err(e), _) => { | ||
eprintln!("Server error: {e}"); | ||
} | ||
(_, Err(e)) => { | ||
eprintln!("Browser error: {e}"); | ||
} | ||
_ => { | ||
// println!("Success"); | ||
Ok(()) | ||
} | ||
|
||
fn callback_server() -> Result<LocalServer> { | ||
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1); | ||
let shared_state = Arc::new(AppState::new(tx)); | ||
|
||
let router = Router::new() | ||
.route("/callback", get(handler)) | ||
.with_state(shared_state); | ||
|
||
Ok(LocalServer::new(router)?.with_shutdown_signal(rx)) | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct CallbackPayload { | ||
code: Option<String>, | ||
} | ||
|
||
#[derive(Clone)] | ||
struct AppState { | ||
shutdown_tx: Sender<()>, | ||
} | ||
|
||
impl AppState { | ||
fn new(shutdown_tx: Sender<()>) -> Self { | ||
Self { shutdown_tx } | ||
} | ||
|
||
async fn shutdown(&self) -> Result<()> { | ||
self.shutdown_tx.send(()).await?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
enum CallbackError { | ||
#[error(transparent)] | ||
Io(#[from] std::io::Error), | ||
|
||
#[error("Api error: {0}")] | ||
Api(#[from] slot::api::Error), | ||
|
||
#[error(transparent)] | ||
Other(#[from] anyhow::Error), | ||
|
||
#[error(transparent)] | ||
Credentials(#[from] slot::credential::Error), | ||
} | ||
|
||
impl IntoResponse for CallbackError { | ||
fn into_response(self) -> Response { | ||
let status = StatusCode::INTERNAL_SERVER_ERROR; | ||
let message = format!("Something went wrong: {self}"); | ||
(status, message).into_response() | ||
} | ||
} | ||
|
||
async fn handler( | ||
State(state): State<Arc<AppState>>, | ||
Query(payload): Query<CallbackPayload>, | ||
) -> Result<Redirect, CallbackError> { | ||
// 1. Shutdown the server | ||
state.shutdown().await?; | ||
|
||
// 2. Get access token using the authorization code | ||
match payload.code { | ||
Some(code) => { | ||
let mut api = Client::new(); | ||
|
||
let token = api.oauth2(&code).await?; | ||
api.set_token(token.clone()); | ||
|
||
// fetch the account information | ||
let request_body = Me::build_query(Variables {}); | ||
let res: graphql_client::Response<ResponseData> = api.query(&request_body).await?; | ||
|
||
// display the errors if any, but still process bcs we have the token | ||
if let Some(errors) = res.errors { | ||
for err in errors { | ||
eprintln!("Error: {}", err.message); | ||
} | ||
} | ||
}); | ||
|
||
handler.join().unwrap(); | ||
let account_info = res.data.map(|data| data.me.expect("should exist")); | ||
|
||
Ok(()) | ||
// 3. Store the access token locally | ||
Credentials::new(account_info, token).store()?; | ||
|
||
println!("You are now logged in!\n"); | ||
|
||
Ok(Redirect::permanent(&format!( | ||
"{}/slot/auth/success", | ||
constant::CARTRIDGE_KEYCHAIN_URL | ||
))) | ||
} | ||
None => { | ||
error!("User denied consent. Try again."); | ||
|
||
Ok(Redirect::permanent(&format!( | ||
"{}/slot/auth/failure", | ||
constant::CARTRIDGE_KEYCHAIN_URL | ||
))) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use std::str::FromStr; | ||
|
||
use anyhow::{anyhow, ensure, Result}; | ||
use clap::Parser; | ||
use slot::session::{self, Policy}; | ||
use starknet::core::types::FieldElement; | ||
use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient, Provider}; | ||
use url::Url; | ||
|
||
#[derive(Debug, Parser)] | ||
pub struct CreateSession { | ||
#[arg(long)] | ||
#[arg(value_name = "URL")] | ||
// #[arg(default_value = "http://localhost:5050")] | ||
#[arg(help = "The RPC URL of the network you want to create a session for.")] | ||
rpc_url: String, | ||
|
||
#[arg(help = "The session's policies.")] | ||
#[arg(value_parser = parse_policy)] | ||
#[arg(required = true)] | ||
policies: Vec<Policy>, | ||
} | ||
|
||
impl CreateSession { | ||
pub async fn run(&self) -> Result<()> { | ||
let url = Url::parse(&self.rpc_url)?; | ||
let chain_id = get_network_chain_id(url.clone()).await?; | ||
let session = session::create(url, &self.policies).await?; | ||
session::store(chain_id, &session)?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
fn parse_policy(value: &str) -> Result<Policy> { | ||
let mut parts = value.split(','); | ||
|
||
let target = parts.next().ok_or(anyhow!("missing target"))?.to_owned(); | ||
let target = FieldElement::from_str(&target)?; | ||
let method = parts.next().ok_or(anyhow!("missing method"))?.to_owned(); | ||
|
||
ensure!(parts.next().is_none(), " bruh"); | ||
|
||
Ok(Policy { target, method }) | ||
} | ||
|
||
async fn get_network_chain_id(url: Url) -> Result<FieldElement> { | ||
let provider = JsonRpcClient::new(HttpTransport::new(url)); | ||
Ok(provider.chain_id().await?) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,9 @@ | ||
use anyhow::Result; | ||
use std::net::SocketAddr; | ||
use urlencoding::encode; | ||
|
||
pub struct Browser; | ||
|
||
impl Browser { | ||
pub async fn open(local_addr: &SocketAddr) -> Result<()> { | ||
let callback_uri = format!("http://{local_addr}/callback").replace("[::1]", "localhost"); | ||
let encoded_callback_uri = encode(&callback_uri); | ||
let url = format!("https://x.cartridge.gg/slot/auth?callback_uri={encoded_callback_uri}"); | ||
|
||
println!("Your browser has been opened to visit: \n\n {url}\n"); | ||
webbrowser::open(&url)?; | ||
|
||
Ok(()) | ||
} | ||
use anyhow::{Context, Result}; | ||
use tracing::trace; | ||
|
||
pub fn open(url: &str) -> Result<()> { | ||
trace!(%url, "Opening browser."); | ||
webbrowser::open(url).context("Failed to open web browser")?; | ||
println!("Your browser has been opened to visit: \n\n {url}\n"); | ||
Ok(()) | ||
} |
Oops, something went wrong.