mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-08 11:57:43 +08:00
Updata(fcpe): local decoder
This commit is contained in:
parent
44ba5dfa55
commit
c9d81428fe
@ -84,13 +84,17 @@ class FCPE(nn.Module):
|
||||
self.dense_out = weight_norm(
|
||||
nn.Linear(n_chans, self.n_out))
|
||||
|
||||
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False):
|
||||
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder = "local_argmax"):
|
||||
"""
|
||||
input:
|
||||
B x n_frames x n_unit
|
||||
return:
|
||||
dict of B x n_frames x feat
|
||||
"""
|
||||
if cdecoder == "argmax":
|
||||
self.cdecoder = self.cents_decoder
|
||||
elif cdecoder == "local_argmax":
|
||||
self.cdecoder = self.cents_local_decoder
|
||||
if self.use_input_conv:
|
||||
x = self.stack(mel.transpose(1, 2)).transpose(1, 2)
|
||||
else:
|
||||
@ -108,7 +112,7 @@ class FCPE(nn.Module):
|
||||
loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
||||
x = loss_all
|
||||
if infer:
|
||||
x = self.cents_decoder(x)
|
||||
x = self.cdecoder(x)
|
||||
x = self.cent_to_f0(x)
|
||||
if not return_hz_f0:
|
||||
x = (1 + x / 700).log()
|
||||
@ -127,6 +131,25 @@ class FCPE(nn.Module):
|
||||
return rtn, confident
|
||||
else:
|
||||
return rtn
|
||||
|
||||
def cents_local_decoder(self, y, mask=True):
|
||||
B, N, _ = y.size()
|
||||
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
||||
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
||||
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
|
||||
local_argmax_index[local_argmax_index<0] = 0
|
||||
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
|
||||
ci_l = torch.gather(ci,-1,local_argmax_index)
|
||||
y_l = torch.gather(y,-1,local_argmax_index)
|
||||
rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1]
|
||||
if mask:
|
||||
confident_mask = torch.ones_like(confident)
|
||||
confident_mask[confident <= self.threshold] = float("-INF")
|
||||
rtn = rtn * confident_mask
|
||||
if self.confidence:
|
||||
return rtn, confident
|
||||
else:
|
||||
return rtn
|
||||
|
||||
def cent_to_f0(self, cent):
|
||||
return 10. * 2 ** (cent / 1200.)
|
||||
@ -165,7 +188,6 @@ class FCPEInfer:
|
||||
f0_min=self.args.model.f0_min,
|
||||
confidence=self.args.model.confidence,
|
||||
)
|
||||
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
||||
model.to(self.device).to(self.dtype)
|
||||
model.load_state_dict(ckpt['model'])
|
||||
model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user