Skip to content

Commit ce4b445

Browse files
committed
yews.transform under cover with 100% coverage.
1 parent 2cf6108 commit ce4b445

File tree

4 files changed

+98
-44
lines changed

4 files changed

+98
-44
lines changed

tests/test_transforms.py

+37
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,19 @@
55
import torch
66
import numpy as np
77

8+
9+
class DummpyBaseTransform(transforms.BaseTransform):
10+
11+
def __init__(self, a=0, b=1):
12+
self.a = a
13+
self.b = b
14+
15+
def __call__(self, data):
16+
return data
17+
18+
819
class TestIsNumpyWaveform:
20+
921
def test_single_channel_waveform_vector(self):
1022
wav = np.empty(10)
1123
assert F._is_numpy_waveform(wav)
@@ -26,7 +38,9 @@ def test_invalid_waveform_wrong_type(self):
2638
wav = torch.tensor(10)
2739
assert not F._is_numpy_waveform(wav)
2840

41+
2942
class TestToTensor:
43+
3044
def test_type_exception(self):
3145
wav = torch.tensor(10)
3246
with pytest.raises(TypeError):
@@ -42,7 +56,19 @@ def test_multi_channel_waveform(self):
4256
tensor = torch.zeros(3, 10,dtype=torch.float)
4357
assert torch.allclose(F._to_tensor(wav), tensor)
4458

59+
class TestBaseTransform:
60+
61+
def test_raise_call_notimplementederror(self):
62+
with pytest.raises(NotImplementedError):
63+
t = transforms.BaseTransform()
64+
t(0)
65+
66+
def test_repr(self):
67+
t = transforms.BaseTransform()
68+
assert type(t.__repr__()) is str
69+
4570
class TestMandatoryMethods:
71+
4672
def test_call_method(self):
4773
assert all([hasattr(getattr(transforms, t), '__call__') for t in
4874
transforms.transforms.__all__])
@@ -51,6 +77,13 @@ def test_repr_method(self):
5177
assert all([hasattr(getattr(transforms, t), '__repr__') for t in
5278
transforms.transforms.__all__])
5379

80+
81+
class TestComposeTransform:
82+
83+
def test_repr(self):
84+
t = transforms.Compose([DummpyBaseTransform()])
85+
assert type(t.__repr__()) is str
86+
5487
class TestTransformCorrectness:
5588
def test_compose(self):
5689
wav = np.array([1, 3])
@@ -80,3 +113,7 @@ def test_cut_waveform_value(self):
80113
assert np.allclose(transforms.CutWaveform(100, 1900)(wav),
81114
wav[:, 100:1900])
82115

116+
def test_soft_clip(self):
117+
wav = np.array([-1, -0.5, 0, 0.5, 1])
118+
assert np.allclose(transforms.SoftClip()(wav),
119+
np.array([0.26894142, 0.37754067, 0.5, 0.62245933, 0.73105858]))

yews/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .base import BaseTransform, Compose
12
from .transforms import *

yews/transforms/base.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
class BaseTransform(object):
2+
"""An abstract class representing a Transform.
3+
4+
All other transform should subclass it. All subclasses should override
5+
``__call__`` which performs the transform.
6+
7+
Args:
8+
root (object): Source of the dataset.
9+
sample_transform (callable, optional): A function/transform that takes
10+
a sample and returns a transformed version.
11+
target_transform (callable, optional): A function/transform that takes
12+
a target and transform it.
13+
14+
Attributes:
15+
samples (dataset-like object): Dataset-like object for samples.
16+
targets (dataset-like object): Dataset-like object for targets.
17+
18+
"""
19+
20+
def __call__(self, data):
21+
raise NotImplementedError
22+
23+
def __repr__(self):
24+
head = self.__class__.__name__
25+
content = [f"{key} = {val}" for key, val in self.__dict__.items()]
26+
body = ", ".join(content)
27+
return f"{head}({body})"
28+
29+
30+
class Compose(BaseTransform):
31+
"""Composes several transforms together.
32+
Args:
33+
transforms (list of ``Transform`` objects): list of transforms to compose.
34+
Example:
35+
>>> transforms.Compose([
36+
>>> transforms.CenterCrop(10),
37+
>>> transforms.ToTensor(),
38+
>>> ])
39+
"""
40+
41+
def __init__(self, transforms):
42+
self.transforms = transforms
43+
44+
def __call__(self, wav):
45+
for t in self.transforms:
46+
wav = t(wav)
47+
return wav
48+
49+
def __repr__(self):
50+
format_string = self.__class__.__name__ + '('
51+
for t in self.transforms:
52+
format_string += '\n'
53+
format_string += ' {0}'.format(t)
54+
format_string += '\n)'
55+
return format_string

yews/transforms/transforms.py

+5-44
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,14 @@
1+
from .base import BaseTransform
12
from . import functional as F
23

34
__all__ = [
4-
"Compose",
55
"ToTensor",
66
"ZeroMean",
77
"SoftClip",
88
"CutWaveform",
99
]
1010

11-
class Compose(object):
12-
"""Composes several transforms together.
13-
Args:
14-
transforms (list of ``Transform`` objects): list of transforms to compose.
15-
Example:
16-
>>> transforms.Compose([
17-
>>> transforms.CenterCrop(10),
18-
>>> transforms.ToTensor(),
19-
>>> ])
20-
"""
21-
22-
def __init__(self, transforms):
23-
self.transforms = transforms
24-
25-
def __call__(self, wav):
26-
for t in self.transforms:
27-
wav = t(wav)
28-
return wav
29-
30-
def __repr__(self):
31-
format_string = self.__class__.__name__ + '('
32-
for t in self.transforms:
33-
format_string += '\n'
34-
format_string += ' {0}'.format(t)
35-
format_string += '\n)'
36-
return format_string
37-
38-
39-
class ToTensor(object):
11+
class ToTensor(BaseTransform):
4012
"""Convert a ``numpy.ndarray`` to tensor.
4113
4214
Converts a numpy.ndarray (C x S) to a torch.FloatTensor of shape (C x S).
@@ -51,11 +23,8 @@ def __call__(self, wav):
5123
"""
5224
return F._to_tensor(wav)
5325

54-
def __repr__(self):
55-
return self.__class__.__name__ + '()'
56-
5726

58-
class SoftClip(object):
27+
class SoftClip(BaseTransform):
5928
"""Soft clip input to compress large amplitude signals
6029
6130
"""
@@ -66,11 +35,8 @@ def __init__(self, scale=1):
6635
def __call__(self, wav):
6736
return F.expit(wav * self.scale)
6837

69-
def __repr__(self):
70-
return self.__class__.__name__ + f'(scale = {self.scale})'
7138

72-
73-
class ZeroMean(object):
39+
class ZeroMean(BaseTransform):
7440
"""Remove mean from each waveforms
7541
7642
"""
@@ -80,11 +46,8 @@ def __call__(self, wav):
8046
wav -= wav.mean(axis=0)
8147
return wav.T
8248

83-
def __repr__(self):
84-
return self.__class__.__name__ + '()'
85-
8649

87-
class CutWaveform(object):
50+
class CutWaveform(BaseTransform):
8851
"""Cut a portion of waveform.
8952
9053
"""
@@ -96,5 +59,3 @@ def __init__(self, samplestart, sampleend):
9659
def __call__(self, wav):
9760
return wav[:, self.start:self.end]
9861

99-
def __repr__(self):
100-
return self.__call__.__name__ + f'(start = {self.start}, end = {self.end})'

0 commit comments

Comments
 (0)