diff --git a/auglib/core/interface.py b/auglib/core/interface.py index c7bb2ad..113765b 100644 --- a/auglib/core/interface.py +++ b/auglib/core/interface.py @@ -377,111 +377,70 @@ def _augment_index( remove_root: str, description: str, ) -> pd.Index: - r"""Augment segments and store augmented files to cache.""" - files = index.get_level_values(0).unique() - augmented_files = _augmented_files( - files, - cache_root, - remove_root, - ) + r"""Augment segments and store augmented files to cache. + + Args: + index: segmented index + cache_root: cache root for augmented files + remove_root: directory that should be removed from + the beginning of the original file path before joining with + ``cache_root`` + description: text to display in progress bar + + Returns: + segmented index of augmented files + + """ + files = index.get_level_values("file") + starts = index.get_level_values("start") + ends = index.get_level_values("end") + augmented_files = _augmented_files(files, cache_root, remove_root) params = [ - ( - ( - file, - out_file, - index[index.get_level_values(0) == file].droplevel( - 0 - ), # start, end values for given file - ), - {}, - ) - for file, out_file in zip(files, augmented_files) + ((file, start, end, out_file), {}) + for file, start, end, out_file in zip(files, starts, ends, augmented_files) ] - - verbose = self.verbose - self.verbose = False # avoid nested progress bar - augmented_indices = audeer.run_tasks( + durations = audeer.run_tasks( self._augment_file_to_cache, params, num_workers=self.num_workers, multiprocessing=self.multiprocessing, - progress_bar=verbose, + progress_bar=self.verbose, task_description=description, ) - self.verbose = verbose - - augmented_index = audformat.utils.union(augmented_indices) - - return augmented_index - - def _augment_file( - self, - file: str, - index: pd.Index, - ) -> typing.Tuple[typing.List, typing.List, typing.List, int]: - r"""Augment file at every segment.""" - signal, sampling_rate = audiofile.read(file, always_2d=True) - signals, starts, ends, augmented_rate = self._augment_signal( - signal, - sampling_rate, - index, + augmented_index = audformat.segmented_index( + augmented_files, + [0] * len(durations), + durations, ) - return signals, starts, ends, augmented_rate + return augmented_index def _augment_file_to_cache( self, file: str, + start: pd.Timedelta, + end: pd.Timedelta, augmented_file: str, - index: pd.Index, # containing (several) start, end values - ) -> pd.Index: + ) -> float: r"""Augment file and store to cache. - Store augmented signals in separate files, - adding a counter at the end of the filename: - - -0 - -1 - ... - - If we have more than 10 files, - the counter will use two digits: + Before augmenting the file, + it is also resampled, + or remixed, + if required. - -00 - -01 - ... + Args: + file: path to incoming audio file + start: start time to read ``file`` + end: end time to read ``file`` + augmented_file: path of augmented file - and so on. + Returns: + duration of augmented file in seconds """ - signals, starts, ends, sampling_rate = self._augment_file(file, index) - - audeer.mkdir(os.path.dirname(augmented_file)) - if len(signals) > 1: - # number of needed digits for file names - digits = len(str(len(signals) - 1)) - root, ext = os.path.splitext(augmented_file) - files = [f"{root}-{str(n).zfill(digits)}{ext}" for n in range(len(signals))] - else: - files = [augmented_file] - for file, signal in zip(files, signals): - audiofile.write(file, signal, sampling_rate) - - # insert augmented file name at first level - augmented_index_with_file = audformat.segmented_index( - files, - starts, - ends, + signal, sampling_rate = audinterface.utils.read_audio( + file, start=start, end=end ) - - return augmented_index_with_file - - def _augment_signal( - self, - signal: np.ndarray, - sampling_rate: int, - index: pd.Index, - ) -> typing.Tuple[typing.List, typing.List, typing.List, int]: - r"""Augment signal at every segment in index.""" signal, sampling_rate = preprocess_signal( signal, sampling_rate=sampling_rate, @@ -490,17 +449,11 @@ def _augment_signal( channels=self.channels, mixdown=self.mixdown, ) - y = self.process_signal_from_index(signal, sampling_rate, index) - # adjust index to always start at 0 - # and end at NaT or duration - signals = list(y.values) - starts = [0] * len(signals) - ends = y.index.get_level_values("end") - ends = [ - end if pd.isna(end) else signal.shape[1] / sampling_rate - for signal, end in zip(signals, ends) - ] - return signals, starts, ends, sampling_rate + augmented_signal = self(signal, sampling_rate) + audeer.mkdir(os.path.dirname(augmented_file)) + audiofile.write(augmented_file, augmented_signal, sampling_rate) + duration = augmented_signal.shape[1] / sampling_rate + return duration @staticmethod def _process_func( @@ -555,8 +508,55 @@ def _augmented_files( files: typing.Sequence[str], cache_root: str, remove_root: str = None, -) -> typing.Sequence[str]: - r"""Return cache file names by joining with the cache directory.""" +) -> typing.List[str]: + r"""Return destination path for augmented files. + + If files contain the same filename several times, + e.g. when augmenting segments, + it will convert them into unique filenames: + + -0 + -1 + ... + + As segments are stored as single files. + + If we have more than 10 files, + the counter will use two digits: + + -00 + -01 + ... + + and so on. + + Args: + files: files to augment + cache_root: cache root of augmented files + remove_root: directory that should be removed from + the beginning of the original file path before joining with + ``cache_root`` + + Returns: + path of augmented files + + """ + # Estimate number of segments/samples for each file + unique_files, counts = np.unique(files, return_counts=True) + counts = {file: count for file, count in zip(unique_files, counts)} + current_count = {file: 0 for file in unique_files} + + augmented_files = [] + for file in files: + if counts[file] > 1: + digits = len(str(counts[file] - 1)) + root, ext = os.path.splitext(file) + augmented_file = f"{root}-{str(current_count[file]).zfill(digits)}{ext}" + else: + augmented_file = file + current_count[file] += 1 + augmented_files.append(augmented_file) + if remove_root is None: def join(path1: str, path2: str) -> str: @@ -566,14 +566,16 @@ def join(path1: str, path2: str) -> str: os.path.splitdrive(path2)[1].lstrip(seps), ) - augmented_files = [join(cache_root, file) for file in files] + augmented_files = [join(cache_root, file) for file in augmented_files] else: remove_root = audeer.path(remove_root) - dirs = [os.path.dirname(file) for file in files] + dirs = [os.path.dirname(file) for file in unique_files] common_root = audeer.common_directory(dirs) if not audeer.common_directory([remove_root, common_root]) == remove_root: raise RuntimeError( f"Cannot remove '{remove_root}' " f"from '{common_root}'." ) - augmented_files = [file.replace(remove_root, cache_root, 1) for file in files] + augmented_files = [ + file.replace(remove_root, cache_root, 1) for file in augmented_files + ] return augmented_files diff --git a/tests/test_interface.py b/tests/test_interface.py index 3f35a6c..796d47d 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -374,6 +374,29 @@ np.ones((1, 1)), ), ), + # Unsorted index + ( + audformat.segmented_index( + ["f1.wav", "f2.wav", "f1.wav", "f2.wav"], + [0.2, 0.0, 0.0, 0.3], + [1.0, 0.3, 0.2, 1.0], + ), + np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype="float32"), + 10, + auglib.transform.Function(lambda x, _: x + 1), + True, + audformat.segmented_index( + ["f1-0.wav", "f2-0.wav", "f1-1.wav", "f2-1.wav"], + [0, 0, 0, 0], + [0.8, 0.3, 0.2, 0.7], + ), + ( + np.array([[3, 4, 5, 6, 7, 8, 9, 10]], dtype="float32"), + np.array([[1, 2, 3]], dtype="float32"), + np.array([[1, 2]], dtype="float32"), + np.array([[4, 5, 6, 7, 8, 9, 10]], dtype="float32"), + ), + ), ], ) def test_augment(