From 426ddd084131b2afe9d713935d4cb14c477dea25 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 8 Jan 2025 10:30:49 -0800 Subject: [PATCH] fix: special characters in GCS urls (#3651) resolves #3649 --- src/daft-io/src/google_cloud.rs | 44 ++++++++----------- .../io/test_url_download_public_gcs.py | 24 ++++++++++ 2 files changed, 43 insertions(+), 25 deletions(-) create mode 100644 tests/integration/io/test_url_download_public_gcs.py diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index 41ca0261d0..60dbec938c 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -12,6 +12,7 @@ use google_cloud_storage::{ }, }; use google_cloud_token::{TokenSource, TokenSourceProvider}; +use regex::Regex; use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::Semaphore; @@ -37,11 +38,6 @@ enum Error { #[snafu(display("Unable to read data from {}: {}", path, source))] UnableToReadBytes { path: String, source: GError }, - #[snafu(display("Unable to parse URL: \"{}\"", path))] - InvalidUrl { - path: String, - source: url::ParseError, - }, #[snafu(display("Unable to load Credentials: {}", source))] UnableToLoadCredentials { source: google_cloud_storage::client::google_cloud_auth::error::Error, @@ -62,8 +58,8 @@ enum Error { impl From for super::Error { fn from(error: Error) -> Self { use Error::{ - InvalidUrl, NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore, - UnableToListObjects, UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes, + NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore, UnableToListObjects, + UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes, }; match error { UnableToReadBytes { path, source } @@ -128,7 +124,6 @@ impl From for super::Error { path: path.into(), source: error.into(), }, - InvalidUrl { path, source } => Self::InvalidUrl { path, source }, UnableToLoadCredentials { source } => Self::UnableToLoadCredentials { store: super::SourceType::GCS, source: source.into(), @@ -154,17 +149,19 @@ struct GCSClientWrapper { connection_pool_sema: Arc, } -fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> { - let bucket = match uri.host_str() { - Some(s) => Ok(s), - None => Err(Error::InvalidUrl { - path: uri.to_string(), - source: url::ParseError::EmptyHost, - }), - }?; - let key = uri.path(); - let key = key.strip_prefix(GCS_DELIMITER).unwrap_or(key); - Ok((bucket, key)) +fn parse_raw_uri(uri: &str) -> super::Result<(&str, &str)> { + // We use regex here instead of the more robust url crate because we do not want to handle character escaping + // which is done by `google_cloud_storage::client::Client` already + let re = Regex::new(r"^gs://([^/]+)(?:/(.*))?$").unwrap(); + + if let Some(cap) = re.captures(uri) { + let bucket = cap.get(1).unwrap().as_str(); + let key = cap.get(2).map_or("", |key| key.as_str()); + + Ok((bucket, key)) + } else { + Err(Error::NotAFile { path: uri.into() }.into()) + } } impl GCSClientWrapper { @@ -174,8 +171,7 @@ impl GCSClientWrapper { range: Option>, io_stats: Option, ) -> super::Result { - let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; - let (bucket, key) = parse_uri(&uri)?; + let (bucket, key) = parse_raw_uri(uri)?; if key.is_empty() { return Err(Error::NotAFile { path: uri.into() }.into()); } @@ -226,8 +222,7 @@ impl GCSClientWrapper { } async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { - let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; - let (bucket, key) = parse_uri(&uri)?; + let (bucket, key) = parse_raw_uri(uri)?; if key.is_empty() { return Err(Error::NotAFile { path: uri.into() }.into()); } @@ -315,8 +310,7 @@ impl GCSClientWrapper { page_size: Option, io_stats: Option, ) -> super::Result { - let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; - let (bucket, key) = parse_uri(&uri)?; + let (bucket, key) = parse_raw_uri(path)?; let _permit = self .connection_pool_sema diff --git a/tests/integration/io/test_url_download_public_gcs.py b/tests/integration/io/test_url_download_public_gcs.py new file mode 100644 index 0000000000..1e7dfbed91 --- /dev/null +++ b/tests/integration/io/test_url_download_public_gcs.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import pytest + +import daft + + +@pytest.mark.integration() +def test_url_download_gcs_public_special_characters(small_images_s3_paths): + df = daft.from_glob_path("gs://daft-public-data-gs/test_naming/**") + df = df.with_column("data", df["path"].url.download()) + + assert df.to_pydict() == { + "path": [ + "gs://daft-public-data-gs/test_naming/test. .txt", + "gs://daft-public-data-gs/test_naming/test.%.txt", + "gs://daft-public-data-gs/test_naming/test.-.txt", + "gs://daft-public-data-gs/test_naming/test.=.txt", + "gs://daft-public-data-gs/test_naming/test.?.txt", + ], + "size": [5, 5, 5, 5, 5], + "num_rows": [None, None, None, None, None], + "data": [b"test\n", b"test\n", b"test\n", b"test\n", b"test\n"], + }