Skip to content

Commit

Permalink
Merge pull request #134 from Ceron257/add-authorization-header-support
Browse files Browse the repository at this point in the history
Add authorization header support
  • Loading branch information
harlanc authored Jun 9, 2024
2 parents 56802e4 + 01e1640 commit a31d775
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 32 deletions.
68 changes: 50 additions & 18 deletions library/common/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,49 @@ pub enum AuthAlgorithm {
Md5,
}

pub enum SecretCarrier {
Query(String),
Bearer(String),
}

pub fn get_secret(carrier: &SecretCarrier) -> Result<String, AuthError> {
match carrier {
SecretCarrier::Query(query) => {
let mut query_pairs = IndexMap::new();
let pars_array: Vec<&str> = query.split('&').collect();
for ele in pars_array {
let (k, v) = scanf!(ele, '=', String, String);
if k.is_none() || v.is_none() {
continue;
}
query_pairs.insert(k.unwrap(), v.unwrap());
}

query_pairs.get("token").map_or(
Err(AuthError {
value: AuthErrorValue::NoTokenFound,
}),
|t| Ok(t.to_string()),
)
}
SecretCarrier::Bearer(header) => {
let invalid_format = Err(AuthError {
value: AuthErrorValue::InvalidTokenFormat,
});
let (prefix, token) = scanf!(header, " ", String, String);
if prefix.is_none() || token.is_none() {
invalid_format
} else {
if prefix.unwrap() != "Bearer" {
invalid_format
} else {
Ok(token.unwrap())
}
}
}
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum AuthType {
Pull,
Expand Down Expand Up @@ -50,7 +93,7 @@ impl Auth {
pub fn authenticate(
&self,
stream_name: &String,
query: &Option<String>,
secret: &Option<SecretCarrier>,
is_pull: bool,
) -> Result<(), AuthError> {
if self.auth_type == AuthType::Both
Expand All @@ -61,24 +104,13 @@ impl Auth {
let mut err: AuthErrorValue = AuthErrorValue::NoTokenFound;

/*Here we should do auth and it must be successful. */
if let Some(query_val) = query {
let mut query_pairs = IndexMap::new();
let pars_array: Vec<&str> = query_val.split('&').collect();
for ele in pars_array {
let (k, v) = scanf!(ele, '=', String, String);
if k.is_none() || v.is_none() {
continue;
}
query_pairs.insert(k.unwrap(), v.unwrap());
}

if let Some(token) = query_pairs.get("token") {
if self.check(stream_name, token, is_pull) {
return Ok(());
}
auth_err_reason = format!("token is not correct: {}", token);
err = AuthErrorValue::TokenIsNotCorrect;
if let Some(secret_value) = secret {
let token = get_secret(secret_value)?;
if self.check(stream_name, token.as_str(), is_pull) {
return Ok(());
}
auth_err_reason = format!("token is not correct: {}", token);
err = AuthErrorValue::TokenIsNotCorrect;
}

log::error!(
Expand Down
2 changes: 2 additions & 0 deletions library/common/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub enum AuthErrorValue {
TokenIsNotCorrect,
#[fail(display = "no token found.")]
NoTokenFound,
#[fail(display = "invalid token format.")]
InvalidTokenFormat
}

impl fmt::Display for AuthError {
Expand Down
8 changes: 6 additions & 2 deletions protocol/hls/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use {
http::StatusCode,
response::Response,
},
commonlib::auth::Auth,
commonlib::auth::{Auth, SecretCarrier},
std::net::SocketAddr,
tokio::{fs::File, net::TcpListener},
tokio_util::codec::{BytesCodec, FramedRead},
Expand Down Expand Up @@ -36,7 +36,11 @@ async fn handle_connection(State(auth): State<Option<Auth>>, req: Request<Body>)

if let Some(auth_val) = auth {
if auth_val
.authenticate(&stream_name, &query_string, true)
.authenticate(
&stream_name,
&query_string.map(|q| SecretCarrier::Query(q)),
true,
)
.is_err()
{
return Response::builder()
Expand Down
8 changes: 6 additions & 2 deletions protocol/httpflv/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use {
http::StatusCode,
response::Response,
},
commonlib::auth::Auth,
commonlib::auth::{Auth, SecretCarrier},
futures::channel::mpsc::unbounded,
std::net::SocketAddr,
streamhub::define::StreamHubEventSender,
Expand Down Expand Up @@ -37,7 +37,11 @@ async fn handle_connection(

if let Some(auth_val) = auth {
if auth_val
.authenticate(&stream_name, &query_string, true)
.authenticate(
&stream_name,
&query_string.map(|q| SecretCarrier::Query(q)),
true,
)
.is_err()
{
return Response::builder()
Expand Down
20 changes: 18 additions & 2 deletions protocol/rtmp/src/session/server_session.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use commonlib::auth::SecretCarrier;

use crate::chunk::{errors::UnpackErrorValue, packetizer::ChunkPacketizer};

use {
Expand Down Expand Up @@ -626,7 +628,14 @@ impl ServerSession {
(self.stream_name, self.query) =
RtmpUrlParser::parse_stream_name_with_query(&raw_stream_name);
if let Some(auth) = &self.auth {
auth.authenticate(&self.stream_name, &self.query, true)?
auth.authenticate(
&self.stream_name,
&self
.query
.as_ref()
.map(|q| SecretCarrier::Query(q.to_string())),
true,
)?
}

let query = if let Some(query_val) = &self.query {
Expand Down Expand Up @@ -699,7 +708,14 @@ impl ServerSession {
}
}
if let Some(auth) = &self.auth {
auth.authenticate(&self.stream_name, &self.query, false)?
auth.authenticate(
&self.stream_name,
&self
.query
.as_ref()
.map(|q| SecretCarrier::Query(q.to_string())),
false,
)?
}

/*Now it can update the request url*/
Expand Down
21 changes: 19 additions & 2 deletions protocol/rtsp/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::global_trait::Unmarshal;
use crate::rtp::define::ANNEXB_NALU_START_CODE;
use crate::rtp::utils::Marshal as RtpMarshal;

use commonlib::auth::SecretCarrier;
use commonlib::http::HttpRequest as RtspRequest;
use commonlib::http::HttpResponse as RtspResponse;
use commonlib::http::Marshal as RtspMarshal;
Expand Down Expand Up @@ -314,7 +315,15 @@ impl RtspServerSession {
async fn handle_announce(&mut self, rtsp_request: &RtspRequest) -> Result<(), SessionError> {
if let Some(auth) = &self.auth {
let stream_name = rtsp_request.uri.path.clone();
auth.authenticate(&stream_name, &rtsp_request.uri.query, false)?;
auth.authenticate(
&stream_name,
&rtsp_request
.uri
.query
.as_ref()
.map(|q| SecretCarrier::Query(q.to_string())),
false,
)?;
}

if let Some(request_body) = &rtsp_request.body {
Expand Down Expand Up @@ -465,7 +474,15 @@ impl RtspServerSession {
async fn handle_play(&mut self, rtsp_request: &RtspRequest) -> Result<(), SessionError> {
if let Some(auth) = &self.auth {
let stream_name = rtsp_request.uri.path.clone();
auth.authenticate(&stream_name, &rtsp_request.uri.query, true)?;
auth.authenticate(
&stream_name,
&rtsp_request
.uri
.query
.as_ref()
.map(|q| SecretCarrier::Query(q.to_string())),
true,
)?;
}

for track in self.tracks.values_mut() {
Expand Down
9 changes: 6 additions & 3 deletions protocol/webrtc/src/clients/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ <h1>WHEP Example</h1>
<label for="app-name">App Name:</label>
<input type="text" id="app-name" name="app-name" value="live">
<label for="stream-name">Stream Name:</label>
<input type="text" id="stream-name" name="stream-name" value="test">
<input type="text" id="stream-name" name="stream-name" value="test"><br>
<label for="token">Token:</label>
<input type="text" id="token" name="token" value="123">
<label for="use-header">Use Authorization header:</label>
<input type="checkbox" id="use-header" name="use-header">
<br><br>
<button id="start-whep-btn">Start WHEP</button>
</div>
Expand All @@ -106,6 +108,7 @@ <h1>WHEP Example</h1>
const appName = document.getElementById("app-name").value;
const streamName = document.getElementById("stream-name").value;
const token = document.getElementById("token").value;
const useHeader = document.getElementById("use-header").checked;

//Create peerconnection
const pc = window.pc = new RTCPeerConnection();
Expand All @@ -126,11 +129,11 @@ <h1>WHEP Example</h1>
//Create whep client
const whep = new WHEPClient();

const url = location.origin + "/whep?app=" + appName + "&stream=" + streamName + "&token=" + token;
const url = location.origin + "/whep?app=" + appName + "&stream=" + streamName + (!useHeader ? "&token=" + token : "");
//const token = ""

//Start viewing
whep.view(pc, url, token);
whep.view(pc, url, useHeader ? token : null);

});
</script>
Expand Down
17 changes: 14 additions & 3 deletions protocol/webrtc/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use std::io::Read;
use std::{collections::HashMap, fs::File, sync::Arc};
use tokio::net::TcpStream;

use commonlib::define::http_method_name;
use commonlib::http::{parse_content_length, HttpRequest, HttpResponse};
use commonlib::{auth::SecretCarrier, define::http_method_name};

use commonlib::http::Marshal as HttpMarshal;
use commonlib::http::Unmarshal as HttpUnmarshal;
Expand Down Expand Up @@ -184,17 +184,28 @@ impl WebRTCServerSession {
);
let offer = RTCSessionDescription::offer(sdp_data.clone())?;

let bearer_carrier = http_request
.get_header(&"Authorization".to_string())
.map(|header| SecretCarrier::Bearer(header.to_string()));
let query_carrier = http_request
.uri
.query
.as_ref()
.map(|q| SecretCarrier::Query(q.to_string()));

let token_carrier = bearer_carrier.or(query_carrier);

match t.to_lowercase().as_str() {
"whip" => {
if let Some(auth) = &self.auth {
auth.authenticate(&stream_name, &http_request.uri.query, false)?;
auth.authenticate(&stream_name, &token_carrier, false)?;
}
self.publish_whip(app_name, stream_name, path, offer)
.await?;
}
"whep" => {
if let Some(auth) = &self.auth {
auth.authenticate(&stream_name, &http_request.uri.query, true)?;
auth.authenticate(&stream_name, &token_carrier, true)?;
}
self.subscribe_whep(app_name, stream_name, path, offer)
.await?;
Expand Down

0 comments on commit a31d775

Please sign in to comment.