From c2ea5295e32e386c15bc4e9212ef87dd0e2640b3 Mon Sep 17 00:00:00 2001 From: Gnome! Date: Mon, 29 Jan 2024 18:48:07 +0000 Subject: [PATCH] Use Arc to store the token (#2745) `SecretString` is nice to have, but this is more efficient, and `SecretString` is defense in depth we can do without. --- src/gateway/bridge/shard_queuer.rs | 2 +- src/gateway/shard.rs | 10 +++---- src/http/client.rs | 45 +++++++++++++++++++----------- src/http/ratelimiting.rs | 14 ++++------ 4 files changed, 40 insertions(+), 31 deletions(-) diff --git a/src/gateway/bridge/shard_queuer.rs b/src/gateway/bridge/shard_queuer.rs index ae36db7eab5..bebd8487d0d 100644 --- a/src/gateway/bridge/shard_queuer.rs +++ b/src/gateway/bridge/shard_queuer.rs @@ -214,7 +214,7 @@ impl ShardQueuer { async fn start(&mut self, shard_id: ShardId) -> Result<()> { let mut shard = Shard::new( Arc::clone(&self.ws_url), - self.http.token(), + Arc::clone(self.http.token()), ShardInfo::new(shard_id, self.shard_total), self.intents, self.presence.clone(), diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index 9ecf3de2da8..03ba49880a1 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -72,7 +72,7 @@ pub struct Shard { // This acts as a timeout to determine if the shard has - for some reason - not started within // a decent amount of time. pub started: Instant, - pub token: String, + pub token: Arc, ws_url: Arc, pub intents: GatewayIntents, } @@ -99,7 +99,7 @@ impl Shard { /// # /// # async fn run() -> Result<(), Box> { /// # let http: Arc = unimplemented!(); - /// let token = std::env::var("DISCORD_BOT_TOKEN")?; + /// let token = Arc::from(std::env::var("DISCORD_BOT_TOKEN")?); /// let shard_info = ShardInfo { /// id: ShardId(0), /// total: NonZeroU16::MIN, @@ -107,7 +107,7 @@ impl Shard { /// /// // retrieve the gateway response, which contains the URL to connect to /// let gateway = Arc::from(http.get_gateway().await?.url); - /// let shard = Shard::new(gateway, &token, shard_info, GatewayIntents::all(), None).await?; + /// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?; /// /// // at this point, you can create a `loop`, and receive events and match /// // their variants @@ -121,7 +121,7 @@ impl Shard { /// TLS error. pub async fn new( ws_url: Arc, - token: &str, + token: Arc, shard_info: ShardInfo, intents: GatewayIntents, presence: Option, @@ -148,7 +148,7 @@ impl Shard { seq, stage, started: Instant::now(), - token: token.to_string(), + token, session_id, shard_info, ws_url, diff --git a/src/http/client.rs b/src/http/client.rs index 9830cfda89a..a6e00ff9d2b 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -11,7 +11,6 @@ use reqwest::header::{HeaderMap as Headers, HeaderValue}; #[cfg(feature = "utils")] use reqwest::Url; use reqwest::{Client, ClientBuilder, Response as ReqwestResponse, StatusCode}; -use secrecy::{ExposeSecret, SecretString}; use serde::de::DeserializeOwned; use serde_json::{from_value, json, to_string, to_vec}; use tracing::{debug, warn}; @@ -55,8 +54,8 @@ pub struct HttpBuilder { client: Option, ratelimiter: Option, ratelimiter_disabled: bool, - token: SecretString, - proxy: Option, + token: Arc, + proxy: Option>, application_id: Option, default_allowed_mentions: Option>, } @@ -69,7 +68,7 @@ impl HttpBuilder { client: None, ratelimiter: None, ratelimiter_disabled: false, - token: SecretString::new(parse_token(token)), + token: parse_token(token), proxy: None, application_id: None, default_allowed_mentions: None, @@ -85,7 +84,7 @@ impl HttpBuilder { /// Sets a token for the bot. If the token is not prefixed "Bot ", this method will /// automatically do so. pub fn token(mut self, token: &str) -> Self { - self.token = SecretString::new(parse_token(token)); + self.token = parse_token(token); self } @@ -124,10 +123,24 @@ impl HttpBuilder { /// proxy's behavior where it will tunnel requests that use TLS via [`HTTP CONNECT`] method /// (e.g. using [`reqwest::Proxy`]). /// + /// # Panics + /// + /// Panics if the proxy URL is larger than u16::MAX characters... what are you doing? + /// /// [`twilight-http-proxy`]: https://github.com/twilight-rs/http-proxy /// [`HTTP CONNECT`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/CONNECT - pub fn proxy(mut self, proxy: impl Into) -> Self { - self.proxy = Some(proxy.into()); + pub fn proxy<'a>(mut self, proxy: impl Into>) -> Self { + let proxy = proxy.into(); + let len = proxy.len(); + + let proxy = match proxy { + Cow::Owned(proxy) => FixedString::from_string_trunc(proxy), + Cow::Borrowed(proxy) => FixedString::from_str_trunc(proxy), + }; + + assert_eq!(len as u32, proxy.len(), "Proxy URL should not be larger than u16::MAX chars"); + self.proxy = Some(proxy); + self } @@ -156,7 +169,7 @@ impl HttpBuilder { let ratelimiter = (!self.ratelimiter_disabled).then(|| { self.ratelimiter - .unwrap_or_else(|| Ratelimiter::new(client.clone(), self.token.expose_secret())) + .unwrap_or_else(|| Ratelimiter::new(client.clone(), Arc::clone(&self.token))) }); Http { @@ -170,13 +183,13 @@ impl HttpBuilder { } } -fn parse_token(token: &str) -> String { +fn parse_token(token: &str) -> Arc { let token = token.trim(); if token.starts_with("Bot ") || token.starts_with("Bearer ") { - token.to_string() + Arc::from(token) } else { - format!("Bot {token}") + Arc::from(format!("Bot {token}")) } } @@ -201,8 +214,8 @@ fn reason_into_header(reason: &str) -> Headers { pub struct Http { pub(crate) client: Client, pub ratelimiter: Option, - pub proxy: Option, - token: SecretString, + pub proxy: Option>, + token: Arc, application_id: AtomicU64, pub default_allowed_mentions: Option>, } @@ -230,8 +243,8 @@ impl Http { self.application_id.store(application_id.get(), Ordering::Relaxed); } - pub fn token(&self) -> &str { - self.token.expose_secret() + pub(crate) fn token(&self) -> &Arc { + &self.token } /// Adds a [`User`] to a [`Guild`] with a valid OAuth2 access token. @@ -4950,7 +4963,7 @@ impl Http { let response = if let Some(ratelimiter) = &self.ratelimiter { ratelimiter.perform(req).await? } else { - let request = req.build(&self.client, self.token(), self.proxy.as_deref())?.build()?; + let request = req.build(&self.client, &self.token, self.proxy.as_deref())?.build()?; self.client.execute(request).await? }; diff --git a/src/http/ratelimiting.rs b/src/http/ratelimiting.rs index 23daed92b3a..62221812407 100644 --- a/src/http/ratelimiting.rs +++ b/src/http/ratelimiting.rs @@ -38,12 +38,12 @@ use std::borrow::Cow; use std::fmt; use std::str::{self, FromStr}; +use std::sync::Arc; use std::time::SystemTime; use dashmap::DashMap; use reqwest::header::HeaderMap; use reqwest::{Client, Response, StatusCode}; -use secrecy::{ExposeSecret, SecretString}; use tokio::sync::Mutex; use tokio::time::{sleep, Duration}; use tracing::debug; @@ -85,7 +85,7 @@ pub struct Ratelimiter { client: Client, global: Mutex<()>, routes: DashMap, - token: SecretString, + token: Arc, absolute_ratelimits: bool, ratelimit_callback: Box, } @@ -108,16 +108,12 @@ impl Ratelimiter { /// /// The bot token must be prefixed with `"Bot "`. The ratelimiter does not prefix it. #[must_use] - pub fn new(client: Client, token: impl Into) -> Self { - Self::new_(client, token.into()) - } - - fn new_(client: Client, token: String) -> Self { + pub fn new(client: Client, token: Arc) -> Self { Self { + token, client, global: Mutex::default(), routes: DashMap::new(), - token: SecretString::new(token), ratelimit_callback: Box::new(|_| {}), absolute_ratelimits: false, } @@ -197,7 +193,7 @@ impl Ratelimiter { sleep(delay_time).await; } - let request = req.clone().build(&self.client, self.token.expose_secret(), None)?; + let request = req.clone().build(&self.client, &self.token, None)?; let response = self.client.execute(request.build()?).await?; // Check if the request got ratelimited by checking for status 429, and if so, sleep