Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gateway transport compression #1508

Merged
merged 5 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ version = "0.4.10"


[dependencies.flate2] [dependencies.flate2]
optional = true optional = true
default-features = false
features = ["zlib"]
version = "1.0.13" version = "1.0.13"


[dependencies.reqwest] [dependencies.reqwest]
Expand Down Expand Up @@ -169,12 +171,13 @@ version = "0.4"


[features] [features]
# Defaults with different backends # Defaults with different backends
default = ["default_no_backend", "rustls_backend"] default = ["default_no_backend", "rustls_backend", "transport_compression"]
default_native_tls = ["default_no_backend", "native_tls_backend"] default_native_tls = ["default_no_backend", "native_tls_backend", "transport_compression"]
default_tokio_0_2 = ["default_no_backend", "rustls_tokio_0_2_backend"] default_tokio_0_2 = ["default_no_backend", "rustls_tokio_0_2_backend", "transport_compression"]
default_native_tls_tokio_0_2 = [ default_native_tls_tokio_0_2 = [
"default_no_backend", "default_no_backend",
"native_tls_tokio_0_2_backend", "native_tls_tokio_0_2_backend",
"transport_compression",
] ]


# Serenity requires a backend, this picks all default features without a backend. # Serenity requires a backend, this picks all default features without a backend.
Expand Down Expand Up @@ -205,6 +208,7 @@ standard_framework = ["framework", "uwl", "command_attr", "static_assertions"]
unstable_discord_api = [] unstable_discord_api = []
utils = ["base64"] utils = ["base64"]
voice = ["client", "model"] voice = ["client", "model"]
transport_compression = ["gateway", "flate2"]


# Enables simd accelerated parsing # Enables simd accelerated parsing
simdjson = ["simd-json"] simdjson = ["simd-json"]
Expand Down
2 changes: 1 addition & 1 deletion examples/e01_basic_ping_bot/Cargo.toml
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ authors = ["my name <my@email.address>"]
edition = "2018" edition = "2018"


