Skip to content

Commit

Permalink
Respect concurrency limits in parallel index fetch (#11182)
Browse files Browse the repository at this point in the history
With the parallel simple index fetching, we would only acquire one
download concurrency token, meaning that we could in the worst case make
times the number of indexes more requests than the user requested limit.
We fix this by passing the semaphore down to the simple API method.
  • Loading branch information
konstin authored Feb 3, 2025
1 parent c54dbcb commit 56684e4
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 10 deletions.
5 changes: 4 additions & 1 deletion crates/uv-client/src/registry_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use http::HeaderMap;
use itertools::Either;
use reqwest::{Client, Response, StatusCode};
use reqwest_middleware::ClientWithMiddleware;
use tokio::sync::Semaphore;
use tracing::{info_span, instrument, trace, warn, Instrument};
use url::Url;

Expand Down Expand Up @@ -235,6 +236,7 @@ impl RegistryClient {
package_name: &PackageName,
index: Option<&'index IndexUrl>,
capabilities: &IndexCapabilities,
download_concurrency: &Semaphore,
) -> Result<Vec<(&'index IndexUrl, OwnedArchive<SimpleMetadata>)>, Error> {
let indexes = if let Some(index) = index {
Either::Left(std::iter::once(index))
Expand All @@ -253,6 +255,7 @@ impl RegistryClient {
// If we're searching for the first index that contains the package, fetch serially.
IndexStrategy::FirstIndex => {
for index in it {
let _permit = download_concurrency.acquire().await;
if let Some(metadata) = self
.simple_single_index(package_name, index, capabilities)
.await?
Expand All @@ -265,9 +268,9 @@ impl RegistryClient {

// Otherwise, fetch concurrently.
IndexStrategy::UnsafeBestMatch | IndexStrategy::UnsafeFirstMatch => {
// TODO(charlie): Respect concurrency limits.
results = futures::stream::iter(it)
.map(|index| async move {
let _permit = download_concurrency.acquire().await;
let metadata = self
.simple_single_index(package_name, index, capabilities)
.await?;
Expand Down
14 changes: 14 additions & 0 deletions crates/uv-distribution/src/distribution_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,20 @@ impl<'a> ManagedClient<'a> {
let _permit = self.control.acquire().await.unwrap();
f(self.unmanaged).await
}

/// Perform a request using a client that internally manages the concurrency limit.
///
/// The callback is passed the client and a semaphore. It must acquire the semaphore before
/// any request through the client and drop it after.
///
/// This method serves as an escape hatch for functions that may want to send multiple requests
/// in parallel.
pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
where
F: Future<Output = T>,
{
f(self.unmanaged, &self.control).await
}
}

/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
Expand Down
13 changes: 11 additions & 2 deletions crates/uv-publish/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::{Duration, SystemTime};
use std::{env, fmt, io};
use thiserror::Error;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::sync::Semaphore;
use tokio_util::io::ReaderStream;
use tracing::{debug, enabled, trace, warn, Level};
use url::Url;
Expand Down Expand Up @@ -369,6 +370,7 @@ pub async fn upload(
username: Option<&str>,
password: Option<&str>,
check_url_client: Option<&CheckUrlClient<'_>>,
download_concurrency: &Semaphore,
reporter: Arc<impl Reporter>,
) -> Result<bool, PublishError> {
let form_metadata = form_metadata(file, filename)
Expand Down Expand Up @@ -428,7 +430,8 @@ pub async fn upload(
PublishSendError::Status(..) | PublishSendError::StatusNoBody(..)
) {
if let Some(check_url_client) = &check_url_client {
if check_url(check_url_client, file, filename).await? {
if check_url(check_url_client, file, filename, download_concurrency).await?
{
// There was a raced upload of the same file, so even though our upload failed,
// the right file now exists in the registry.
return Ok(false);
Expand All @@ -450,6 +453,7 @@ pub async fn check_url(
check_url_client: &CheckUrlClient<'_>,
file: &Path,
filename: &DistFilename,
download_concurrency: &Semaphore,
) -> Result<bool, PublishError> {
let CheckUrlClient {
index_url,
Expand All @@ -470,7 +474,12 @@ pub async fn check_url(

debug!("Checking for {filename} in the registry");
let response = match registry_client
.simple(filename.name(), Some(index_url), index_capabilities)
.simple(
filename.name(),
Some(index_url),
index_capabilities,
download_concurrency,
)
.await
{
Ok(response) => response,
Expand Down
4 changes: 3 additions & 1 deletion crates/uv-resolver/src/resolver/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ impl<Context: BuildContext> ResolverProvider for DefaultResolverProvider<'_, Con
let result = self
.fetcher
.client()
.managed(|client| client.simple(package_name, index, self.capabilities))
.manual(|client, semaphore| {
client.simple(package_name, index, self.capabilities, semaphore)
})
.await;

match result {
Expand Down
8 changes: 7 additions & 1 deletion crates/uv/src/commands/pip/latest.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use tokio::sync::Semaphore;
use tracing::debug;
use uv_client::{RegistryClient, VersionFiles};
use uv_distribution_filename::DistFilename;
Expand Down Expand Up @@ -27,10 +28,15 @@ impl LatestClient<'_> {
&self,
package: &PackageName,
index: Option<&IndexUrl>,
download_concurrency: &Semaphore,
) -> anyhow::Result<Option<DistFilename>, uv_client::Error> {
debug!("Fetching latest version of: `{package}`");

let archives = match self.client.simple(package, index, self.capabilities).await {
let archives = match self
.client
.simple(package, index, self.capabilities, download_concurrency)
.await
{
Ok(archives) => archives,
Err(err) => {
return match err.into_kind() {
Expand Down
6 changes: 5 additions & 1 deletion crates/uv/src/commands/pip/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use itertools::Itertools;
use owo_colors::OwoColorize;
use rustc_hash::FxHashMap;
use serde::Serialize;
use tokio::sync::Semaphore;
use unicode_width::UnicodeWidthStr;

use uv_cache::{Cache, Refresh};
Expand Down Expand Up @@ -94,6 +95,7 @@ pub(crate) async fn pip_list(
.markers(environment.interpreter().markers())
.platform(environment.interpreter().platform())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);

// Determine the platform tags.
let interpreter = environment.interpreter();
Expand All @@ -116,7 +118,9 @@ pub(crate) async fn pip_list(
// Fetch the latest version for each package.
let mut fetches = futures::stream::iter(&results)
.map(|dist| async {
let latest = client.find_latest(dist.name(), None).await?;
let latest = client
.find_latest(dist.name(), None, &download_concurrency)
.await?;
Ok::<(&PackageName, Option<DistFilename>), uv_client::Error>((dist.name(), latest))
})
.buffer_unordered(concurrency.downloads);
Expand Down
7 changes: 6 additions & 1 deletion crates/uv/src/commands/pip/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::EdgeRef;
use petgraph::Direction;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::sync::Semaphore;

use uv_cache::{Cache, Refresh};
use uv_cache_info::Timestamp;
Expand Down Expand Up @@ -95,6 +96,7 @@ pub(crate) async fn pip_tree(
.markers(environment.interpreter().markers())
.platform(environment.interpreter().platform())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);

// Determine the platform tags.
let interpreter = environment.interpreter();
Expand All @@ -117,7 +119,10 @@ pub(crate) async fn pip_tree(
// Fetch the latest version for each package.
let mut fetches = futures::stream::iter(&packages)
.map(|(name, ..)| async {
let Some(filename) = client.find_latest(name, None).await? else {
let Some(filename) = client
.find_latest(name, None, &download_concurrency)
.await?
else {
return Ok(None);
};
Ok::<Option<_>, uv_client::Error>(Some((*name, filename.into_version())))
Expand Down
8 changes: 6 additions & 2 deletions crates/uv/src/commands/project/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::Path;
use anstream::print;
use anyhow::{Error, Result};
use futures::StreamExt;

use tokio::sync::Semaphore;
use uv_cache::{Cache, Refresh};
use uv_cache_info::Timestamp;
use uv_client::{Connectivity, RegistryClientBuilder};
Expand Down Expand Up @@ -225,6 +225,7 @@ pub(crate) async fn tree(
.keyring(*keyring_provider)
.allow_insecure_host(allow_insecure_host.to_vec())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);

// Initialize the client to fetch the latest version of each package.
let client = LatestClient {
Expand All @@ -239,9 +240,12 @@ pub(crate) async fn tree(
let reporter = LatestVersionReporter::from(printer).with_length(packages.len() as u64);

// Fetch the latest version for each package.
let download_concurrency = &download_concurrency;
let mut fetches = futures::stream::iter(packages)
.map(|(package, index)| async move {
let Some(filename) = client.find_latest(package.name(), Some(&index)).await?
let Some(filename) = client
.find_latest(package.name(), Some(&index), download_concurrency)
.await?
else {
return Ok(None);
};
Expand Down
8 changes: 7 additions & 1 deletion crates/uv/src/commands/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::fmt::Write;
use std::iter;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, info};
use url::Url;
use uv_cache::Cache;
Expand Down Expand Up @@ -69,6 +70,8 @@ pub(crate) async fn publish(
let oidc_client = BaseClientBuilder::new()
.auth_integration(AuthIntegration::NoAuthMiddleware)
.wrap_existing(&upload_client);
// We're only checking a single URL and one at a time, so 1 permit is sufficient
let download_concurrency = Arc::new(Semaphore::new(1));

let (publish_url, username, password) = gather_credentials(
publish_url,
Expand Down Expand Up @@ -110,7 +113,9 @@ pub(crate) async fn publish(

for (file, raw_filename, filename) in files {
if let Some(check_url_client) = &check_url_client {
if uv_publish::check_url(check_url_client, &file, &filename).await? {
if uv_publish::check_url(check_url_client, &file, &filename, &download_concurrency)
.await?
{
writeln!(printer.stderr(), "File {filename} already exists, skipping")?;
continue;
}
Expand All @@ -134,6 +139,7 @@ pub(crate) async fn publish(
username.as_deref(),
password.as_deref(),
check_url_client.as_ref(),
&download_concurrency,
// Needs to be an `Arc` because the reqwest `Body` static lifetime requirement
Arc::new(reporter),
)
Expand Down

0 comments on commit 56684e4

Please sign in to comment.