Skip to content

Commit 18d9d57

Browse files
authored
Expose pd.DataFrame.to_csv and json.dump keyword arguments (#421)
1 parent 40fe3a7 commit 18d9d57

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

intake_esm/cat.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,15 @@ def from_dict(cls, data: typing.Dict) -> 'ESMCatalogModel':
115115
cat._df = df
116116
return cat
117117

118-
def save(self, name: str, *, directory: str = None, catalog_type: str = 'dict') -> None:
118+
def save(
119+
self,
120+
name: str,
121+
*,
122+
directory: str = None,
123+
catalog_type: str = 'dict',
124+
to_csv_kwargs: dict = None,
125+
json_dump_kwargs: dict = None,
126+
) -> None:
119127
"""
120128
Save the catalog to a file.
121129
@@ -128,6 +136,10 @@ def save(self, name: str, *, directory: str = None, catalog_type: str = 'dict')
128136
catalog_type: str
129137
The type of catalog to save. Whether to save the catalog table as a dictionary
130138
in the JSON file or as a separate CSV file. Valid options are 'dict' and 'file'.
139+
to_csv_kwargs : dict, optional
140+
Additional keyword arguments passed through to the :py:meth:`~pandas.DataFrame.to_csv` method.
141+
json_dump_kwargs : dict, optional
142+
Additional keyword arguments passed through to the :py:func:`~json.dump` function.
131143
132144
Notes
133145
-----
@@ -140,7 +152,7 @@ def save(self, name: str, *, directory: str = None, catalog_type: str = 'dict')
140152
raise ValueError(
141153
f'catalog_type must be either "dict" or "file". Received catalog_type={catalog_type}'
142154
)
143-
csv_file_name = pathlib.Path(f'{name}.csv.gz')
155+
csv_file_name = pathlib.Path(f'{name}.csv')
144156
json_file_name = pathlib.Path(f'{name}.json')
145157
if directory:
146158
directory = pathlib.Path(directory)
@@ -154,13 +166,20 @@ def save(self, name: str, *, directory: str = None, catalog_type: str = 'dict')
154166
data['id'] = name
155167

156168
if catalog_type == 'file':
169+
csv_kwargs = {'index': False}
170+
csv_kwargs.update(to_csv_kwargs or {})
171+
compression = csv_kwargs.get('compression')
172+
extensions = {'gzip': '.gz', 'bz2': '.bz2', 'zip': '.zip', 'xz': '.xz', None: ''}
173+
csv_file_name = f'{csv_file_name}{extensions[compression]}'
157174
data['catalog_file'] = str(csv_file_name)
158-
self.df.to_csv(csv_file_name, compression='gzip', index=False)
175+
self.df.to_csv(csv_file_name, **csv_kwargs)
159176
else:
160177
data['catalog_dict'] = self.df.to_dict(orient='records')
161178

162179
with open(json_file_name, 'w') as outfile:
163-
json.dump(data, outfile, indent=2)
180+
json_kwargs = {'indent': 2}
181+
json_kwargs.update(json_dump_kwargs or {})
182+
json.dump(data, outfile, **json_kwargs)
164183

165184
print(f'Successfully wrote ESM collection json file to: {json_file_name}')
166185

intake_esm/core.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def serialize(
365365
name: pydantic.StrictStr,
366366
directory: typing.Union[pydantic.DirectoryPath, pydantic.StrictStr] = None,
367367
catalog_type: str = 'dict',
368+
to_csv_kwargs: typing.Dict[typing.Any, typing.Any] = None,
369+
json_dump_kwargs: typing.Dict[typing.Any, typing.Any] = None,
368370
) -> None:
369371
"""Serialize collection/catalog to corresponding json and csv files.
370372
@@ -376,6 +378,10 @@ def serialize(
376378
The path to the local directory. If None, use the current directory
377379
catalog_type: str, default 'dict'
378380
Whether to save the catalog table as a dictionary in the JSON file or as a separate CSV file.
381+
to_csv_kwargs : dict, optional
382+
Additional keyword arguments passed through to the :py:meth:`~pandas.DataFrame.to_csv` method.
383+
json_dump_kwargs : dict, optional
384+
Additional keyword arguments passed through to the :py:func:`~json.dump` function.
379385
380386
Notes
381387
-----
@@ -395,7 +401,13 @@ def serialize(
395401
>>> col_subset.serialize(name="cmip6_bcc_esm1", catalog_type="file")
396402
"""
397403

398-
self.esmcat.save(name, directory=directory, catalog_type=catalog_type)
404+
self.esmcat.save(
405+
name,
406+
directory=directory,
407+
catalog_type=catalog_type,
408+
to_csv_kwargs=to_csv_kwargs,
409+
json_dump_kwargs=json_dump_kwargs,
410+
)
399411

400412
def nunique(self) -> pd.Series:
401413
"""Count distinct observations across dataframe columns

tests/test_core.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,24 @@ def test_catalog_getitem_error():
147147
cat['foo']
148148

149149

150-
@pytest.mark.parametrize('catalog_type', ['file', 'dict'])
151-
def test_catalog_serialize(tmp_path, catalog_type):
150+
@pytest.mark.parametrize(
151+
'catalog_type, to_csv_kwargs, json_dump_kwargs',
152+
[('file', {'compression': 'bz2'}, {}), ('file', {'compression': 'gzip'}, {}), ('dict', {}, {})],
153+
)
154+
def test_catalog_serialize(tmp_path, catalog_type, to_csv_kwargs, json_dump_kwargs):
152155
cat = intake.open_esm_datastore(cdf_col_sample_cmip6)
153156
local_store = tmp_path
154157
cat_subset = cat.search(
155158
source_id='MRI-ESM2-0',
156159
)
157160
name = 'CMIP6-MRI-ESM2-0'
158-
cat_subset.serialize(name=name, directory=local_store, catalog_type=catalog_type)
161+
cat_subset.serialize(
162+
name=name,
163+
directory=local_store,
164+
catalog_type=catalog_type,
165+
to_csv_kwargs=to_csv_kwargs,
166+
json_dump_kwargs=json_dump_kwargs,
167+
)
159168
cat = intake.open_esm_datastore(f'{local_store}/{name}.json')
160169
pd.testing.assert_frame_equal(
161170
cat_subset.df.reset_index(drop=True), cat.df.reset_index(drop=True)

0 commit comments

Comments
 (0)