Debug FCPE

This commit is contained in:
ylzz1997 2023-07-23 00:10:24 +08:00
parent 26329ff059
commit ec6f7f5ade
9 changed files with 21 additions and 23 deletions

View File

@ -203,9 +203,10 @@ class Svc(object):
def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05): def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05):
f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold) if not hasattr(self,"f0_predictor_object") or self.f0_predictor_object is None or f0_predictor != self.f0_predictor_object.name:
self.f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold)
f0, uv = f0_predictor_object.compute_f0_uv(wav) f0, uv = self.f0_predictor_object.compute_f0_uv(wav)
if f0_filter and sum(f0) == 0: if f0_filter and sum(f0) == 0:
raise F0FilterException("No voice detected") raise F0FilterException("No voice detected")
f0 = torch.FloatTensor(f0).to(self.dev) f0 = torch.FloatTensor(f0).to(self.dev)

View File

@ -13,6 +13,7 @@ class CrepeF0Predictor(F0Predictor):
self.device = device self.device = device
self.threshold = threshold self.threshold = threshold
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.name = "crepe"
def compute_f0(self,wav,p_len=None): def compute_f0(self,wav,p_len=None):
x = torch.FloatTensor(wav).to(self.device) x = torch.FloatTensor(wav).to(self.device)

View File

@ -10,6 +10,7 @@ class DioF0Predictor(F0Predictor):
self.f0_min = f0_min self.f0_min = f0_min
self.f0_max = f0_max self.f0_max = f0_max
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.name = "dio"
def interpolate_f0(self,f0): def interpolate_f0(self,f0):
''' '''

View File

@ -23,6 +23,7 @@ class FCPEF0Predictor(F0Predictor):
self.threshold = threshold self.threshold = threshold
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.dtype = dtype self.dtype = dtype
self.name = "fcpe"
def repeat_expand( def repeat_expand(
self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
@ -89,7 +90,7 @@ class FCPEF0Predictor(F0Predictor):
p_len = x.shape[0] // self.hop_length p_len = x.shape[0] // self.hop_length
else: else:
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold) f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
if torch.all(f0 == 0): if torch.all(f0 == 0):
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
return rtn, rtn return rtn, rtn
@ -101,7 +102,7 @@ class FCPEF0Predictor(F0Predictor):
p_len = x.shape[0] // self.hop_length p_len = x.shape[0] // self.hop_length
else: else:
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold) f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
if torch.all(f0 == 0): if torch.all(f0 == 0):
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
return rtn, rtn return rtn, rtn

View File

@ -10,6 +10,7 @@ class HarvestF0Predictor(F0Predictor):
self.f0_min = f0_min self.f0_min = f0_min
self.f0_max = f0_max self.f0_max = f0_max
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.name = "harvest"
def interpolate_f0(self,f0): def interpolate_f0(self,f0):
''' '''

View File

@ -10,7 +10,7 @@ class PMF0Predictor(F0Predictor):
self.f0_min = f0_min self.f0_min = f0_min
self.f0_max = f0_max self.f0_max = f0_max
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.name = "pm"
def interpolate_f0(self,f0): def interpolate_f0(self,f0):
''' '''

View File

@ -22,6 +22,7 @@ class RMVPEF0Predictor(F0Predictor):
self.threshold = threshold self.threshold = threshold
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.dtype = dtype self.dtype = dtype
self.name = "rmvpe"
def repeat_expand( def repeat_expand(
self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"

View File

@ -1,10 +1,7 @@
import os
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import yaml
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from torchaudio.transforms import Resample from torchaudio.transforms import Resample
@ -146,10 +143,11 @@ class FCPE(nn.Module):
class FCPEInfer: class FCPEInfer:
def __init__(self, model_path, device=None, dtype=torch.float32): def __init__(self, model_path, device=None, dtype=torch.float32):
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') if device is None:
with open(config_file, "r") as config: device = 'cuda' if torch.cuda.is_available() else 'cpu'
args = yaml.safe_load(config) self.device = device
self.args = DotDict(args) ckpt = torch.load(model_path, map_location=torch.device(self.device))
self.args = DotDict(ckpt["config"])
self.dtype = dtype self.dtype = dtype
model = FCPE( model = FCPE(
input_channel=self.args.model.input_channel, input_channel=self.args.model.input_channel,
@ -167,25 +165,19 @@ class FCPEInfer:
f0_min=self.args.model.f0_min, f0_min=self.args.model.f0_min,
confidence=self.args.model.confidence, confidence=self.args.model.confidence,
) )
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
ckpt = torch.load(model_path, map_location=torch.device(self.device)) ckpt = torch.load(model_path, map_location=torch.device(self.device))
model.to(self.device).to(self.dtype) model.to(self.device).to(self.dtype)
model.load_state_dict(ckpt['model']) model.load_state_dict(ckpt['model'])
model.eval() model.eval()
self.model = model self.model = model
self.wav2mel = Wav2Mel(self.args) self.wav2mel = Wav2Mel(self.args)
self.args = args
@torch.no_grad() @torch.no_grad()
def __call__(self, audio, sr, threshold=0.05): def __call__(self, audio, sr, threshold=0.05):
self.model.threshold = threshold self.model.threshold = threshold
audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) audio = audio[None,:]
mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype) mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
mel_f0 = self.model(mel=mel, infer=True, return_hz_f0=True) f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
# f0 = (mel_f0.exp() - 1) * 700
f0 = mel_f0
return f0 return f0

View File

@ -102,8 +102,8 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
elif f0_predictor == "fcpe": elif f0_predictor == "fcpe":
from modules.F0Predictor.FCPEF0Predictor import FCEF0Predictor from modules.F0Predictor.FCPEF0Predictor import FCPEF0Predictor
f0_predictor_object = FCEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) f0_predictor_object = FCPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
else: else:
raise Exception("Unknown f0 predictor") raise Exception("Unknown f0 predictor")
return f0_predictor_object return f0_predictor_object