Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an encapsulated file stream in axum-extra to make it more conveni… #3047

Merged
merged 22 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ repository = "https://github.com/tokio-rs/axum"
version = "0.10.0-alpha.1"

[features]
default = ["tracing", "multipart"]
default = ["tracing", "multipart", "file-stream"]

async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"]
file-stream = ["dep:tokio-util", "tokio-util?/io", "dep:tokio", "tokio?/fs"]
file-stream = ["dep:tokio-util", "tokio-util?/io", "dep:tokio", "tokio?/fs", "tokio?/io-util", "dep:async-stream"]
attachment = ["dep:tracing"]
error_response = ["dep:tracing", "tracing/std"]
cookie = ["dep:cookie"]
Expand Down Expand Up @@ -57,6 +57,7 @@ tower-layer = "0.3"
tower-service = "0.3"

# optional dependencies
async-stream = { version = "0.3", optional = true }
axum-macros = { path = "../axum-macros", version = "0.5.0-alpha.1", optional = true }
cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true }
fastrand = { version = "2.1.0", optional = true }
Expand All @@ -75,6 +76,7 @@ tracing = { version = "0.1.37", default-features = false, optional = true }
typed-json = { version = "0.1.1", optional = true }

[dev-dependencies]
async-stream = "0.3"
axum = { path = "../axum", features = ["macros"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = "1.0.0"
Expand Down
261 changes: 254 additions & 7 deletions axum-extra/src/response/file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ use axum::{
BoxError,
};
use bytes::Bytes;
use futures_util::TryStream;
use futures_util::{Stream, TryStream};
use http::{header, StatusCode};
use std::{io, path::PathBuf};
use tokio::fs::File;
use std::{io, path::Path};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncSeekExt},
};
use tokio_util::io::ReaderStream;

/// Alias for `tokio_util::io::ReaderStream<File>`.
Expand All @@ -27,6 +30,7 @@ pub type AsyncReaderStream = ReaderStream<File>;
/// use axum_extra::response::file_stream::FileStream;
/// use tokio::fs::File;
/// use tokio_util::io::ReaderStream;
///
/// async fn file_stream() -> Result<Response, (StatusCode, String)> {
/// let stream=ReaderStream::new(File::open("test.txt").await.map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?);
/// let file_stream_resp = FileStream::new(stream)
Expand Down Expand Up @@ -70,6 +74,8 @@ where
/// Create a file stream from a file path.
/// # Examples
/// ```
/// use std::path::Path;
///
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// use axum::{
/// http::StatusCode,
/// response::{Response, IntoResponse},
Expand All @@ -80,16 +86,17 @@ where
/// use std::path::PathBuf;
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// use tokio::fs::File;
/// use tokio_util::io::ReaderStream;
///
/// async fn file_stream() -> Response {
/// FileStream::<ReaderStream<File>>::from_path(PathBuf::from("test.txt"))
/// FileStream::<ReaderStream<File>>::from_path(&PathBuf::from("test.txt"))
/// .await
/// .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
/// .into_response()
/// }
/// let app = Router::new().route("/FileStreamDownload", get(file_stream));
/// # let _: Router = app;
/// ```
pub async fn from_path(path: PathBuf) -> io::Result<FileStream<AsyncReaderStream>> {
pub async fn from_path(path: &Path) -> io::Result<FileStream<AsyncReaderStream>> {
jplatte marked this conversation as resolved.
Show resolved Hide resolved
// open file
let file = File::open(&path).await?;
jplatte marked this conversation as resolved.
Show resolved Hide resolved
let mut content_size = None;
Expand Down Expand Up @@ -126,6 +133,165 @@ where
self.content_size = Some(len.into());
jplatte marked this conversation as resolved.
Show resolved Hide resolved
self
}

/// return a range response
/// range: (start, end, total_size)
/// # Examples
///
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// ```
/// use axum::{
/// http::StatusCode,
/// response::{Response, IntoResponse},
/// Router,
/// routing::get
/// };
/// use axum_extra::response::file_stream::FileStream;
/// use tokio::fs::File;
/// use tokio_util::io::ReaderStream;
/// use tokio::io::AsyncSeekExt;
///
/// async fn range_response() -> Result<Response, (StatusCode, String)> {
/// let mut file=File::open("test.txt").await.map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
/// let mut file_size=file.metadata().await.map_err(|e| (StatusCode::NOT_FOUND, format!("Get file size: {e}")))?.len();
/// file.seek(std::io::SeekFrom::Start(10)).await.map_err(|e| (StatusCode::NOT_FOUND, format!("File seek error: {e}")))?;
/// let stream=ReaderStream::new(file);
///
/// Ok(FileStream::new(stream).into_range_response(10, file_size-1, file_size))
/// }
/// let app = Router::new().route("/FileStreamRange", get(range_response));
/// # let _: Router = app;
/// ```
pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
resp = resp.status(StatusCode::PARTIAL_CONTENT);

resp = resp.header(
header::CONTENT_RANGE,
format!("bytes {}-{}/{}", start, end, total_size),
);

resp.body(body::Body::from_stream(self.stream))
.unwrap_or_else(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("build FileStream responsec error: {}", e),
)
.into_response()
})
}

