Skip to content

Commit

Permalink
Use Arc<str> to store the token (#2745)
Browse files Browse the repository at this point in the history
`SecretString` is nice to have, but this is more efficient, and
`SecretString` is defense in depth we can do without.
  • Loading branch information
GnomedDev committed Mar 10, 2024
1 parent dc9e482 commit 561d629
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ impl ShardQueuer {
async fn start(&mut self, shard_id: ShardId) -> Result<()> {
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.http.token(),
Arc::clone(self.http.token()),
ShardInfo::new(shard_id, self.shard_total),
self.intents,
self.presence.clone(),
Expand Down
10 changes: 5 additions & 5 deletions src/gateway/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct Shard {
// This acts as a timeout to determine if the shard has - for some reason - not started within
// a decent amount of time.
pub started: Instant,
pub token: String,
pub token: Arc<str>,
ws_url: Arc<str>,
pub intents: GatewayIntents,
}
Expand All @@ -98,15 +98,15 @@ impl Shard {
/// #
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let http: Arc<Http> = unimplemented!();
/// let token = std::env::var("DISCORD_BOT_TOKEN")?;
/// let token = Arc::from(std::env::var("DISCORD_BOT_TOKEN")?);
/// let shard_info = ShardInfo {
/// id: ShardId(0),
/// total: NonZeroU16::MIN,
/// };
///
/// // retrieve the gateway response, which contains the URL to connect to
/// let gateway = Arc::from(http.get_gateway().await?.url);
/// let shard = Shard::new(gateway, &token, shard_info, GatewayIntents::all(), None).await?;
/// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?;
///
/// // at this point, you can create a `loop`, and receive events and match
/// // their variants
Expand All @@ -120,7 +120,7 @@ impl Shard {
/// TLS error.
pub async fn new(
ws_url: Arc<str>,
token: &str,
token: Arc<str>,
shard_info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
Expand All @@ -147,7 +147,7 @@ impl Shard {
seq,
stage,
started: Instant::now(),
token: token.to_string(),
token,
session_id,
shard_info,
ws_url,
Expand Down
45 changes: 29 additions & 16 deletions src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use reqwest::header::{HeaderMap as Headers, HeaderValue};
#[cfg(feature = "utils")]
use reqwest::Url;
use reqwest::{Client, ClientBuilder, Response as ReqwestResponse, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use serde::de::DeserializeOwned;
use serde_json::{from_value, json, to_string, to_vec};
use tracing::{debug, trace};
Expand Down Expand Up @@ -55,8 +54,8 @@ pub struct HttpBuilder {
client: Option<Client>,
ratelimiter: Option<Ratelimiter>,
ratelimiter_disabled: bool,
token: SecretString,
proxy: Option<String>,
token: Arc<str>,
proxy: Option<FixedString<u16>>,
application_id: Option<ApplicationId>,
}

Expand All @@ -68,7 +67,7 @@ impl HttpBuilder {
client: None,
ratelimiter: None,
ratelimiter_disabled: false,
token: SecretString::new(parse_token(token)),
token: parse_token(token),
proxy: None,
application_id: None,
}
Expand All @@ -83,7 +82,7 @@ impl HttpBuilder {
/// Sets a token for the bot. If the token is not prefixed "Bot ", this method will
/// automatically do so.
pub fn token(mut self, token: &str) -> Self {
self.token = SecretString::new(parse_token(token));
self.token = parse_token(token);
self
}

Expand Down Expand Up @@ -122,10 +121,24 @@ impl HttpBuilder {
/// proxy's behavior where it will tunnel requests that use TLS via [`HTTP CONNECT`] method
/// (e.g. using [`reqwest::Proxy`]).
///
/// # Panics
///
/// Panics if the proxy URL is larger than u16::MAX characters... what are you doing?
///
/// [`twilight-http-proxy`]: https://github.com/twilight-rs/http-proxy
/// [`HTTP CONNECT`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/CONNECT
pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
pub fn proxy<'a>(mut self, proxy: impl Into<Cow<'a, str>>) -> Self {
let proxy = proxy.into();
let len = proxy.len();

let proxy = match proxy {
Cow::Owned(proxy) => FixedString::from_string_trunc(proxy),
Cow::Borrowed(proxy) => FixedString::from_str_trunc(proxy),
};

assert_eq!(len as u32, proxy.len(), "Proxy URL should not be larger than u16::MAX chars");
self.proxy = Some(proxy);

self
}

Expand All @@ -142,7 +155,7 @@ impl HttpBuilder {

let ratelimiter = (!self.ratelimiter_disabled).then(|| {
self.ratelimiter
.unwrap_or_else(|| Ratelimiter::new(client.clone(), self.token.expose_secret()))
.unwrap_or_else(|| Ratelimiter::new(client.clone(), Arc::clone(&self.token)))
});

Http {
Expand All @@ -155,13 +168,13 @@ impl HttpBuilder {
}
}

fn parse_token(token: &str) -> String {
fn parse_token(token: &str) -> Arc<str> {
let token = token.trim();

if token.starts_with("Bot ") || token.starts_with("Bearer ") {
token.to_string()
Arc::from(token)
} else {
format!("Bot {token}")
Arc::from(format!("Bot {token}"))
}
}

Expand All @@ -186,8 +199,8 @@ fn reason_into_header(reason: &str) -> Headers {
pub struct Http {
pub(crate) client: Client,
pub ratelimiter: Option<Ratelimiter>,
pub proxy: Option<String>,
token: SecretString,
pub proxy: Option<FixedString<u16>>,
token: Arc<str>,
application_id: AtomicU64,
}

Expand All @@ -214,8 +227,8 @@ impl Http {
self.application_id.store(application_id.get(), Ordering::Relaxed);
}

pub fn token(&self) -> &str {
self.token.expose_secret()
pub(crate) fn token(&self) -> &Arc<str> {
&self.token
}

/// Adds a [`User`] to a [`Guild`] with a valid OAuth2 access token.
Expand Down Expand Up @@ -4642,7 +4655,7 @@ impl Http {
let response = if let Some(ratelimiter) = &self.ratelimiter {
ratelimiter.perform(req).await?
} else {
let request = req.build(&self.client, self.token(), self.proxy.as_deref())?.build()?;
let request = req.build(&self.client, &self.token, self.proxy.as_deref())?.build()?;
self.client.execute(request).await?
};

Expand Down
14 changes: 5 additions & 9 deletions src/http/ratelimiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
use std::borrow::Cow;
use std::fmt;
use std::str::{self, FromStr};
use std::sync::Arc;
use std::time::SystemTime;

use dashmap::DashMap;
use reqwest::header::HeaderMap;
use reqwest::{Client, Response, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use tokio::sync::Mutex;
use tokio::time::{sleep, Duration};
use tracing::debug;
Expand Down Expand Up @@ -85,7 +85,7 @@ pub struct Ratelimiter {
client: Client,
global: Mutex<()>,
routes: DashMap<RatelimitingBucket, Ratelimit>,
token: SecretString,
token: Arc<str>,
absolute_ratelimits: bool,
ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
}
Expand All @@ -108,16 +108,12 @@ impl Ratelimiter {
///
/// The bot token must be prefixed with `"Bot "`. The ratelimiter does not prefix it.
#[must_use]
pub fn new(client: Client, token: impl Into<String>) -> Self {
Self::_new(client, token.into())
}

fn _new(client: Client, token: String) -> Self {
pub fn new(client: Client, token: Arc<str>) -> Self {
Self {
token,
client,
global: Mutex::default(),
routes: DashMap::new(),
token: SecretString::new(token),
ratelimit_callback: Box::new(|_| {}),
absolute_ratelimits: false,
}
Expand Down Expand Up @@ -197,7 +193,7 @@ impl Ratelimiter {
sleep(delay_time).await;
}

let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
let request = req.clone().build(&self.client, &self.token, None)?;
let response = self.client.execute(request.build()?).await?;

// Check if the request got ratelimited by checking for status 429, and if so, sleep
Expand Down

0 comments on commit 561d629

Please sign in to comment.