Skip to content

Commit

Permalink
automatically include server function handler in .leptos_router()
Browse files Browse the repository at this point in the history
  • Loading branch information
gbj committed Jan 4, 2024
1 parent 6bfaa6b commit b07b84d
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 17 deletions.
Binary file modified examples/todo_app_sqlite_axum/Todos.db
Binary file not shown.
1 change: 0 additions & 1 deletion examples/todo_app_sqlite_axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ cfg_if! {

// build our application with a route
let app = Router::new()
.route("/api/*fn_name", post(leptos_axum::handle_server_fns))
.route("/special/:id", get(custom_handler))
.leptos_routes(&leptos_options, routes, || view! { <TodoApp/> } )
.fallback(file_and_error_handler)
Expand Down
30 changes: 28 additions & 2 deletions integrations/axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ use axum::{
http::{
header::{self, HeaderName, HeaderValue},
request::Parts,
HeaderMap, Request, Response, StatusCode,
HeaderMap, Method, Request, Response, StatusCode,
},
response::IntoResponse,
routing::{delete, get, patch, post, put},
RequestPartsExt,
};
use futures::{
channel::mpsc::{Receiver, Sender},
Expand Down Expand Up @@ -1540,6 +1539,8 @@ where
IV: IntoView + 'static,
{
let mut router = self;

// register router paths
for listing in paths.iter() {
let path = listing.path();

Expand Down Expand Up @@ -1631,6 +1632,31 @@ where
};
}
}

// register server functions
for (path, method) in server_fn::axum::server_fn_paths() {
let additional_context = additional_context.clone();
let handler = move |req: Request<Body>| async move {
handle_server_fns_with_context(additional_context, req).await
};
router = router.route(
path,
match method {
Method::GET => get(handler),
Method::POST => post(handler),
Method::PUT => put(handler),
Method::DELETE => delete(handler),
Method::PATCH => patch(handler),
_ => {
panic!(
"Unsupported server function HTTP method: \
{method:?}"
);
}
},
);
}

router
}

