diff --git a/CHANGES.md b/CHANGES.md index 6e8ece802..ae3109a46 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626) - Set train/validation on criterion if it's a PyTorch module (#621) - Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605). +- `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658) ### Fixed diff --git a/skorch/tests/test_utils.py b/skorch/tests/test_utils.py index 1017cac67..44969bee9 100644 --- a/skorch/tests/test_utils.py +++ b/skorch/tests/test_utils.py @@ -135,6 +135,69 @@ def test_sparse_tensor_not_accepted_raises(self, to_tensor, device): assert exc.value.args[0] == msg +class TestToNumpy: + @pytest.fixture + def to_numpy(self): + from skorch.utils import to_numpy + return to_numpy + + @pytest.fixture + def x_tensor(self): + return torch.zeros(3, 4) + + @pytest.fixture + def x_tuple(self): + return torch.ones(3), torch.zeros(3, 4) + + @pytest.fixture + def x_list(self): + return [torch.ones(3), torch.zeros(3, 4)] + + @pytest.fixture + def x_dict(self): + return {'a': torch.ones(3), 'b': (torch.zeros(2), torch.zeros(3))} + + def compare_array_to_tensor(self, x_numpy, x_tensor): + assert isinstance(x_tensor, torch.Tensor) + assert isinstance(x_numpy, np.ndarray) + assert x_numpy.shape == x_tensor.shape + for a, b in zip(x_numpy.flatten(), x_tensor.flatten()): + assert np.isclose(a, b.item()) + + def test_tensor(self, to_numpy, x_tensor): + x_numpy = to_numpy(x_tensor) + self.compare_array_to_tensor(x_numpy, x_tensor) + + def test_list(self, to_numpy, x_list): + x_numpy = to_numpy(x_list) + for entry_numpy, entry_torch in zip(x_numpy, x_list): + self.compare_array_to_tensor(entry_numpy, entry_torch) + + def test_tuple(self, to_numpy, x_tuple): + x_numpy = to_numpy(x_tuple) + for entry_numpy, entry_torch in zip(x_numpy, x_tuple): + self.compare_array_to_tensor(entry_numpy, entry_torch) + + def test_dict(self, to_numpy, x_dict): + x_numpy = to_numpy(x_dict) + self.compare_array_to_tensor(x_numpy['a'], x_dict['a']) + self.compare_array_to_tensor(x_numpy['b'][0], x_dict['b'][0]) + self.compare_array_to_tensor(x_numpy['b'][1], x_dict['b'][1]) + + @pytest.mark.parametrize('x_invalid', [ + 1, + [1,2,3], + (1,2,3), + {'a': 1}, + ]) + def test_invalid_inputs(self, to_numpy, x_invalid): + # Inputs that are invalid for the scope of to_numpy. + with pytest.raises(TypeError) as e: + to_numpy(x_invalid) + expected = "Cannot convert this data type to a numpy array." + assert e.value.args[0] == expected + + class TestToDevice: @pytest.fixture def to_device(self): @@ -155,13 +218,17 @@ def x_dict(self): 'x': torch.zeros(3), 'y': torch.ones((4, 5)) } - + @pytest.fixture def x_pad_seq(self): value = torch.zeros((5, 3)).float() length = torch.as_tensor([2, 2, 1]) return pack_padded_sequence(value, length) + @pytest.fixture + def x_list(self): + return [torch.zeros(3), torch.ones(2, 4)] + def check_device_type(self, tensor, device_input, prev_device): """assert expected device type conditioned on the input argument for `to_device`""" if None is device_input: @@ -214,7 +281,7 @@ def test_check_device_tuple_torch_tensor( x_tup = to_device(x_tup, device=device_to) for xi, prev_d in zip(x_tup, prev_devices): self.check_device_type(xi, device_to, prev_d) - + @pytest.mark.parametrize('device_from, device_to', [ ('cpu', 'cpu'), ('cpu', 'cuda'), @@ -244,7 +311,7 @@ def test_check_device_dict_torch_tensor( assert x_dict.keys() == original_x_dict.keys() for k in x_dict: assert np.allclose(x_dict[k], original_x_dict[k]) - + @pytest.mark.parametrize('device_from, device_to', [ ('cpu', 'cpu'), ('cpu', 'cuda'), @@ -267,6 +334,36 @@ def test_check_device_packed_padded_sequence( x_pad_seq = to_device(x_pad_seq, device=device_to) self.check_device_type(x_pad_seq.data, device_to, prev_device) + @pytest.mark.parametrize('device_from, device_to', [ + ('cpu', 'cpu'), + ('cpu', 'cuda'), + ('cuda', 'cpu'), + ('cuda', 'cuda'), + (None, None), + ]) + def test_nested_data(self, to_device, x_list, device_from, device_to): + # Sometimes data is nested because it would need to be padded so it's + # easier to return a list of tensors with different shapes. + # to_device should honor this. + if 'cuda' in (device_from, device_to) and not torch.cuda.is_available(): + pytest.skip() + + prev_devices = [None for _ in range(len(x_list))] + if None in (device_from, device_to): + prev_devices = [x.device.type for x in x_list] + + x_list = to_device(x_list, device=device_from) + assert isinstance(x_list, list) + + for xi, prev_d in zip(x_list, prev_devices): + self.check_device_type(xi, device_from, prev_d) + + x_list = to_device(x_list, device=device_to) + assert isinstance(x_list, list) + + for xi, prev_d in zip(x_list, prev_devices): + self.check_device_type(xi, device_to, prev_d) + class TestDuplicateItems: @pytest.fixture diff --git a/skorch/utils.py b/skorch/utils.py index 3793a0eb2..bf412760b 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -104,6 +104,10 @@ def to_tensor(X, device, accept_sparse=False): def to_numpy(X): """Generic function to convert a pytorch tensor to numpy. + This function tries to unpack the tensor(s) from supported + data structures (e.g., dicts, lists, etc.) but doesn't go + beyond. + Returns X when it already is a numpy array. """ @@ -116,6 +120,9 @@ def to_numpy(X): if is_pandas_ndframe(X): return X.values + if isinstance(X, (tuple, list)): + return type(X)(to_numpy(x) for x in X) + if not is_torch_data_type(X): raise TypeError("Cannot convert this data type to a numpy array.") @@ -154,8 +161,8 @@ def to_device(X, device): return {key: to_device(val,device) for key, val in X.items()} # PackedSequence class inherits from a namedtuple - if isinstance(X, tuple) and (type(X) != PackedSequence): - return tuple(x.to(device) for x in X) + if isinstance(X, (tuple, list)) and (type(X) != PackedSequence): + return type(X)(to_device(x, device) for x in X) return X.to(device)