Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http): set request authority and scheme for h1 #1200

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<const HEADER_LIMIT: usize> DerefMut for Context<'_, '_, HEADER_LIMIT> {
}

impl<'c, 'd, const HEADER_LIMIT: usize> Context<'c, 'd, HEADER_LIMIT> {
pub(crate) fn new(date: &'c DateTimeHandle<'d>) -> Self {
Self(context::Context::new(date))
pub(crate) fn new(date: &'c DateTimeHandle<'d>, is_tls: bool) -> Self {
Self(context::Context::new(date, is_tls))
}
}
6 changes: 5 additions & 1 deletion client/src/h1/proto/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ where
}
}

let is_tls = req
.uri()
.scheme()
.is_some_and(|scheme| scheme == "https" || scheme == "wss");
// TODO: make const generic params configurable.
let mut ctx = Context::<128>::new(&date);
let mut ctx = Context::<128>::new(&date, is_tls);

// encode request head and return transfer encoding for request body
let encoder = ctx.encode_head(&mut buf, req)?;
Expand Down
2 changes: 1 addition & 1 deletion http/benches/h1_decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl DateTime for DT {
fn decode(c: &mut Criterion) {
let dt = DT::dummy_date_time();

let mut ctx = Context::<_, 8>::new(&dt);
let mut ctx = Context::<_, 8>::new(&dt, false);

let req = b"\
GET /HFQR/xitca-web HTTP/1.1\r\n\
Expand Down
7 changes: 5 additions & 2 deletions http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub(crate) async fn run<
config: HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Result<(), Error<S::Error, BE>>
where
S: Service<ExtRequest<ReqB>, Response = Response<ResB>>,
Expand All @@ -77,7 +78,7 @@ where
EitherBuf::Right(WriteBuf::<WRITE_BUF_LIMIT>::default())
};

Dispatcher::new(io, addr, timer, config, service, date, write_buf)
Dispatcher::new(io, addr, timer, config, service, date, write_buf, is_tls)
.run()
.await
}
Expand Down Expand Up @@ -166,6 +167,7 @@ where
W: H1BufWrite,
D: DateTime,
{
#[allow(clippy::too_many_arguments)]
fn new<const WRITE_BUF_LIMIT: usize>(
io: &'a mut St,
addr: SocketAddr,
Expand All @@ -174,11 +176,12 @@ where
service: &'a S,
date: &'a D,
write_buf: W,
is_tls: bool,
) -> Self {
Self {
io: BufferedIo::new(io, write_buf),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::with_addr(addr, date),
ctx: Context::with_addr(addr, date, is_tls),
service,
_phantom: PhantomData,
}
Expand Down
3 changes: 2 additions & 1 deletion http/src/h1/dispatcher_uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ where
config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Self {
Self {
io: Rc::new(io),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date, is_tls),
service,
read_buf: BufOwned::new(),
write_buf: BufOwned::new(),
Expand Down
8 changes: 5 additions & 3 deletions http/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct Context<'a, D, const HEADER_LIMIT: usize> {
// http extensions reused by next request.
exts: Extensions,
date: &'a D,
pub(crate) is_tls: bool,
}

// A set of state for current request that are used after request's ownership is passed
Expand Down Expand Up @@ -49,21 +50,22 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> {
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn new(date: &'a D) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date)
pub fn new(date: &'a D, is_tls: bool) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date, is_tls)
}

/// Context is constructed with [SocketAddr] and reference of certain type that impl [DateTime] trait.
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn with_addr(addr: SocketAddr, date: &'a D) -> Self {
pub fn with_addr(addr: SocketAddr, date: &'a D, is_tls: bool) -> Self {
Self {
addr,
state: ContextState::new(),
header: None,
exts: Extensions::new(),
date,
is_tls,
}
}

Expand Down
55 changes: 52 additions & 3 deletions http/src/h1/proto/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::mem::MaybeUninit;

use http::uri::{Authority, Scheme};
use httparse::Status;

use crate::{
Expand Down Expand Up @@ -71,7 +72,7 @@ impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
// split the headers from buffer.
let slice = buf.split_to(len).freeze();

let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
let mut uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?.into_parts();

// pop a cached headermap or construct a new one.
let mut headers = self.take_headers();
Expand All @@ -87,6 +88,25 @@ impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {

let extensions = self.take_extensions();

// Try to set authority from host header if not present in request path
if uri.authority.is_none() {
// @TODO if it's a tls connection we could set the sni server name as authority instead
if let Some(host) = headers.get(http::header::HOST) {
uri.authority = Some(Authority::try_from(host.as_bytes())?);
}
}

// If authority is set, this will set the correct scheme depending on the tls acceptor used in the service.
if uri.authority.is_some() && uri.scheme.is_none() {
uri.scheme = if self.is_tls {
Some(Scheme::HTTPS)
} else {
Some(Scheme::HTTP)
};
}

let uri = Uri::from_parts(uri)?;

*req.method_mut() = method;
*req.version_mut() = version;
*req.uri_mut() = uri;
Expand Down Expand Up @@ -173,7 +193,7 @@ mod test {

#[test]
fn connection_multiple_value() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down Expand Up @@ -211,7 +231,7 @@ mod test {

#[test]
fn transfer_encoding() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down Expand Up @@ -311,4 +331,33 @@ mod test {
"transfer coding is not decoded to chunked"
);
}

#[test]
fn test_host_with_scheme() {
let mut ctx = Context::<_, 4>::new(&(), true);

let head = b"\
GET / HTTP/1.1\r\n\
Host: example.com\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);

let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();

assert_eq!(req.uri().scheme(), Some(&Scheme::HTTPS));
assert_eq!(req.uri().authority(), Some(&Authority::from_static("example.com")));
assert_eq!(req.headers().get(http::header::HOST).unwrap(), "example.com");

let head = b"\
GET / HTTP/1.1\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);

let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();

assert_eq!(req.uri().scheme(), None);
assert_eq!(req.uri().authority(), None);
}
}
4 changes: 2 additions & 2 deletions http/src/h1/proto/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mod test {

#[test]
fn append_header() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down Expand Up @@ -287,7 +287,7 @@ mod test {

#[test]
fn multi_set_cookie() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down
6 changes: 6 additions & 0 deletions http/src/h1/proto/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ impl From<http::uri::InvalidUri> for ProtoError {
}
}

impl From<http::uri::InvalidUriParts> for ProtoError {
fn from(_: http::uri::InvalidUriParts) -> Self {
Self::Uri
}
}

impl From<http::status::InvalidStatusCode> for ProtoError {
fn from(_: http::status::InvalidStatusCode) -> Self {
Self::Status
Expand Down
35 changes: 26 additions & 9 deletions http/src/h1/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
error::{HttpServiceError, TimeoutError},
http::{Request, RequestExt, Response},
service::HttpService,
tls::IsTls,
util::timer::Timeout,
};

Expand All @@ -21,7 +22,7 @@ impl<St, S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, co
Service<(St, SocketAddr)> for H1Service<St, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<St>,
A: Service<St> + IsTls,
St: AsyncIo,
A::Response: AsyncIo,
B: Stream<Item = Result<Bytes, BE>>,
Expand All @@ -41,9 +42,17 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get())
.await
.map_err(Into::into)
super::dispatcher::run(
&mut io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(Into::into)
}
}

Expand Down Expand Up @@ -94,7 +103,7 @@ impl<S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const
Service<(TcpStream, SocketAddr)> for H1UringService<S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<TcpStream>,
A: Service<TcpStream> + IsTls,
A::Response: AsyncBufRead + AsyncBufWrite + 'static,
B: Stream<Item = Result<Bytes, BE>>,
HttpServiceError<S::Error, BE>: From<A::Error>,
Expand All @@ -113,10 +122,18 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get())
.run()
.await
.map_err(Into::into)
super::dispatcher_uring::Dispatcher::new(
io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.run()
.await
.map_err(Into::into)
}
}

Expand Down
5 changes: 4 additions & 1 deletion http/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
date::{DateTime, DateTimeService},
error::{HttpServiceError, TimeoutError},
http::{Request, RequestExt, Response},
tls::IsTls,
util::timer::{KeepAlive, Timeout},
version::AsVersion,
};
Expand Down Expand Up @@ -73,7 +74,7 @@ impl<S, ResB, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, con
for HttpService<ServerStream, S, RequestBody, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<ResB>>,
A: Service<TcpStream>,
A: Service<TcpStream> + IsTls,
A::Response: AsyncIo + AsVersion,
HttpServiceError<S::Error, BE>: From<A::Error>,
S::Error: fmt::Debug,
Expand Down Expand Up @@ -120,6 +121,7 @@ where
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(From::from),
Expand Down Expand Up @@ -168,6 +170,7 @@ where
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(From::from)
Expand Down
13 changes: 13 additions & 0 deletions http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ pub use error::TlsError;

use xitca_service::Service;

/// A trait to check if an acceptor will create a Tls stream.
pub trait IsTls {
fn is_tls(&self) -> bool {
true
}
}

/// A NoOp Tls Acceptor pass through input Stream type.
#[derive(Copy, Clone)]
pub struct NoOpTlsAcceptorBuilder;
Expand All @@ -42,3 +49,9 @@ impl<St> Service<St> for NoOpTlsAcceptorService {
Ok(io)
}
}

impl IsTls for NoOpTlsAcceptorService {
fn is_tls(&self) -> bool {
false
}
}
4 changes: 3 additions & 1 deletion http/src/tls/native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use xitca_service::Service;

use crate::{http::Version, version::AsVersion};

use super::error::TlsError;
use super::{error::TlsError, IsTls};

/// A wrapper type for [TlsStream](native_tls::TlsStream).
///
Expand Down Expand Up @@ -93,6 +93,8 @@ impl<St: AsyncIo> Service<St> for TlsAcceptorService {
}
}

impl IsTls for TlsAcceptorService {}

impl<S: AsyncIo> AsyncIo for TlsStream<S> {
#[inline]
fn ready(&mut self, interest: Interest) -> impl Future<Output = io::Result<Ready>> + Send {
Expand Down
Loading
Loading