Expand Down
5 changes: 2 additions & 3 deletions server_fn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ multer = { version = "3", optional = true }
# serde
serde_json = "1"
futures = "0.3"
http = { version = "1", optional = true }
http = { version = "1" }
ciborium = { version = "0.2", optional = true }
hyper = { version = "1", optional = true }
bytes = "1"
Expand Down Expand Up @@ -67,7 +67,6 @@ reqwest = { version = "0.11", default-features = false, optional = true, feature
actix = ["dep:actix-web", "dep:send_wrapper"]
axum = [
"dep:axum",
"dep:http",
"dep:hyper",
"dep:http-body-util",
"dep:tower",
Expand All @@ -88,5 +87,5 @@ cbor = ["dep:ciborium"]
rkyv = ["dep:rkyv"]
default-tls = ["reqwest/default-tls"]
rustls = ["reqwest/rustls-tls"]
reqwest = ["dep:http", "dep:reqwest"]
reqwest = ["dep:reqwest"]
ssr = ["inventory"]
2 changes: 2 additions & 0 deletions server_fn/src/codec/cbor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ use crate::{
response::{ClientRes, Res},
};
use bytes::Bytes;
use http::Method;
use serde::{de::DeserializeOwned, Serialize};

/// Pass arguments and receive responses using `cbor` in a `POST` request.
pub struct Cbor;

impl Encoding for Cbor {
const CONTENT_TYPE: &'static str = "application/cbor";
const METHOD: Method = Method::POST;
}

impl<CustErr, T, Request> IntoReq<CustErr, Request, Cbor> for T
Expand Down
2 changes: 2 additions & 0 deletions server_fn/src/codec/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use crate::{
response::{ClientRes, Res},
IntoReq, IntoRes,
};
use http::Method;
use serde::{de::DeserializeOwned, Serialize};
/// Pass arguments and receive responses as JSON in the body of a `POST` request.
pub struct Json;

impl Encoding for Json {
const CONTENT_TYPE: &'static str = "application/json";
const METHOD: Method = Method::POST;
}

impl<CustErr, T, Request> IntoReq<CustErr, Request, Json> for T
Expand Down
2 changes: 2 additions & 0 deletions server_fn/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod cbor;
pub use cbor::*;
#[cfg(feature = "json")]
mod json;
use http::Method;
#[cfg(feature = "json")]
pub use json::*;
#[cfg(feature = "rkyv")]
Expand Down Expand Up @@ -59,6 +60,7 @@ pub trait IntoRes<CustErr, Response, Encoding> {

pub trait Encoding {
const CONTENT_TYPE: &'static str;
const METHOD: Method;
}

pub trait FormDataEncoding<Client, CustErr, Request>
Expand Down
2 changes: 2 additions & 0 deletions server_fn/src/codec/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ use crate::{
IntoReq,
};
use futures::StreamExt;
use http::Method;
use multer::Multipart;
use web_sys::FormData;

pub struct MultipartFormData;

impl Encoding for MultipartFormData {
const CONTENT_TYPE: &'static str = "multipart/form-data";
const METHOD: Method = Method::POST;
}

#[derive(Debug)]
Expand Down
2 changes: 2 additions & 0 deletions server_fn/src/codec/rkyv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
response::{ClientRes, Res},
};
use bytes::Bytes;
use http::Method;
use rkyv::{
de::deserializers::SharedDeserializeMap, ser::serializers::AllocSerializer,
validation::validators::DefaultValidator, Archive, CheckBytes, Deserialize,
Expand All @@ -16,6 +17,7 @@ pub struct Rkyv;

impl Encoding for Rkyv {
const CONTENT_TYPE: &'static str = "application/rkyv";
const METHOD: Method = Method::POST;
}

impl<CustErr, T, Request> IntoReq<CustErr, Request, Rkyv> for T
Expand Down
3 changes: 3 additions & 0 deletions server_fn/src/codec/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ use crate::{
};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use http::Method;
use std::pin::Pin;

pub struct Streaming;

impl Encoding for Streaming {
const CONTENT_TYPE: &'static str = "application/octet-stream";
const METHOD: Method = Method::POST;
}

/* impl<CustErr, T, Request> IntoReq<CustErr, Request, ByteStream> for T
Expand Down Expand Up @@ -81,6 +83,7 @@ pub struct StreamingText;

impl Encoding for StreamingText {
const CONTENT_TYPE: &'static str = "text/plain";
const METHOD: Method = Method::POST;
}

pub struct TextStream<CustErr = NoCustomError>(
Expand Down
3 changes: 3 additions & 0 deletions server_fn/src/codec/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
error::ServerFnError,
request::{ClientReq, Req},
};
use http::Method;
use serde::{de::DeserializeOwned, Serialize};

/// Pass arguments as a URL-encoded query string of a `GET` request.
Expand All @@ -13,6 +14,7 @@ pub struct PostUrl;

impl Encoding for GetUrl {
const CONTENT_TYPE: &'static str = "application/x-www-form-urlencoded";
const METHOD: Method = Method::GET;
}

impl<CustErr, T, Request> IntoReq<CustErr, Request, GetUrl> for T
Expand Down Expand Up @@ -46,6 +48,7 @@ where

impl Encoding for PostUrl {
const CONTENT_TYPE: &'static str = "application/x-www-form-urlencoded";
const METHOD: Method = Method::POST;
}

impl<CustErr, T, Request> IntoReq<CustErr, Request, PostUrl> for T
Expand Down
39 changes: 29 additions & 10 deletions server_fn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use const_format;
use dashmap::DashMap;
pub use error::ServerFnError;
use error::ServerFnErrorSerde;
use http::Method;
use middleware::{Layer, Service};
use once_cell::sync::Lazy;
use request::Req;
Expand Down Expand Up @@ -155,26 +156,29 @@ macro_rules! initialize_server_fn_map {
once_cell::sync::Lazy::new(|| {
$crate::inventory::iter::<ServerFnTraitObj<$req, $res>>
.into_iter()
.map(|obj| (obj.path(), *obj))
.map(|obj| (obj.path(), obj.clone()))
.collect()
})
};
}

pub struct ServerFnTraitObj<Req, Res> {
path: &'static str,
method: Method,
handler: fn(Req) -> Pin<Box<dyn Future<Output = Res> + Send>>,
middleware: fn() -> Vec<Arc<dyn Layer<Req, Res>>>,
}

impl<Req, Res> ServerFnTraitObj<Req, Res> {
pub const fn new(
path: &'static str,
method: Method,
handler: fn(Req) -> Pin<Box<dyn Future<Output = Res> + Send>>,
middleware: fn() -> Vec<Arc<dyn Layer<Req, Res>>>,
) -> Self {
Self {
path,
method,
handler,
middleware,
}
Expand All @@ -183,6 +187,10 @@ impl<Req, Res> ServerFnTraitObj<Req, Res> {
pub fn path(&self) -> &'static str {
self.path
}

pub fn method(&self) -> Method {
self.method.clone()
}
}

impl<Req, Res> Service<Req, Res> for ServerFnTraitObj<Req, Res>
Expand All @@ -198,12 +206,15 @@ where

impl<Req, Res> Clone for ServerFnTraitObj<Req, Res> {
fn clone(&self) -> Self {
*self
Self {
path: self.path,
method: self.method.clone(),
handler: self.handler,
middleware: self.middleware,
}
}
}

impl<Req, Res> Copy for ServerFnTraitObj<Req, Res> {}

type LazyServerFnMap<Req, Res> =
Lazy<DashMap<&'static str, ServerFnTraitObj<Req, Res>>>;

Expand All @@ -212,10 +223,10 @@ type LazyServerFnMap<Req, Res> =
pub mod axum {
use crate::{
middleware::{BoxedService, Layer, Service},
LazyServerFnMap, ServerFn, ServerFnTraitObj,
Encoding, LazyServerFnMap, ServerFn, ServerFnTraitObj,
};
use axum::body::Body;
use http::{Request, Response, StatusCode};
use http::{Method, Request, Response, StatusCode};

inventory::collect!(ServerFnTraitObj<Request<Body>, Response<Body>>);

Expand All @@ -235,12 +246,19 @@ pub mod axum {
T::PATH,
ServerFnTraitObj::new(
T::PATH,
T::InputEncoding::METHOD,
|req| Box::pin(T::run_on_server(req)),
T::middlewares,
),
);
}

pub fn server_fn_paths() -> impl Iterator<Item = (&'static str, Method)> {
REGISTERED_SERVER_FUNCTIONS
.iter()
.map(|item| (item.path(), item.method()))
}

pub async fn handle_server_fn(req: Request<Body>) -> Response<Body> {
let path = req.uri().path();

Expand Down Expand Up @@ -268,7 +286,7 @@ pub mod axum {
) -> Option<BoxedService<Request<Body>, Response<Body>>> {
REGISTERED_SERVER_FUNCTIONS.get(path).map(|server_fn| {
let middleware = (server_fn.middleware)();
let mut service = BoxedService::new(*server_fn);
let mut service = BoxedService::new(server_fn.clone());
for middleware in middleware {
service = middleware.layer(service);
}
Expand All @@ -282,11 +300,10 @@ pub mod axum {
pub mod actix {
use crate::{
middleware::BoxedService, request::actix::ActixRequest,
response::actix::ActixResponse, LazyServerFnMap, ServerFn,
response::actix::ActixResponse, Encoding, LazyServerFnMap, ServerFn,
ServerFnTraitObj,
};
use actix_web::{HttpRequest, HttpResponse};
use send_wrapper::SendWrapper;

inventory::collect!(ServerFnTraitObj<ActixRequest, ActixResponse>);

Expand All @@ -306,6 +323,7 @@ pub mod actix {
T::PATH,
ServerFnTraitObj::new(
T::PATH,
T::InputEncoding::METHOD,
|req| Box::pin(T::run_on_server(req)),
T::middlewares,
),
Expand All @@ -316,7 +334,8 @@ pub mod actix {
let path = req.uri().path();
if let Some(server_fn) = REGISTERED_SERVER_FUNCTIONS.get(path) {
let middleware = (server_fn.middleware)();
let mut service = BoxedService::new(*server_fn);
// http::Method is the only non-Copy type here
let mut service = BoxedService::new(server_fn.clone());
for middleware in middleware {
service = middleware.layer(service);
}
Expand Down
3 changes: 2 additions & 1 deletion server_fn_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,10 @@ pub fn server_macro_impl(
let inventory = if cfg!(feature = "ssr") {
quote! {
#server_fn_path::inventory::submit! {{
use #server_fn_path::ServerFn;
use #server_fn_path::{ServerFn, codec::Encoding};
#server_fn_path::ServerFnTraitObj::new(
#struct_name::PATH,
<#struct_name as ServerFn>::InputEncoding::METHOD,
|req| {
Box::pin(#struct_name::run_on_server(req))
},
Expand Down

0 comments on commit b07b84d

Please sign in to comment.