/// Attempts to return RANGE requests directly from the file path
/// # Arguments
/// * `file_path` - The path of the file to be streamed
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// * `start` - The start position of the range, if start > file size or start > end return Range Not Satisfiable
/// * `end` - The end position of the range if end == 0 end = file size - 1
/// * `buffer_size` - The buffer size of the range
/// # Examples
/// ```
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// use axum::{
/// http::StatusCode,
/// response::{Response, IntoResponse},
/// Router,
/// routing::get
/// };
/// use std::path::Path;
/// use axum_extra::response::file_stream::FileStream;
/// use tokio::fs::File;
/// use tokio_util::io::ReaderStream;
/// use tokio::io::AsyncSeekExt;
/// use axum_extra::response::AsyncReaderStream;
///
/// async fn range_stream() -> Response {
/// let range_start = 0;
/// let range_end = 1024;
/// let buffer_size = 1024;
///
/// FileStream::<AsyncReaderStream>::try_range_response(Path::new("CHANGELOG.md"),range_start,range_end,buffer_size).await
jplatte marked this conversation as resolved.
Show resolved Hide resolved
/// .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
/// .into_response()
///
/// }
/// let app = Router::new().route("/FileStreamRange", get(range_stream));
/// # let _: Router = app;
/// ```
pub async fn try_range_response(
file_path: &Path,
jplatte marked this conversation as resolved.
Show resolved Hide resolved
start: u64,
mut end: u64,
buffer_size: usize,
) -> io::Result<Response> {
// open file
let file = File::open(file_path).await?;

// get file metadata
let metadata = file.metadata().await?;
let total_size = metadata.len();

if end == 0 {
end = total_size - 1;
}

// range check
if start > total_size {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
if start > end {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
if end >= total_size {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}

// get file stream
let stream = try_stream(file, start, end, buffer_size).await?;
let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
resp = resp.status(StatusCode::PARTIAL_CONTENT);

resp = resp.header(
header::CONTENT_RANGE,
format!("bytes {}-{}/{}", start, end, total_size),
jplatte marked this conversation as resolved.
Show resolved Hide resolved
);

Ok(resp
.body(body::Body::from_stream(stream))
.unwrap_or_else(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("build FileStream responsec error: {}", e),
jplatte marked this conversation as resolved.
Show resolved Hide resolved
)
.into_response()
}))
}
}

/// More complex manipulation of files and conversion to a stream
async fn try_stream(
mut file: File,
start: u64,
end: u64,
buffer_size: usize,
) -> Result<impl Stream<Item = Result<Vec<u8>, io::Error>>, io::Error> {
file.seek(std::io::SeekFrom::Start(start)).await?;

let mut buffer = vec![0; buffer_size];

let stream = async_stream::try_stream! {
let mut total_read = 0;

while total_read < end {
let bytes_to_read = std::cmp::min(buffer_size as u64, end - total_read);
let n = file.read(&mut buffer[..bytes_to_read as usize]).await.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, e)
})?;
if n == 0 {
break; // EOF
}
total_read += n as u64;
yield buffer[..n].to_vec();

jplatte marked this conversation as resolved.
Show resolved Hide resolved
}
};
Ok(stream)
YanHeDoki marked this conversation as resolved.
Show resolved Hide resolved
jplatte marked this conversation as resolved.
Show resolved Hide resolved
}

impl<S> IntoResponse for FileStream<S>
Expand All @@ -152,7 +318,7 @@ where
.unwrap_or_else(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("build FileStream responsec error:{}", e),
format!("build FileStream responsec error: {}", e),
jplatte marked this conversation as resolved.
Show resolved Hide resolved
)
.into_response()
})
Expand All @@ -164,6 +330,7 @@ mod tests {
use super::*;
use axum::{extract::Request, routing::get, Router};
use body::Body;
use http::HeaderMap;
use http_body_util::BodyExt;
use std::io::Cursor;
use tokio_util::io::ReaderStream;
Expand Down Expand Up @@ -342,7 +509,7 @@ mod tests {
let app = Router::new().route(
"/from_path",
get(move || async move {
FileStream::<AsyncReaderStream>::from_path("CHANGELOG.md".into())
FileStream::<AsyncReaderStream>::from_path(Path::new("CHANGELOG.md"))
.await
.unwrap()
.into_response()
Expand Down Expand Up @@ -388,4 +555,84 @@ mod tests {
);
Ok(())
}

#[tokio::test]
async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route("/range_response", get(range_stream));

// Simulating a GET request
let response = app
.oneshot(
Request::builder()
.uri("/range_response")
.header(header::RANGE, "bytes=20-1000")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

// Validate Response Status Code
assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);

// Validate Response Headers
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);

let file = File::open("CHANGELOG.md").await.unwrap();
// get file size
let content_length = file.metadata().await.unwrap().len();

assert_eq!(
response
.headers()
.get("content-range")
.unwrap()
.to_str()
.unwrap(),
format!("bytes 20-1000/{}", content_length)
);
Ok(())
}

async fn range_stream(headers: HeaderMap) -> Response {
let range_header = headers
.get(header::RANGE)
.and_then(|value| value.to_str().ok());

let (start, end) = if let Some(range) = range_header {
if let Some(range) = parse_range_header(range) {
range
} else {
return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
}
} else {
(0, 0) // default range end = 0, if end = 0 end == file size - 1
};

FileStream::<AsyncReaderStream>::try_range_response(
Path::new("CHANGELOG.md"),
start,
end,
1024,
)
.await
.unwrap()
}

fn parse_range_header(range: &str) -> Option<(u64, u64)> {
let range = range.strip_prefix("bytes=")?;
let mut parts = range.split('-');
let start = parts.next()?.parse::<u64>().ok()?;
let end = parts
.next()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if start > end {
return None;
}
Some((start, end))
}
}
Loading