Skip to content

Commit

Permalink
fix: special characters in GCS urls (#3651)
Browse files Browse the repository at this point in the history
resolves  #3649
  • Loading branch information
kevinzwang authored Jan 8, 2025
1 parent 2de6787 commit 426ddd0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
44 changes: 19 additions & 25 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use google_cloud_storage::{
},
};
use google_cloud_token::{TokenSource, TokenSourceProvider};
use regex::Regex;
use snafu::{IntoError, ResultExt, Snafu};
use tokio::sync::Semaphore;

Expand All @@ -37,11 +38,6 @@ enum Error {
#[snafu(display("Unable to read data from {}: {}", path, source))]
UnableToReadBytes { path: String, source: GError },

#[snafu(display("Unable to parse URL: \"{}\"", path))]
InvalidUrl {
path: String,
source: url::ParseError,
},
#[snafu(display("Unable to load Credentials: {}", source))]
UnableToLoadCredentials {
source: google_cloud_storage::client::google_cloud_auth::error::Error,
Expand All @@ -62,8 +58,8 @@ enum Error {
impl From<Error> for super::Error {
fn from(error: Error) -> Self {
use Error::{
InvalidUrl, NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore,
UnableToListObjects, UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes,
NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore, UnableToListObjects,
UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes,
};
match error {
UnableToReadBytes { path, source }
Expand Down Expand Up @@ -128,7 +124,6 @@ impl From<Error> for super::Error {
path: path.into(),
source: error.into(),
},
InvalidUrl { path, source } => Self::InvalidUrl { path, source },
UnableToLoadCredentials { source } => Self::UnableToLoadCredentials {
store: super::SourceType::GCS,
source: source.into(),
Expand All @@ -154,17 +149,19 @@ struct GCSClientWrapper {
connection_pool_sema: Arc<Semaphore>,
}

fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> {
let bucket = match uri.host_str() {
Some(s) => Ok(s),
None => Err(Error::InvalidUrl {
path: uri.to_string(),
source: url::ParseError::EmptyHost,
}),
}?;
let key = uri.path();
let key = key.strip_prefix(GCS_DELIMITER).unwrap_or(key);
Ok((bucket, key))
fn parse_raw_uri(uri: &str) -> super::Result<(&str, &str)> {
// We use regex here instead of the more robust url crate because we do not want to handle character escaping
// which is done by `google_cloud_storage::client::Client` already
let re = Regex::new(r"^gs://([^/]+)(?:/(.*))?$").unwrap();

if let Some(cap) = re.captures(uri) {
let bucket = cap.get(1).unwrap().as_str();
let key = cap.get(2).map_or("", |key| key.as_str());

Ok((bucket, key))
} else {
Err(Error::NotAFile { path: uri.into() }.into())
}
}

impl GCSClientWrapper {
Expand All @@ -174,8 +171,7 @@ impl GCSClientWrapper {
range: Option<Range<usize>>,
io_stats: Option<IOStatsRef>,
) -> super::Result<GetResult> {
let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(uri)?;
if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}
Expand Down Expand Up @@ -226,8 +222,7 @@ impl GCSClientWrapper {
}

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(uri)?;
if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}
Expand Down Expand Up @@ -315,8 +310,7 @@ impl GCSClientWrapper {
page_size: Option<i32>,
io_stats: Option<IOStatsRef>,
) -> super::Result<LSResult> {
let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(path)?;

let _permit = self
.connection_pool_sema
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/io/test_url_download_public_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import pytest

import daft


@pytest.mark.integration()
def test_url_download_gcs_public_special_characters(small_images_s3_paths):
df = daft.from_glob_path("gs://daft-public-data-gs/test_naming/**")
df = df.with_column("data", df["path"].url.download())

assert df.to_pydict() == {
"path": [
"gs://daft-public-data-gs/test_naming/test. .txt",
"gs://daft-public-data-gs/test_naming/test.%.txt",
"gs://daft-public-data-gs/test_naming/test.-.txt",
"gs://daft-public-data-gs/test_naming/test.=.txt",
"gs://daft-public-data-gs/test_naming/test.?.txt",
],
"size": [5, 5, 5, 5, 5],
"num_rows": [None, None, None, None, None],
"data": [b"test\n", b"test\n", b"test\n", b"test\n", b"test\n"],
}

0 comments on commit 426ddd0

Please sign in to comment.