diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 55eabf85ffde..02a0a99498ee 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -10,6 +10,7 @@ use http::HeaderMap; use itertools::Either; use reqwest::{Client, Response, StatusCode}; use reqwest_middleware::ClientWithMiddleware; +use tokio::sync::Semaphore; use tracing::{info_span, instrument, trace, warn, Instrument}; use url::Url; @@ -235,6 +236,7 @@ impl RegistryClient { package_name: &PackageName, index: Option<&'index IndexUrl>, capabilities: &IndexCapabilities, + download_concurrency: &Semaphore, ) -> Result)>, Error> { let indexes = if let Some(index) = index { Either::Left(std::iter::once(index)) @@ -253,6 +255,7 @@ impl RegistryClient { // If we're searching for the first index that contains the package, fetch serially. IndexStrategy::FirstIndex => { for index in it { + let _permit = download_concurrency.acquire().await; if let Some(metadata) = self .simple_single_index(package_name, index, capabilities) .await? @@ -265,9 +268,9 @@ impl RegistryClient { // Otherwise, fetch concurrently. IndexStrategy::UnsafeBestMatch | IndexStrategy::UnsafeFirstMatch => { - // TODO(charlie): Respect concurrency limits. results = futures::stream::iter(it) .map(|index| async move { + let _permit = download_concurrency.acquire().await; let metadata = self .simple_single_index(package_name, index, capabilities) .await?; diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index 9a74f9d2ad5d..14fb52dba5ae 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -992,6 +992,20 @@ impl<'a> ManagedClient<'a> { let _permit = self.control.acquire().await.unwrap(); f(self.unmanaged).await } + + /// Perform a request using a client that internally manages the concurrency limit. + /// + /// The callback is passed the client and a semaphore. It must acquire the semaphore before + /// any request through the client and drop it after. + /// + /// This method serves as an escape hatch for functions that may want to send multiple requests + /// in parallel. + pub async fn manual(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T + where + F: Future, + { + f(self.unmanaged, &self.control).await + } } /// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present. diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index 19c40abdadbb..4d65fc11d125 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -20,6 +20,7 @@ use std::time::{Duration, SystemTime}; use std::{env, fmt, io}; use thiserror::Error; use tokio::io::{AsyncReadExt, BufReader}; +use tokio::sync::Semaphore; use tokio_util::io::ReaderStream; use tracing::{debug, enabled, trace, warn, Level}; use url::Url; @@ -369,6 +370,7 @@ pub async fn upload( username: Option<&str>, password: Option<&str>, check_url_client: Option<&CheckUrlClient<'_>>, + download_concurrency: &Semaphore, reporter: Arc, ) -> Result { let form_metadata = form_metadata(file, filename) @@ -428,7 +430,8 @@ pub async fn upload( PublishSendError::Status(..) | PublishSendError::StatusNoBody(..) ) { if let Some(check_url_client) = &check_url_client { - if check_url(check_url_client, file, filename).await? { + if check_url(check_url_client, file, filename, download_concurrency).await? + { // There was a raced upload of the same file, so even though our upload failed, // the right file now exists in the registry. return Ok(false); @@ -450,6 +453,7 @@ pub async fn check_url( check_url_client: &CheckUrlClient<'_>, file: &Path, filename: &DistFilename, + download_concurrency: &Semaphore, ) -> Result { let CheckUrlClient { index_url, @@ -470,7 +474,12 @@ pub async fn check_url( debug!("Checking for {filename} in the registry"); let response = match registry_client - .simple(filename.name(), Some(index_url), index_capabilities) + .simple( + filename.name(), + Some(index_url), + index_capabilities, + download_concurrency, + ) .await { Ok(response) => response, diff --git a/crates/uv-resolver/src/resolver/provider.rs b/crates/uv-resolver/src/resolver/provider.rs index dc9b38e08ec9..0c076b85862d 100644 --- a/crates/uv-resolver/src/resolver/provider.rs +++ b/crates/uv-resolver/src/resolver/provider.rs @@ -155,7 +155,9 @@ impl ResolverProvider for DefaultResolverProvider<'_, Con let result = self .fetcher .client() - .managed(|client| client.simple(package_name, index, self.capabilities)) + .manual(|client, semaphore| { + client.simple(package_name, index, self.capabilities, semaphore) + }) .await; match result { diff --git a/crates/uv/src/commands/pip/latest.rs b/crates/uv/src/commands/pip/latest.rs index abc028a0076b..12f56362913a 100644 --- a/crates/uv/src/commands/pip/latest.rs +++ b/crates/uv/src/commands/pip/latest.rs @@ -1,3 +1,4 @@ +use tokio::sync::Semaphore; use tracing::debug; use uv_client::{RegistryClient, VersionFiles}; use uv_distribution_filename::DistFilename; @@ -27,10 +28,15 @@ impl LatestClient<'_> { &self, package: &PackageName, index: Option<&IndexUrl>, + download_concurrency: &Semaphore, ) -> anyhow::Result, uv_client::Error> { debug!("Fetching latest version of: `{package}`"); - let archives = match self.client.simple(package, index, self.capabilities).await { + let archives = match self + .client + .simple(package, index, self.capabilities, download_concurrency) + .await + { Ok(archives) => archives, Err(err) => { return match err.into_kind() { diff --git a/crates/uv/src/commands/pip/list.rs b/crates/uv/src/commands/pip/list.rs index 6ec866404dca..d9127ecaeb0a 100644 --- a/crates/uv/src/commands/pip/list.rs +++ b/crates/uv/src/commands/pip/list.rs @@ -8,6 +8,7 @@ use itertools::Itertools; use owo_colors::OwoColorize; use rustc_hash::FxHashMap; use serde::Serialize; +use tokio::sync::Semaphore; use unicode_width::UnicodeWidthStr; use uv_cache::{Cache, Refresh}; @@ -94,6 +95,7 @@ pub(crate) async fn pip_list( .markers(environment.interpreter().markers()) .platform(environment.interpreter().platform()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Determine the platform tags. let interpreter = environment.interpreter(); @@ -116,7 +118,9 @@ pub(crate) async fn pip_list( // Fetch the latest version for each package. let mut fetches = futures::stream::iter(&results) .map(|dist| async { - let latest = client.find_latest(dist.name(), None).await?; + let latest = client + .find_latest(dist.name(), None, &download_concurrency) + .await?; Ok::<(&PackageName, Option), uv_client::Error>((dist.name(), latest)) }) .buffer_unordered(concurrency.downloads); diff --git a/crates/uv/src/commands/pip/tree.rs b/crates/uv/src/commands/pip/tree.rs index 077dcf6f5720..da0266719b2c 100644 --- a/crates/uv/src/commands/pip/tree.rs +++ b/crates/uv/src/commands/pip/tree.rs @@ -8,6 +8,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::EdgeRef; use petgraph::Direction; use rustc_hash::{FxHashMap, FxHashSet}; +use tokio::sync::Semaphore; use uv_cache::{Cache, Refresh}; use uv_cache_info::Timestamp; @@ -95,6 +96,7 @@ pub(crate) async fn pip_tree( .markers(environment.interpreter().markers()) .platform(environment.interpreter().platform()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Determine the platform tags. let interpreter = environment.interpreter(); @@ -117,7 +119,10 @@ pub(crate) async fn pip_tree( // Fetch the latest version for each package. let mut fetches = futures::stream::iter(&packages) .map(|(name, ..)| async { - let Some(filename) = client.find_latest(name, None).await? else { + let Some(filename) = client + .find_latest(name, None, &download_concurrency) + .await? + else { return Ok(None); }; Ok::, uv_client::Error>(Some((*name, filename.into_version()))) diff --git a/crates/uv/src/commands/project/tree.rs b/crates/uv/src/commands/project/tree.rs index 83b4234df45b..14d07722e81d 100644 --- a/crates/uv/src/commands/project/tree.rs +++ b/crates/uv/src/commands/project/tree.rs @@ -3,7 +3,7 @@ use std::path::Path; use anstream::print; use anyhow::{Error, Result}; use futures::StreamExt; - +use tokio::sync::Semaphore; use uv_cache::{Cache, Refresh}; use uv_cache_info::Timestamp; use uv_client::{Connectivity, RegistryClientBuilder}; @@ -225,6 +225,7 @@ pub(crate) async fn tree( .keyring(*keyring_provider) .allow_insecure_host(allow_insecure_host.to_vec()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Initialize the client to fetch the latest version of each package. let client = LatestClient { @@ -239,9 +240,12 @@ pub(crate) async fn tree( let reporter = LatestVersionReporter::from(printer).with_length(packages.len() as u64); // Fetch the latest version for each package. + let download_concurrency = &download_concurrency; let mut fetches = futures::stream::iter(packages) .map(|(package, index)| async move { - let Some(filename) = client.find_latest(package.name(), Some(&index)).await? + let Some(filename) = client + .find_latest(package.name(), Some(&index), download_concurrency) + .await? else { return Ok(None); }; diff --git a/crates/uv/src/commands/publish.rs b/crates/uv/src/commands/publish.rs index cef9c7833a50..898348c6423a 100644 --- a/crates/uv/src/commands/publish.rs +++ b/crates/uv/src/commands/publish.rs @@ -8,6 +8,7 @@ use std::fmt::Write; use std::iter; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Semaphore; use tracing::{debug, info}; use url::Url; use uv_cache::Cache; @@ -69,6 +70,8 @@ pub(crate) async fn publish( let oidc_client = BaseClientBuilder::new() .auth_integration(AuthIntegration::NoAuthMiddleware) .wrap_existing(&upload_client); + // We're only checking a single URL and one at a time, so 1 permit is sufficient + let download_concurrency = Arc::new(Semaphore::new(1)); let (publish_url, username, password) = gather_credentials( publish_url, @@ -110,7 +113,9 @@ pub(crate) async fn publish( for (file, raw_filename, filename) in files { if let Some(check_url_client) = &check_url_client { - if uv_publish::check_url(check_url_client, &file, &filename).await? { + if uv_publish::check_url(check_url_client, &file, &filename, &download_concurrency) + .await? + { writeln!(printer.stderr(), "File {filename} already exists, skipping")?; continue; } @@ -134,6 +139,7 @@ pub(crate) async fn publish( username.as_deref(), password.as_deref(), check_url_client.as_ref(), + &download_concurrency, // Needs to be an `Arc` because the reqwest `Body` static lifetime requirement Arc::new(reporter), )