Skip to content

Commit 0790338

Browse files
DanilBaibakfacebook-github-bot
authored andcommitted
Forward fix / Update dill_available API for torchdata (#1222)
Summary: Pull Request resolved: #1222 Changes from the PyTorch repo (D53082622) broke torchdata. I updated the dill_available API for torchdata to keep everything in sync. Reviewed By: atalman, ejguan Differential Revision: D53086369 fbshipit-source-id: 7344c4cd3205a38689722330721257b5a01bd32f
1 parent d727f63 commit 0790338

File tree

4 files changed

+17
-14
lines changed

4 files changed

+17
-14
lines changed

test/test_serialization.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import torchdata.datapipes.iter as iterdp
1818
import torchdata.datapipes.map as mapdp
1919
from _utils._common_utils_for_test import create_temp_dir, create_temp_files
20-
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE
20+
from torch.utils._import_utils import dill_available
2121
from torchdata.datapipes.iter import IterableWrapper
2222
from torchdata.datapipes.map import SequenceWrapper
2323

24-
if DILL_AVAILABLE:
24+
if dill_available():
2525
import dill
2626

2727
dill.extend(use_dill=False)
@@ -87,7 +87,7 @@ def _filter_by_module_availability(datapipes):
8787
filter_set.update([iterdp.IoPathFileLister, iterdp.IoPathFileOpener, iterdp.IoPathSaver])
8888
if rarfile is None:
8989
filter_set.update([iterdp.RarArchiveLoader])
90-
if torcharrow is None or not DILL_AVAILABLE:
90+
if torcharrow is None or not dill_available():
9191
filter_set.update([iterdp.DataFrameMaker, iterdp.ParquetDataFrameLoader])
9292
return [dp for dp in datapipes if dp[0] not in filter_set]
9393

@@ -374,7 +374,7 @@ def test_serializable_with_dill(self) -> None:
374374
# Skipping value comparison for these DataPipes
375375
dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator}
376376
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
377-
if DILL_AVAILABLE:
377+
if dill_available():
378378
try:
379379
if dpipe in dp_skip_comparison: # Make sure they are picklable/loadable (no value comparison)
380380
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]

torchdata/datapipes/iter/util/cacheholder.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
except ImportError:
2222
portalocker = None
2323

24-
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
24+
from torch.utils._import_utils import dill_available
25+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
2526

2627
from torch.utils.data.graph import traverse_dps
2728
from torchdata.datapipes import functional_datapipe
2829
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
2930

30-
if DILL_AVAILABLE:
31+
if dill_available():
3132
import dill
3233

3334
dill.extend(use_dill=False)

torchdata/datapipes/iter/util/converter.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
from typing import Callable, Dict, Optional
1010

11+
from torch.utils._import_utils import dill_available
12+
1113
from torch.utils.data import IterDataPipe, MapDataPipe
12-
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
14+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1315

14-
if DILL_AVAILABLE:
16+
if dill_available():
1517
import dill
1618

1719
dill.extend(use_dill=False)
@@ -108,7 +110,7 @@ def __len__(self):
108110
return len(self._map) # type: ignore[arg-type]
109111

110112
def __getstate__(self):
111-
if DILL_AVAILABLE:
113+
if dill_available():
112114
dill_key_value_fn = dill.dumps(self.key_value_fn)
113115
else:
114116
dill_key_value_fn = self.key_value_fn
@@ -120,7 +122,7 @@ def __getstate__(self):
120122

121123
def __setstate__(self, state):
122124
(self.datapipe, dill_key_value_fn, self._map) = state
123-
if DILL_AVAILABLE:
125+
if dill_available():
124126
self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment]
125127
else:
126128
self.key_value_fn = dill_key_value_fn # type: ignore[assignment]

torchdata/datapipes/iter/util/dataframemaker.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from typing import List, Optional, TypeVar
99

10-
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE
10+
from torch.utils._import_utils import dill_available
1111

1212
from torchdata.datapipes import functional_datapipe
1313
from torchdata.datapipes.iter import IterDataPipe
@@ -19,7 +19,7 @@
1919
torcharrow = None
2020
parquet = None
2121

22-
if DILL_AVAILABLE:
22+
if dill_available():
2323
import dill
2424

2525
dill.extend(use_dill=False)
@@ -150,7 +150,7 @@ def __iter__(self):
150150
yield torcharrow.from_arrow(row_group, dtype=self.dtype)
151151

152152
def __getstate__(self):
153-
if DILL_AVAILABLE:
153+
if dill_available():
154154
dill_dtype = dill.dumps(self.dtype)
155155
else:
156156
dill_dtype = self.dtype
@@ -161,7 +161,7 @@ def __getstate__(self):
161161

162162
def __setstate__(self, state):
163163
(self.source_dp, dill_dtype, self.columns, self.device, self.use_threads) = state
164-
if DILL_AVAILABLE:
164+
if dill_available():
165165
self.dtype = dill.loads(dill_dtype) # type: ignore[assignment]
166166
else:
167167
self.dtype = dill_dtype # type: ignore[assignment]

0 commit comments

Comments
 (0)