Skip to content

Commit

Permalink
Preserve order of index in augmentation (#36)
Browse files Browse the repository at this point in the history
* Add test case for unsorted index

* Preserve order of index in augmentation

* Improve code and comments

* Make test also check correct position

* Fix typos

Co-authored-by: audeerington <99745980+audeerington@users.noreply.github.com>

---------

Co-authored-by: Anna Derington <aderington@audeering.com>
Co-authored-by: audeerington <99745980+audeerington@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent a665a49 commit d0403a7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 100 deletions.
202 changes: 102 additions & 100 deletions auglib/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,111 +377,70 @@ def _augment_index(
remove_root: str,
description: str,
) -> pd.Index:
r"""Augment segments and store augmented files to cache."""
files = index.get_level_values(0).unique()
augmented_files = _augmented_files(
files,
cache_root,
remove_root,
)
r"""Augment segments and store augmented files to cache.
Args:
index: segmented index
cache_root: cache root for augmented files
remove_root: directory that should be removed from
the beginning of the original file path before joining with
``cache_root``
description: text to display in progress bar
Returns:
segmented index of augmented files
"""
files = index.get_level_values("file")
starts = index.get_level_values("start")
ends = index.get_level_values("end")
augmented_files = _augmented_files(files, cache_root, remove_root)
params = [
(
(
file,
out_file,
index[index.get_level_values(0) == file].droplevel(
0
), # start, end values for given file
),
{},
)
for file, out_file in zip(files, augmented_files)
((file, start, end, out_file), {})
for file, start, end, out_file in zip(files, starts, ends, augmented_files)
]

verbose = self.verbose
self.verbose = False # avoid nested progress bar
augmented_indices = audeer.run_tasks(
durations = audeer.run_tasks(
self._augment_file_to_cache,
params,
num_workers=self.num_workers,
multiprocessing=self.multiprocessing,
progress_bar=verbose,
progress_bar=self.verbose,
task_description=description,
)
self.verbose = verbose

augmented_index = audformat.utils.union(augmented_indices)

return augmented_index

def _augment_file(
self,
file: str,
index: pd.Index,
) -> typing.Tuple[typing.List, typing.List, typing.List, int]:
r"""Augment file at every segment."""
signal, sampling_rate = audiofile.read(file, always_2d=True)
signals, starts, ends, augmented_rate = self._augment_signal(
signal,
sampling_rate,
index,
augmented_index = audformat.segmented_index(
augmented_files,
[0] * len(durations),
durations,
)
return signals, starts, ends, augmented_rate
return augmented_index

def _augment_file_to_cache(
self,
file: str,
start: pd.Timedelta,
end: pd.Timedelta,
augmented_file: str,
index: pd.Index, # containing (several) start, end values
) -> pd.Index:
) -> float:
r"""Augment file and store to cache.
Store augmented signals in separate files,
adding a counter at the end of the filename:
<augmented_file>-0
<augmented_file>-1
...
If we have more than 10 files,
the counter will use two digits:
Before augmenting the file,
it is also resampled,
or remixed,
if required.
<augmented_file>-00
<augmented_file>-01
...
Args:
file: path to incoming audio file
start: start time to read ``file``
end: end time to read ``file``
augmented_file: path of augmented file
and so on.
Returns:
duration of augmented file in seconds
"""
signals, starts, ends, sampling_rate = self._augment_file(file, index)

audeer.mkdir(os.path.dirname(augmented_file))
if len(signals) > 1:
# number of needed digits for file names
digits = len(str(len(signals) - 1))
root, ext = os.path.splitext(augmented_file)
files = [f"{root}-{str(n).zfill(digits)}{ext}" for n in range(len(signals))]
else:
files = [augmented_file]
for file, signal in zip(files, signals):
audiofile.write(file, signal, sampling_rate)

# insert augmented file name at first level
augmented_index_with_file = audformat.segmented_index(
files,
starts,
ends,
signal, sampling_rate = audinterface.utils.read_audio(
file, start=start, end=end
)

return augmented_index_with_file

def _augment_signal(
self,
signal: np.ndarray,
sampling_rate: int,
index: pd.Index,
) -> typing.Tuple[typing.List, typing.List, typing.List, int]:
r"""Augment signal at every segment in index."""
signal, sampling_rate = preprocess_signal(
signal,
sampling_rate=sampling_rate,
Expand All @@ -490,17 +449,11 @@ def _augment_signal(
channels=self.channels,
mixdown=self.mixdown,
)
y = self.process_signal_from_index(signal, sampling_rate, index)
# adjust index to always start at 0
# and end at NaT or duration
signals = list(y.values)
starts = [0] * len(signals)
ends = y.index.get_level_values("end")
ends = [
end if pd.isna(end) else signal.shape[1] / sampling_rate
for signal, end in zip(signals, ends)
]
return signals, starts, ends, sampling_rate
augmented_signal = self(signal, sampling_rate)
audeer.mkdir(os.path.dirname(augmented_file))
audiofile.write(augmented_file, augmented_signal, sampling_rate)
duration = augmented_signal.shape[1] / sampling_rate
return duration

@staticmethod
def _process_func(
Expand Down Expand Up @@ -555,8 +508,55 @@ def _augmented_files(
files: typing.Sequence[str],
cache_root: str,
remove_root: str = None,
) -> typing.Sequence[str]:
r"""Return cache file names by joining with the cache directory."""
) -> typing.List[str]:
r"""Return destination path for augmented files.
If files contain the same filename several times,
e.g. when augmenting segments,
it will convert them into unique filenames:
<augmented_file>-0
<augmented_file>-1
...
As segments are stored as single files.
If we have more than 10 files,
the counter will use two digits:
<augmented_file>-00
<augmented_file>-01
...
and so on.
Args:
files: files to augment
cache_root: cache root of augmented files
remove_root: directory that should be removed from
the beginning of the original file path before joining with
``cache_root``
Returns:
path of augmented files
"""
# Estimate number of segments/samples for each file
unique_files, counts = np.unique(files, return_counts=True)
counts = {file: count for file, count in zip(unique_files, counts)}
current_count = {file: 0 for file in unique_files}

augmented_files = []
for file in files:
if counts[file] > 1:
digits = len(str(counts[file] - 1))
root, ext = os.path.splitext(file)
augmented_file = f"{root}-{str(current_count[file]).zfill(digits)}{ext}"
else:
augmented_file = file
current_count[file] += 1
augmented_files.append(augmented_file)

if remove_root is None:

def join(path1: str, path2: str) -> str:
Expand All @@ -566,14 +566,16 @@ def join(path1: str, path2: str) -> str:
os.path.splitdrive(path2)[1].lstrip(seps),
)

augmented_files = [join(cache_root, file) for file in files]
augmented_files = [join(cache_root, file) for file in augmented_files]
else:
remove_root = audeer.path(remove_root)
dirs = [os.path.dirname(file) for file in files]
dirs = [os.path.dirname(file) for file in unique_files]
common_root = audeer.common_directory(dirs)
if not audeer.common_directory([remove_root, common_root]) == remove_root:
raise RuntimeError(
f"Cannot remove '{remove_root}' " f"from '{common_root}'."
)
augmented_files = [file.replace(remove_root, cache_root, 1) for file in files]
augmented_files = [
file.replace(remove_root, cache_root, 1) for file in augmented_files
]
return augmented_files
23 changes: 23 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,29 @@
np.ones((1, 1)),
),
),
# Unsorted index
(
audformat.segmented_index(
["f1.wav", "f2.wav", "f1.wav", "f2.wav"],
[0.2, 0.0, 0.0, 0.3],
[1.0, 0.3, 0.2, 1.0],
),
np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype="float32"),
10,
auglib.transform.Function(lambda x, _: x + 1),
True,
audformat.segmented_index(
["f1-0.wav", "f2-0.wav", "f1-1.wav", "f2-1.wav"],
[0, 0, 0, 0],
[0.8, 0.3, 0.2, 0.7],
),
(
np.array([[3, 4, 5, 6, 7, 8, 9, 10]], dtype="float32"),
np.array([[1, 2, 3]], dtype="float32"),
np.array([[1, 2]], dtype="float32"),
np.array([[4, 5, 6, 7, 8, 9, 10]], dtype="float32"),
),
),
],
)
def test_augment(
Expand Down

0 comments on commit d0403a7

Please sign in to comment.