Updata(fcpe): local decoder

This commit is contained in:
ylzz1997 2023-07-25 23:34:56 +08:00
parent 44ba5dfa55
commit c9d81428fe

View File

@ -84,13 +84,17 @@ class FCPE(nn.Module):
self.dense_out = weight_norm(
nn.Linear(n_chans, self.n_out))
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False):
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder = "local_argmax"):
"""
input:
B x n_frames x n_unit
return:
dict of B x n_frames x feat
"""
if cdecoder == "argmax":
self.cdecoder = self.cents_decoder
elif cdecoder == "local_argmax":
self.cdecoder = self.cents_local_decoder
if self.use_input_conv:
x = self.stack(mel.transpose(1, 2)).transpose(1, 2)
else:
@ -108,7 +112,7 @@ class FCPE(nn.Module):
loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
x = loss_all
if infer:
x = self.cents_decoder(x)
x = self.cdecoder(x)
x = self.cent_to_f0(x)
if not return_hz_f0:
x = (1 + x / 700).log()
@ -128,6 +132,25 @@ class FCPE(nn.Module):
else:
return rtn
def cents_local_decoder(self, y, mask=True):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index<0] = 0
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
ci_l = torch.gather(ci,-1,local_argmax_index)
y_l = torch.gather(y,-1,local_argmax_index)
rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1]
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= self.threshold] = float("-INF")
rtn = rtn * confident_mask
if self.confidence:
return rtn, confident
else:
return rtn
def cent_to_f0(self, cent):
return 10. * 2 ** (cent / 1200.)
@ -165,7 +188,6 @@ class FCPEInfer:
f0_min=self.args.model.f0_min,
confidence=self.args.model.confidence,
)
ckpt = torch.load(model_path, map_location=torch.device(self.device))
model.to(self.device).to(self.dtype)
model.load_state_dict(ckpt['model'])
model.eval()