-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdenoising.py
70 lines (52 loc) · 2.31 KB
/
denoising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from denoiser.dsp import convert_audio
from denoiser import pretrained
import torchaudio
import torch
import logging
logging.basicConfig(level=logging.INFO, format='[DSR_MODULE]%(asctime)s %(levelname)s %(message)s')
def _check_parallel_device_list():
if not torch.cuda.is_available():
return ["cpu"]
device_list = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
return device_list
def load_audio(
path: str = None,
verbose: bool = False
):
if path is None: raise ValueError(f"path argument is required. Excepted: str, but got {path}")
if verbose: logging.info(f"Loading audio from {path}...")
audio, sample_rate = torchaudio.load(path, format="wav")
if verbose: logging.info("Done!")
return audio, sample_rate
def denoising(
audio: torch.Tensor = None,
sample_rate: int = None,
device: torch.device or str = None,
inference_parallel: bool = False,
verbose: bool = False
):
if device is None: raise ValueError(f"device argument is required. Excepted: 'cuda' or 'cpu', but got {device}")
if sample_rate is None: raise ValueError(f"sample_rate argument is required. Excepted: int, but got {sample_rate}")
if audio is None: raise ValueError(f"audio argument is required. Excepted: torch.Tensor, but got {audio}")
if verbose: logging.info("Loading model...")
model = pretrained.dns64(pretrained=True).to(device)
model_sample_rate = model.sample_rate
model_chin = model.chin
if inference_parallel:
device_list = _check_parallel_device_list()
if verbose: logging.info(f"Parallel inference... Device list: {device_list}")
model = torch.nn.DataParallel(model, device_ids=device_list)
if audio.ndim == 1: audio = audio.unsqueeze(0)
if verbose: logging.info("Converting audio...")
wav = convert_audio(wav=audio,
from_samplerate=sample_rate,
to_samplerate=model_sample_rate,
channels=model_chin
).to(device)
if verbose: logging.info("Inference...")
with torch.no_grad():
output = model(wav)
if verbose: logging.info("Converting output...")
output = output.squeeze(0).cpu().numpy()
if verbose: logging.info("Done!")
return output, model_sample_rate