[dependencies] [dependencies]
serenity = { path = "../../", default-features = false, features = ["client", "gateway", "rustls_backend", "model"] } serenity = { path = "../../", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "transport_compression"] }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
7 changes: 4 additions & 3 deletions src/client/bridge/gateway/shard_runner.rs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ impl ShardRunner {
let _ = self let _ = self
.shard .shard
.client .client
.stream
.close(Some(CloseFrame { .close(Some(CloseFrame {
code: close_code.into(), code: close_code.into(),
reason: Cow::from(""), reason: Cow::from(""),
Expand All @@ -306,7 +307,7 @@ impl ShardRunner {
// In return, we wait for either a Close Frame response, or an error, after which this WS is deemed // In return, we wait for either a Close Frame response, or an error, after which this WS is deemed
// disconnected from Discord. // disconnected from Discord.
loop { loop {
match self.shard.client.next().await { match self.shard.client.stream.next().await {
Some(Ok(tungstenite::Message::Close(_))) => break, Some(Ok(tungstenite::Message::Close(_))) => break,
Some(Err(_)) => { Some(Err(_)) => {
warn!( warn!(
Expand Down Expand Up @@ -409,10 +410,10 @@ impl ShardRunner {
code: code.into(), code: code.into(),
reason: Cow::from(reason), reason: Cow::from(reason),
}; };
self.shard.client.close(Some(close)).await.is_ok() self.shard.client.stream.close(Some(close)).await.is_ok()
}, },
ShardClientMessage::Runner(ShardRunnerMessage::Message(msg)) => { ShardClientMessage::Runner(ShardRunnerMessage::Message(msg)) => {
self.shard.client.send(msg).await.is_ok() self.shard.client.stream.send(msg).await.is_ok()
}, },
ShardClientMessage::Runner(ShardRunnerMessage::SetActivity(activity)) => { ShardClientMessage::Runner(ShardRunnerMessage::SetActivity(activity)) => {
// To avoid a clone of `activity`, we do a little bit of // To avoid a clone of `activity`, we do a little bit of
Expand Down
14 changes: 14 additions & 0 deletions src/error.rs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::{


#[cfg(feature = "gateway")] #[cfg(feature = "gateway")]
use async_tungstenite::tungstenite::error::Error as TungsteniteError; use async_tungstenite::tungstenite::error::Error as TungsteniteError;
#[cfg(feature = "transport_compression")]
use flate2::DecompressError;
#[cfg(feature = "http")] #[cfg(feature = "http")]
use reqwest::{header::InvalidHeaderValue, Error as ReqwestError}; use reqwest::{header::InvalidHeaderValue, Error as ReqwestError};
use serde_json::Error as JsonError; use serde_json::Error as JsonError;
Expand Down Expand Up @@ -93,6 +95,9 @@ pub enum Error {
/// [client]: crate::client /// [client]: crate::client
#[cfg(feature = "client")] #[cfg(feature = "client")]
Client(ClientError), Client(ClientError),
/// Error when decompressing a payload.
#[cfg(feature = "transport_compression")]
Flate2(DecompressError),
/// A [collector] error. /// A [collector] error.
/// ///
/// [collector]: crate::collector /// [collector]: crate::collector
Expand Down Expand Up @@ -175,6 +180,13 @@ impl From<RustlsError> for Error {
} }
} }


#[cfg(feature = "transport_compression")]
impl From<DecompressError> for Error {
fn from(e: DecompressError) -> Error {
Error::Flate2(e)
}
}

#[cfg(feature = "gateway")] #[cfg(feature = "gateway")]
impl From<TungsteniteError> for Error { impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error { fn from(e: TungsteniteError) -> Error {
Expand Down Expand Up @@ -227,6 +239,8 @@ impl Display for Error {
Error::Http(inner) => fmt::Display::fmt(&inner, f), Error::Http(inner) => fmt::Display::fmt(&inner, f),
#[cfg(all(feature = "gateway", not(feature = "native_tls_backend_marker")))] #[cfg(all(feature = "gateway", not(feature = "native_tls_backend_marker")))]
Error::Rustls(inner) => fmt::Display::fmt(&inner, f), Error::Rustls(inner) => fmt::Display::fmt(&inner, f),
#[cfg(feature = "transport_compression")]
Error::Flate2(inner) => fmt::Display::fmt(&inner, f),
#[cfg(feature = "gateway")] #[cfg(feature = "gateway")]
Error::Tungstenite(inner) => fmt::Display::fmt(&inner, f), Error::Tungstenite(inner) => fmt::Display::fmt(&inner, f),
} }
Expand Down
8 changes: 7 additions & 1 deletion src/gateway/mod.rs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -59,14 +59,20 @@ pub use self::{
}; };
#[cfg(feature = "client")] #[cfg(feature = "client")]
use crate::client::bridge::gateway::ShardClientMessage; use crate::client::bridge::gateway::ShardClientMessage;
#[cfg(feature = "transport_compression")]
use crate::internal::Inflater;
use crate::json::Value; use crate::json::Value;
use crate::model::{gateway::Activity, user::OnlineStatus}; use crate::model::{gateway::Activity, user::OnlineStatus};


pub type CurrentPresence = (Option<Activity>, OnlineStatus); pub type CurrentPresence = (Option<Activity>, OnlineStatus);


use async_tungstenite::{tokio::ConnectStream, WebSocketStream}; use async_tungstenite::{tokio::ConnectStream, WebSocketStream};


pub type WsStream = WebSocketStream<ConnectStream>; pub struct WsClient {
#[cfg(feature = "transport_compression")]
pub(crate) inflater: Inflater,
pub(crate) stream: WebSocketStream<ConnectStream>,
}


/// Indicates the current connection stage of a [`Shard`]. /// Indicates the current connection stage of a [`Shard`].
/// ///
Expand Down
46 changes: 32 additions & 14 deletions src/gateway/shard.rs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use async_tungstenite::tungstenite::{
error::Error as TungsteniteError, error::Error as TungsteniteError,
protocol::frame::CloseFrame, protocol::frame::CloseFrame,
}; };
use async_tungstenite::{tokio::ConnectStream, WebSocketStream};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{debug, error, info, instrument, trace, warn}; use tracing::{debug, error, info, instrument, trace, warn};
use url::Url; use url::Url;
Expand All @@ -18,7 +19,7 @@ use super::{
ReconnectType, ReconnectType,
ShardAction, ShardAction,
WebSocketGatewayClientExt, WebSocketGatewayClientExt,
WsStream, WsClient,
}; };
use crate::client::bridge::gateway::{ChunkGuildFilter, GatewayIntents}; use crate::client::bridge::gateway::{ChunkGuildFilter, GatewayIntents};
use crate::constants::{self, close_codes}; use crate::constants::{self, close_codes};
Expand All @@ -27,6 +28,8 @@ use crate::internal::prelude::*;
use crate::internal::ws_impl::create_native_tls_client; use crate::internal::ws_impl::create_native_tls_client;
#[cfg(all(feature = "rustls_backend_marker", not(feature = "native_tls_backend_marker")))] #[cfg(all(feature = "rustls_backend_marker", not(feature = "native_tls_backend_marker")))]
use crate::internal::ws_impl::create_rustls_client; use crate::internal::ws_impl::create_rustls_client;
#[cfg(feature = "transport_compression")]
use crate::internal::Inflater;
use crate::model::{ use crate::model::{
event::{Event, GatewayEvent}, event::{Event, GatewayEvent},
gateway::Activity, gateway::Activity,
Expand Down Expand Up @@ -67,7 +70,7 @@ use crate::model::{
/// [docs]: https://discord.com/developers/docs/topics/gateway#sharding /// [docs]: https://discord.com/developers/docs/topics/gateway#sharding
/// [module docs]: crate::gateway#sharding /// [module docs]: crate::gateway#sharding
pub struct Shard { pub struct Shard {
pub client: WsStream, pub client: WsClient,
current_presence: CurrentPresence, current_presence: CurrentPresence,
/// A tuple of: /// A tuple of:
/// ///
Expand Down Expand Up @@ -140,7 +143,12 @@ impl Shard {
intents: GatewayIntents, intents: GatewayIntents,
) -> Result<Shard> { ) -> Result<Shard> {
let url = ws_url.lock().await.clone(); let url = ws_url.lock().await.clone();
let client = connect(&url).await?; let stream = connect(&url).await?;
let client = WsClient {
#[cfg(feature = "transport_compression")]
inflater: Inflater::new(),
stream,
};


let current_presence = (None, OnlineStatus::Online); let current_presence = (None, OnlineStatus::Online);
let heartbeat_instants = (None, None); let heartbeat_instants = (None, None);
Expand Down Expand Up @@ -735,7 +743,7 @@ impl Shard {
/// This will set the stage of the shard before and after instantiation of /// This will set the stage of the shard before and after instantiation of
/// the client. /// the client.
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn initialize(&mut self) -> Result<WsStream> { pub async fn initialize(&mut self) -> Result<()> {
debug!("[Shard {:?}] Initializing.", self.shard_info); debug!("[Shard {:?}] Initializing.", self.shard_info);


// We need to do two, sort of three things here: // We need to do two, sort of three things here:
Expand All @@ -749,10 +757,14 @@ impl Shard {
self.stage = ConnectionStage::Connecting; self.stage = ConnectionStage::Connecting;
self.started = Instant::now(); self.started = Instant::now();
let url = &self.ws_url.lock().await.clone(); let url = &self.ws_url.lock().await.clone();
let client = connect(url).await?; // Reset inflater
#[cfg(feature = "transport_compression")]
self.client.inflater.reset();
// Make new websocket stream
self.client.stream = connect(url).await?;
self.stage = ConnectionStage::Handshake; self.stage = ConnectionStage::Handshake;


Ok(client) Ok(())
} }


#[instrument(skip(self))] #[instrument(skip(self))]
Expand All @@ -769,7 +781,7 @@ impl Shard {
pub async fn resume(&mut self) -> Result<()> { pub async fn resume(&mut self) -> Result<()> {
debug!("[Shard {:?}] Attempting to resume", self.shard_info); debug!("[Shard {:?}] Attempting to resume", self.shard_info);


self.client = self.initialize().await?; self.initialize().await?;
self.stage = ConnectionStage::Resuming; self.stage = ConnectionStage::Resuming;


match self.session_id.as_ref() { match self.session_id.as_ref() {
Expand All @@ -785,7 +797,7 @@ impl Shard {
info!("[Shard {:?}] Attempting to reconnect", self.shard_info()); info!("[Shard {:?}] Attempting to reconnect", self.shard_info());


self.reset().await; self.reset().await;
self.client = self.initialize().await?; self.initialize().await?;


Ok(()) Ok(())
} }
Expand All @@ -797,23 +809,29 @@ impl Shard {
} }


#[cfg(all(feature = "rustls_backend_marker", not(feature = "native_tls_backend_marker")))] #[cfg(all(feature = "rustls_backend_marker", not(feature = "native_tls_backend_marker")))]
async fn connect(base_url: &str) -> Result<WsStream> { async fn connect(base_url: &str) -> Result<WebSocketStream<ConnectStream>> {
let url = build_gateway_url(base_url)?; let url = build_gateway_url(base_url)?;


Ok(create_rustls_client(url).await?) Ok(create_rustls_client(url).await?)
} }


#[cfg(feature = "native_tls_backend_marker")] #[cfg(feature = "native_tls_backend_marker")]
async fn connect(base_url: &str) -> Result<WsStream> { async fn connect(base_url: &str) -> Result<WebSocketStream<ConnectStream>> {
let url = build_gateway_url(base_url)?; let url = build_gateway_url(base_url)?;


Ok(create_native_tls_client(url).await?) Ok(create_native_tls_client(url).await?)
} }


fn build_gateway_url(base: &str) -> Result<Url> { fn build_gateway_url(base: &str) -> Result<Url> {
Url::parse(&format!("{}?v={}", base, constants::GATEWAY_VERSION)).map_err(|why| { #[cfg(feature = "transport_compression")]
warn!("Error building gateway URL with base `{}`: {:?}", base, why); const COMPRESSION: &str = "?compress=zlib-stream";
#[cfg(not(feature = "transport_compression"))]
const COMPRESSION: &str = "";


Error::Gateway(GatewayError::BuildingUrl) Url::parse(&format!("{}?v={}&encoding=json{}", base, constants::GATEWAY_VERSION, COMPRESSION))
}) .map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base, why);

Error::Gateway(GatewayError::BuildingUrl)
})
} }
4 changes: 2 additions & 2 deletions src/gateway/ws_client_ext.rs
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{debug, trace};


use crate::client::bridge::gateway::{ChunkGuildFilter, GatewayIntents}; use crate::client::bridge::gateway::{ChunkGuildFilter, GatewayIntents};
use crate::constants::{self, OpCode}; use crate::constants::{self, OpCode};
use crate::gateway::{CurrentPresence, WsStream}; use crate::gateway::{CurrentPresence, WsClient};
use crate::internal::prelude::*; use crate::internal::prelude::*;
use crate::internal::ws_impl::SenderExt; use crate::internal::ws_impl::SenderExt;
use crate::json::json; use crate::json::json;
Expand Down Expand Up @@ -49,7 +49,7 @@ pub trait WebSocketGatewayClientExt {
} }


#[async_trait] #[async_trait]
impl WebSocketGatewayClientExt for WsStream { impl WebSocketGatewayClientExt for WsClient {
#[instrument(skip(self))] #[instrument(skip(self))]
async fn send_chunk_guild( async fn send_chunk_guild(
&mut self, &mut self,
Expand Down
109 changes: 109 additions & 0 deletions src/internal/inflater.rs
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::convert::TryInto;

use flate2::{Decompress, DecompressError, FlushDecompress};
use tracing::trace;

const ZLIB_SUFFIX: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
const INTERNAL_BUFFER_SIZE: usize = 32 * 1024;

pub struct Inflater {
decompress: Decompress,
compressed: Vec<u8>,
internal_buffer: Vec<u8>,
buffer: Vec<u8>,
countdown_to_resize: u8,
}

impl Inflater {
pub fn new() -> Self {
Self {
decompress: Decompress::new(true),
compressed: Vec::new(),
internal_buffer: Vec::with_capacity(INTERNAL_BUFFER_SIZE),
buffer: Vec::with_capacity(32 * 1024),
countdown_to_resize: u8::max_value(),
}
}

pub fn extend(&mut self, slice: &[u8]) {
self.compressed.extend_from_slice(slice);
}

pub fn msg(&mut self) -> Result<Option<&[u8]>, DecompressError> {
let length = self.compressed.len();
if length >= 4 && self.compressed[(length - 4)..] == ZLIB_SUFFIX {
// There is a payload to be decompressed.
let before = self.decompress.total_in();
let mut offset = 0;
loop {
self.internal_buffer.clear();

self.decompress.decompress_vec(
&self.compressed[offset..],
&mut self.internal_buffer,
FlushDecompress::Sync,
)?;

offset = (self.decompress.total_in() - before).try_into().unwrap_or(0);
self.buffer.extend_from_slice(&self.internal_buffer[..]);
if self.internal_buffer.len() < self.internal_buffer.capacity()
|| offset > self.compressed.len()
{
break;
}
}

trace!("in:out: {}:{}", self.compressed.len(), self.buffer.len());
self.compressed.clear();

#[allow(clippy::cast_precision_loss)]
{
// To get around the u64 → f64 precision loss lint
// it does really not matter that it happens here
trace!(
"Data saved: {}KiB ({:.2}%)",
((self.decompress.total_out() - self.decompress.total_in()) / 1024),
((self.decompress.total_in() as f64) / (self.decompress.total_out() as f64)
* 100.0)
);
}
trace!("Capacity: {}", self.buffer.capacity());
Ok(Some(&self.buffer))
} else {
// Received a partial payload.
Ok(None)
}
}

// Clear the buffer, and shrink it if it has more space
// enough to grow the length more than 4 times.
pub fn clear(&mut self) {
self.countdown_to_resize -= 1;

// Only shrink capacity if it is less than 4
// times the size, this is to prevent too
// frequent shrinking.
let cap = self.buffer.capacity();
if self.countdown_to_resize == 0 && self.buffer.len() < cap * 4 {
// When shrink_to goes stable use that on the following line.
// https://github.com/rust-lang/rust/issues/56431
self.compressed.shrink_to_fit();
self.buffer.shrink_to_fit();
trace!("compressed: {}", self.compressed.capacity());
trace!("buffer: {}", self.buffer.capacity());
self.countdown_to_resize = u8::max_value();
}
self.compressed.clear();
self.internal_buffer.clear();
self.buffer.clear();
}

// Reset the inflater
pub fn reset(&mut self) {
self.decompress.reset(true);
self.compressed.clear();
self.internal_buffer.clear();
self.buffer.clear();
self.countdown_to_resize = u8::MAX;
}
}
Loading