Skip to content

Commit

Permalink
Merge pull request #156 from akadan47/fix-hls-compatibility-security
Browse files Browse the repository at this point in the history
Content types for HLS playlist & segments.
  • Loading branch information
harlanc authored Nov 5, 2024
2 parents ad709e0 + 81dcc8e commit 1057281
Showing 1 changed file with 152 additions and 52 deletions.
204 changes: 152 additions & 52 deletions protocol/hls/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,135 @@ use {

type GenericError = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, GenericError>;

static NOTFOUND: &[u8] = b"Not Found";
static UNAUTHORIZED: &[u8] = b"Unauthorized";

async fn handle_connection(State(auth): State<Option<Auth>>, req: Request<Body>) -> Response<Body> {
let path = req.uri().path();
#[derive(Debug)]
enum HlsFileType {
Playlist,
Segment,
}

let query_string: Option<String> = req.uri().query().map(|s| s.to_string());
let mut file_path: String = String::from("");

if path.ends_with(".m3u8") {
//http://127.0.0.1/app_name/stream_name/stream_name.m3u8
let m3u8_index = path.find(".m3u8").unwrap();

if m3u8_index > 0 {
let (left, _) = path.split_at(m3u8_index);
let rv: Vec<_> = left.split('/').collect();

let app_name = String::from(rv[1]);
let stream_name = String::from(rv[2]);

if let Some(auth_val) = auth {
if auth_val
.authenticate(
&stream_name,
&query_string.map(SecretCarrier::Query),
true,
)
.is_err()
{
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(UNAUTHORIZED.into())
.unwrap();
}
}

file_path = format!("./{app_name}/{stream_name}/{stream_name}.m3u8");
impl HlsFileType {
const CONTENT_TYPE_PLAYLIST: &'static str = "application/vnd.apple.mpegurl";
const CONTENT_TYPE_SEGMENT: &'static str = "video/mp2t";

fn content_type(&self) -> &str {
match self {
Self::Playlist => Self::CONTENT_TYPE_PLAYLIST,
Self::Segment => Self::CONTENT_TYPE_SEGMENT,
}
} else if path.ends_with(".ts") {
//http://127.0.0.1/app_name/stream_name/ts_name.m3u8
let ts_index = path.find(".ts").unwrap();
}
}

if ts_index > 0 {
let (left, _) = path.split_at(ts_index);
#[derive(Debug)]
struct HlsPath {
app_name: String,
stream_name: String,
file_name: String,
file_type: HlsFileType,
}

let rv: Vec<_> = left.split('/').collect();
impl HlsPath {
const M3U8_EXT: &'static str = "m3u8";
const TS_EXT: &'static str = "ts";

let app_name = String::from(rv[1]);
let stream_name = String::from(rv[2]);
let ts_name = String::from(rv[3]);
fn parse(path: &str) -> Option<Self> {
if path.is_empty() || path.contains("..") {
return None;
}

let mut parts = path[1..].split('/');
let app_name = parts.next()?;
let stream_name = parts.next()?;
let file_part = parts.next()?;
if parts.next().is_some() {
return None;
}

file_path = format!("./{app_name}/{stream_name}/{ts_name}.ts");
let (file_name, ext) = file_part.rsplit_once('.')?;
if file_name.is_empty() {
return None;
}

let file_type = match ext {
Self::M3U8_EXT => HlsFileType::Playlist,
Self::TS_EXT => HlsFileType::Segment,
_ => return None,
};

Some(Self {
app_name: app_name.into(),
stream_name: stream_name.into(),
file_name: file_name.into(),
file_type,
})
}

fn to_file_path(&self) -> String {
let ext = match self.file_type {
HlsFileType::Playlist => Self::M3U8_EXT,
HlsFileType::Segment => Self::TS_EXT,
};
format!(
"./{}/{}/{}.{}",
self.app_name, self.stream_name, self.file_name, ext
)
}
simple_file_send(file_path.as_str()).await
}

/// HTTP status code 404
fn not_found() -> Response<Body> {
fn response_unauthorized() -> Response<Body> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(UNAUTHORIZED.into())
.unwrap()
}

fn response_not_found() -> Response<Body> {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(NOTFOUND.into())
.unwrap()
}

async fn simple_file_send(filename: &str) -> Response<Body> {
// Serve a file by asynchronously reading it by chunks using tokio-util crate.
async fn response_file(hls_path: &HlsPath) -> Response<Body> {
let file_path = hls_path.to_file_path();

if let Ok(file) = File::open(&file_path).await {
let builder = Response::builder().header("Content-Type", hls_path.file_type.content_type());

if let Ok(file) = File::open(filename).await {
// Serve a file by asynchronously reading it by chunks using tokio-util crate.
let stream = FramedRead::new(file, BytesCodec::new());
let body = Body::from_stream(stream);
return Response::new(body);
return builder.body(Body::from_stream(stream)).unwrap();
}

response_not_found()
}

async fn handle_connection(State(auth): State<Option<Auth>>, req: Request<Body>) -> Response<Body> {
let path = req.uri().path();
let query_string = req.uri().query().map(|s| s.to_string());

let hls_path = match HlsPath::parse(path) {
Some(p) => p,
None => return response_not_found(),
};

if let (Some(auth_val), HlsFileType::Playlist) = (auth.as_ref(), &hls_path.file_type) {
if auth_val
.authenticate(
&hls_path.stream_name,
&query_string.map(SecretCarrier::Query),
true,
)
.is_err()
{
return response_unauthorized();
}
}

not_found()
response_file(&hls_path).await
}

pub async fn run(port: usize, auth: Option<Auth>) -> Result<()> {
Expand All @@ -105,3 +159,49 @@ pub async fn run(port: usize, auth: Option<Auth>) -> Result<()> {

Ok(())
}

#[cfg(test)]
mod tests {
use super::{HlsFileType, HlsPath};

#[test]
fn test_hls_path_parse() {
// Playlist
let playlist = HlsPath::parse("/live/stream/stream.m3u8").unwrap();
assert_eq!(playlist.app_name, "live");
assert_eq!(playlist.stream_name, "stream");
assert_eq!(playlist.file_name, "stream");
assert!(matches!(playlist.file_type, HlsFileType::Playlist));
assert_eq!(playlist.to_file_path(), "./live/stream/stream.m3u8");
assert_eq!(
playlist.file_type.content_type(),
"application/vnd.apple.mpegurl"
);

// Segment
let segment = HlsPath::parse("/live/stream/123.ts").unwrap();
assert_eq!(segment.app_name, "live");
assert_eq!(segment.stream_name, "stream");
assert_eq!(segment.file_name, "123");
assert!(matches!(segment.file_type, HlsFileType::Segment));
assert_eq!(segment.to_file_path(), "./live/stream/123.ts");
assert_eq!(segment.file_type.content_type(), "video/mp2t");

// Negative
assert!(HlsPath::parse("").is_none());
assert!(HlsPath::parse("/invalid").is_none());
assert!(HlsPath::parse("/too/many/parts/of/path.m3u8").is_none());
assert!(HlsPath::parse("/live/stream/invalid.mp4").is_none());
assert!(HlsPath::parse("/live/stream/../../etc/passwd").is_none());
assert!(HlsPath::parse("/live/stream/...").is_none());
assert!(HlsPath::parse("/live/stream.m3u8").is_none());
assert!(HlsPath::parse("/live/stream.ts").is_none());
assert!(HlsPath::parse("/live/stream/").is_none());
assert!(HlsPath::parse("/live/stream.m3u8").is_none());
assert!(HlsPath::parse("/live/stream.ts").is_none());
assert!(HlsPath::parse("/live/stream/file.").is_none());
assert!(HlsPath::parse("/live/stream/.m3u8").is_none());
assert!(HlsPath::parse("/live/stream/file.M3U8").is_none());
assert!(HlsPath::parse("/live/stream/file.TS").is_none());
}
}

0 comments on commit 1057281

Please sign in to comment.