fix(fcpe): resample

This commit is contained in:
magic-akari 2023-07-23 20:42:10 +08:00
parent 7c7496536f
commit efed9f5f02
No known key found for this signature in database
GPG Key ID: EC005B1159285BDD

View File

@ -170,7 +170,7 @@ class FCPEInfer:
model.load_state_dict(ckpt['model'])
model.eval()
self.model = model
self.wav2mel = Wav2Mel(self.args)
self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05):
@ -182,13 +182,15 @@ class FCPEInfer:
class Wav2Mel:
def __init__(self, args, device=None):
def __init__(self, args, device=None, dtype=torch.float32):
# self.args = args
self.sampling_rate = args.mel.sampling_rate
self.hop_size = args.mel.hop_size
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.dtype = dtype
self.stft = STFT(
args.mel.sampling_rate,
args.mel.num_mels,
@ -205,14 +207,15 @@ class Wav2Mel:
return mel
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
audio = audio.to(self.dtype).to(self.device)
# resample
if sample_rate == self.sampling_rate:
audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel:
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate,
lowpass_filter_width=128).to(self.device)
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128)
self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
audio_res = self.resample_kernel[key_str](audio)
# extract