Skip to content

Commit

Permalink
progress? sinks and streams are weird
Browse files Browse the repository at this point in the history
  • Loading branch information
jmwample committed May 28, 2024
1 parent d6761bd commit 2ccf0e1
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 35 deletions.
4 changes: 2 additions & 2 deletions crates/obfs4/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Client {
/// handshake timeout and then close the connection.
pub async fn wrap<'a, T>(self, mut stream: T) -> Result<Obfs4Stream>
where
T: AsyncRead + AsyncWrite + Unpin + 'a,
T: AsyncRead + AsyncWrite + Unpin + Send + 'a,
{
let session = sessions::new_client_session(self.station_pubkey, self.iat_mode);

Expand All @@ -158,7 +158,7 @@ impl Client {
mut stream_fut: Pin<ptrs::FutureResult<T, E>>,
) -> Result<Obfs4Stream>
where
T: AsyncRead + AsyncWrite + Unpin + 'a,
T: AsyncRead + AsyncWrite + Unpin + Send + 'a,
E: std::error::Error + Send + Sync + 'static,
{
let stream = stream_fut.await.map_err(|e| Error::Other(Box::new(e)))?;
Expand Down
8 changes: 5 additions & 3 deletions crates/obfs4/src/common/delay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ impl<Item, Si: Sink<Item>> Sink<Item> for DelayedSink<Si, Item> {

let delay = (*s.delay_fn)();

s.sleep
.as_mut()
.reset(Instant::now() + delay);
if delay.is_zero() {
s.sleep
.as_mut()
.reset(Instant::now() + delay);
}
Ok(())
}

Expand Down
76 changes: 49 additions & 27 deletions crates/obfs4/src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
Error, Result,
};

use bytes::{Buf, BytesMut, Bytes};
use futures::{sink::Sink, stream::Stream};
use bytes::{Buf, BytesMut};
use futures::{sink::Sink, stream::{Stream, StreamExt}};
use pin_project::pin_project;
use ptrs::trace;
use sha2::{Digest, Sha256};
Expand All @@ -35,6 +35,8 @@ pub enum IAT {
Paranoid,
}

pub trait Transport<B,E,I>: Sink<B, Error=E> + Stream<Item=I> + Unpin + Send {}

#[derive(Debug, Clone)]
pub(crate) enum MaybeTimeout {
Default_,
Expand Down Expand Up @@ -91,7 +93,7 @@ pub struct Obfs4Stream {
}

impl Obfs4Stream {
pub(crate) fn from_o4(o4: O4Stream) -> Self {
pub(crate) fn from_o4(o4: O4Stream<>) -> Self {
Obfs4Stream {
// s: Arc::new(Mutex::new(o4)),
s: o4,
Expand All @@ -100,10 +102,13 @@ impl Obfs4Stream {
}

#[pin_project]
pub(crate) struct O4Stream {
pub(crate) struct O4Stream{
#[pin]
// pub stream: Framed<T, framing::Obfs4Codec>,
pub stream: Box<dyn Sink<BytesMut, Error=()> + Send + Unpin>,
// pub stream: Box<dyn Transport<BytesMut, IoError, Messages>>,
pub stream: Box<dyn Stream<Item=Messages> + Send + Unpin>,
#[pin]
pub sink: Box<dyn Sink<BytesMut, Error=IoError> + Send + Unpin>,

pub length_dist: probdist::WeightedDist,
pub iat_dist: probdist::WeightedDist,
Expand All @@ -116,18 +121,21 @@ impl O4Stream {
// inner: &'a mut dyn Stream<'a>,
inner: T,
codec: framing::Obfs4Codec,
session: Session,
mut session: Session,
) -> O4Stream
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send,
{
let stream: Box<dyn Sink<BytesMut, Error=()>+Send+Unpin> = match session.get_iat_mode() {
IAT::Off => Box::new(Framed::new(inner, codec)),
IAT::Enabled | IAT::Paranoid => {
let f = Framed::new(inner, codec);
Box::new(delay::DelayedSink::new(f, session.iat_duration_sampler()))
}
let delay_fn = match session.get_iat_mode() {
IAT::Off => || Duration::ZERO,
IAT::Enabled | IAT::Paranoid => session.iat_duration_sampler(),
};
let (sink, stream) = Framed::new(inner, codec).split();
let sink = delay::DelayedSink::new(sink, delay_fn);

let sink: Box<dyn Sink<BytesMut, Error = IoError> + Send + Unpin> = Box::new(sink);
let stream: Box<dyn Stream<Item = Messages> + Send + Unpin> = Box::new(stream);

let len_seed = session.len_seed();

let mut hasher = Sha256::new();
Expand All @@ -150,6 +158,7 @@ impl O4Stream {
);

Self {
sink,
stream,
session,
length_dist,
Expand Down Expand Up @@ -179,8 +188,9 @@ impl AsyncWrite for O4Stream {
let mut this = self.as_mut().project();

// determine if the stream is ready to send an event?
if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending {
return Poll::Pending;
match futures::Sink::<BytesMut>::poll_ready(this.sink.as_mut(), cx) {
Poll::Pending => return Poll::Pending,
_ => {}
}

// while we have bytes in the buffer write MAX_MESSAGE_PAYLOAD_LENGTH
Expand All @@ -202,24 +212,25 @@ impl AsyncWrite for O4Stream {
out_buf.clear();

// determine if the stream is ready to send more data. if not back off
if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending {
return Poll::Ready(Ok(len_sent));
match futures::Sink::<BytesMut>::poll_ready(this.sink.as_mut(), cx) {
Poll::Pending => return Poll::Ready(Ok(len_sent)),
_ => {}
}
}

let payload = framing::Messages::Payload(buf[len_sent..].to_vec());

let mut out_buf = BytesMut::new();
payload.marshall(&mut out_buf)?;
this.stream.as_mut().start_send(out_buf)?;
this.sink.as_mut().start_send(out_buf)?;

Poll::Ready(Ok(msg_len))
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<StdResult<(), IoError>> {
trace!("{} flushing", self.session.id());
let mut this = self.project();
match futures::Sink::<&[u8]>::poll_flush(this.stream.as_mut(), cx) {
match futures::Sink::<BytesMut>::poll_flush(this.sink.as_mut(), cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Pending => Poll::Pending,
Expand All @@ -229,7 +240,7 @@ impl AsyncWrite for O4Stream {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<StdResult<(), IoError>> {
trace!("{} shutting down", self.session.id());
let mut this = self.project();
match futures::Sink::<&[u8]>::poll_close(this.stream.as_mut(), cx) {
match futures::Sink::<BytesMut>::poll_close(this.sink.as_mut(), cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Pending => Poll::Pending,
Expand Down Expand Up @@ -260,10 +271,7 @@ impl AsyncRead for O4Stream {
return Poll::Ready(Ok(()));
}

match res.unwrap() {
Ok(m) => m,
Err(e) => Err(e)?,
}
res.unwrap()
}
}
};
Expand Down Expand Up @@ -316,13 +324,13 @@ impl AsyncRead for Obfs4Stream {
}
}

impl Sink<Messages> for O4Stream {
type Error = ();
impl Sink<BytesMut> for O4Stream {
type Error = IoError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
todo!();
}

fn start_send(self: Pin<&mut Self>, _item: Messages) -> StdResult<(), Self::Error> {
fn start_send(self: Pin<&mut Self>, _item: BytesMut) -> StdResult<(), Self::Error> {
todo!();
}

Expand All @@ -335,6 +343,20 @@ impl Sink<Messages> for O4Stream {
}
}

impl Stream for O4Stream {
type Item = Messages;

// Required method
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>
) -> Poll<Option<Self::Item>> {
todo!();
}
}

impl Transport<BytesMut, IoError, Messages> for O4Stream {}


// TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl)
/// Attempts to pad a burst of data so that the last [`Message`] is of the length
Expand Down
2 changes: 1 addition & 1 deletion crates/obfs4/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl Server {

pub async fn wrap<T>(self, stream: T) -> Result<Obfs4Stream>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send,
{
let session = self.new_server_session()?;
let deadline = self.handshake_timeout.map(|d| Instant::now() + d);
Expand Down
4 changes: 2 additions & 2 deletions crates/obfs4/src/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl ClientSession<Initialized> {
deadline: Option<Instant>,
) -> Result<Obfs4Stream>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send,
{
// set up for handshake
let mut session = self.transition(ClientHandshaking {});
Expand Down Expand Up @@ -415,7 +415,7 @@ impl ServerSession<Initialized> {
deadline: Option<Instant>,
) -> Result<Obfs4Stream>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send,
{
// set up for handshake
let mut session = self.transition(ServerHandshaking {});
Expand Down

0 comments on commit 2ccf0e1

Please sign in to comment.