mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-09 04:27:31 +08:00
Updata rmvpe F0 predictor
This commit is contained in:
parent
c376983d40
commit
e0fc0e8328
10
README.md
10
README.md
@ -145,6 +145,8 @@ While the pretrained model typically does not pose copyright concerns, it is ess
|
||||
|
||||
#### **Optional(Select as Required)**
|
||||
|
||||
##### NSF-HIFIGAN
|
||||
|
||||
If you are using the `NSF-HIFIGAN enhancer` or `shallow diffusion`, you will need to download the pre-trained NSF-HIFIGAN model.
|
||||
|
||||
- Pre-trained NSF-HIFIGAN Vocoder: [nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
|
||||
@ -158,6 +160,13 @@ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
|
||||
# URL: https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
|
||||
```
|
||||
|
||||
##### RMVPE
|
||||
|
||||
If you are using the `rmvpe` F0 Predictor, you will need to download the pre-trained RMVPE model.
|
||||
|
||||
- download model at [rmvpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt)
|
||||
- Place it under the `pretrain` directory
|
||||
|
||||
## 📊 Dataset Preparation
|
||||
|
||||
Simply place the dataset in the `dataset_raw` directory with the following file structure:
|
||||
@ -285,6 +294,7 @@ crepe
|
||||
dio
|
||||
pm
|
||||
harvest
|
||||
rmvpe
|
||||
```
|
||||
|
||||
If the training set is too noisy,it is recommended to use `crepe` to handle f0
|
||||
|
@ -145,6 +145,8 @@ wget -P pretrain/ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/mai
|
||||
|
||||
#### **可选项(根据情况选择)**
|
||||
|
||||
##### NSF-HIFIGAN
|
||||
|
||||
如果使用`NSF-HIFIGAN 增强器`或`浅层扩散`的话,需要下载预训练的 NSF-HIFIGAN 模型,如果不需要可以不下载
|
||||
|
||||
+ 预训练的 NSF-HIFIGAN 声码器 :[nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
|
||||
@ -158,6 +160,14 @@ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
|
||||
# 地址:https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
|
||||
```
|
||||
|
||||
##### RMVPE
|
||||
|
||||
如果使用`rmvpe`F0预测器的话,需要下载预训练的 RMVPE 模型
|
||||
|
||||
+ 下载模型 [rmvpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt)
|
||||
+ 放在`pretrain`目录下
|
||||
|
||||
|
||||
## 📊 数据集准备
|
||||
|
||||
仅需要以以下文件结构将数据集放入 dataset_raw 目录即可
|
||||
@ -287,6 +297,7 @@ crepe
|
||||
dio
|
||||
pm
|
||||
harvest
|
||||
rmvpe
|
||||
```
|
||||
|
||||
如果训练集过于嘈杂,请使用 crepe 处理 f0
|
||||
|
106
modules/F0Predictor/RMVPEF0Predictor.py
Normal file
106
modules/F0Predictor/RMVPEF0Predictor.py
Normal file
@ -0,0 +1,106 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules.F0Predictor.F0Predictor import F0Predictor
|
||||
|
||||
from .rmvpe import RMVPE
|
||||
|
||||
|
||||
class RMVPEF0Predictor(F0Predictor):
|
||||
def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
|
||||
self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
|
||||
self.hop_length = hop_length
|
||||
self.f0_min = f0_min
|
||||
self.f0_max = f0_max
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
self.threshold = threshold
|
||||
self.sampling_rate = sampling_rate
|
||||
self.dtype = dtype
|
||||
|
||||
def repeat_expand(
|
||||
self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
|
||||
):
|
||||
ndim = content.ndim
|
||||
|
||||
if content.ndim == 1:
|
||||
content = content[None, None]
|
||||
elif content.ndim == 2:
|
||||
content = content[None]
|
||||
|
||||
assert content.ndim == 3
|
||||
|
||||
is_np = isinstance(content, np.ndarray)
|
||||
if is_np:
|
||||
content = torch.from_numpy(content)
|
||||
|
||||
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
|
||||
|
||||
if is_np:
|
||||
results = results.numpy()
|
||||
|
||||
if ndim == 1:
|
||||
return results[0, 0]
|
||||
elif ndim == 2:
|
||||
return results[0]
|
||||
|
||||
def post_process(self, x, sampling_rate, f0, pad_to):
|
||||
if isinstance(f0, np.ndarray):
|
||||
f0 = torch.from_numpy(f0).float().to(x.device)
|
||||
|
||||
if pad_to is None:
|
||||
return f0
|
||||
|
||||
f0 = self.repeat_expand(f0, pad_to)
|
||||
|
||||
vuv_vector = torch.zeros_like(f0)
|
||||
vuv_vector[f0 > 0.0] = 1.0
|
||||
vuv_vector[f0 <= 0.0] = 0.0
|
||||
|
||||
# 去掉0频率, 并线性插值
|
||||
nzindex = torch.nonzero(f0).squeeze()
|
||||
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
||||
time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
|
||||
time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
|
||||
|
||||
vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
|
||||
|
||||
if f0.shape[0] <= 0:
|
||||
return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
|
||||
if f0.shape[0] == 1:
|
||||
return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
|
||||
|
||||
# 大概可以用 torch 重写?
|
||||
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
|
||||
#vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
|
||||
|
||||
return f0,vuv_vector.cpu().numpy()
|
||||
|
||||
def compute_f0(self,wav,p_len=None):
|
||||
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
||||
if p_len is None:
|
||||
p_len = x.shape[0]//self.hop_length
|
||||
else:
|
||||
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
|
||||
if torch.all(f0 == 0):
|
||||
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
|
||||
return rtn,rtn
|
||||
return self.post_process(x,self.sampling_rate,f0,p_len)[0]
|
||||
|
||||
def compute_f0_uv(self,wav,p_len=None):
|
||||
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
||||
if p_len is None:
|
||||
p_len = x.shape[0]//self.hop_length
|
||||
else:
|
||||
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
|
||||
if torch.all(f0 == 0):
|
||||
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
|
||||
return rtn,rtn
|
||||
return self.post_process(x,self.sampling_rate,f0,p_len)
|
10
modules/F0Predictor/rmvpe/__init__.py
Normal file
10
modules/F0Predictor/rmvpe/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from .constants import * # noqa: F403
|
||||
from .inference import RMVPE # noqa: F401
|
||||
from .model import E2E, E2E0 # noqa: F401
|
||||
from .spec import MelSpectrogram # noqa: F401
|
||||
from .utils import ( # noqa: F401
|
||||
cycle,
|
||||
summary,
|
||||
to_local_average_cents,
|
||||
to_viterbi_cents,
|
||||
)
|
9
modules/F0Predictor/rmvpe/constants.py
Normal file
9
modules/F0Predictor/rmvpe/constants.py
Normal file
@ -0,0 +1,9 @@
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
N_CLASS = 360
|
||||
|
||||
N_MELS = 128
|
||||
MEL_FMIN = 30
|
||||
MEL_FMAX = SAMPLE_RATE // 2
|
||||
WINDOW_LENGTH = 1024
|
||||
CONST = 1997.3794084376191
|
190
modules/F0Predictor/rmvpe/deepunet.py
Normal file
190
modules/F0Predictor/rmvpe/deepunet.py
Normal file
@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .constants import N_MELS
|
||||
|
||||
|
||||
class ConvBlockRes(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, momentum=0.01):
|
||||
super(ConvBlockRes, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
padding=(1, 1),
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
padding=(1, 1),
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
)
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
||||
self.is_shortcut = True
|
||||
else:
|
||||
self.is_shortcut = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_shortcut:
|
||||
return self.conv(x) + self.shortcut(x)
|
||||
else:
|
||||
return self.conv(x) + x
|
||||
|
||||
|
||||
class ResEncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
||||
super(ResEncoderBlock, self).__init__()
|
||||
self.n_blocks = n_blocks
|
||||
self.conv = nn.ModuleList()
|
||||
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
||||
for i in range(n_blocks - 1):
|
||||
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||
self.kernel_size = kernel_size
|
||||
if self.kernel_size is not None:
|
||||
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.n_blocks):
|
||||
x = self.conv[i](x)
|
||||
if self.kernel_size is not None:
|
||||
return x, self.pool(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ResDecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
||||
super(ResDecoderBlock, self).__init__()
|
||||
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
||||
self.n_blocks = n_blocks
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=stride,
|
||||
padding=(1, 1),
|
||||
output_padding=out_padding,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels, momentum=momentum),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.conv2 = nn.ModuleList()
|
||||
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
||||
for i in range(n_blocks-1):
|
||||
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
||||
|
||||
def forward(self, x, concat_tensor):
|
||||
x = self.conv1(x)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
for i in range(self.n_blocks):
|
||||
x = self.conv2[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
||||
super(Encoder, self).__init__()
|
||||
self.n_encoders = n_encoders
|
||||
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
||||
self.layers = nn.ModuleList()
|
||||
self.latent_channels = []
|
||||
for i in range(self.n_encoders):
|
||||
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
||||
self.latent_channels.append([out_channels, in_size])
|
||||
in_channels = out_channels
|
||||
out_channels *= 2
|
||||
in_size //= 2
|
||||
self.out_size = in_size
|
||||
self.out_channel = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
concat_tensors = []
|
||||
x = self.bn(x)
|
||||
for i in range(self.n_encoders):
|
||||
_, x = self.layers[i](x)
|
||||
concat_tensors.append(_)
|
||||
return x, concat_tensors
|
||||
|
||||
|
||||
class Intermediate(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
||||
super(Intermediate, self).__init__()
|
||||
self.n_inters = n_inters
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
||||
for i in range(self.n_inters-1):
|
||||
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.n_inters):
|
||||
x = self.layers[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
self.n_decoders = n_decoders
|
||||
for i in range(self.n_decoders):
|
||||
out_channels = in_channels // 2
|
||||
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
||||
in_channels = out_channels
|
||||
|
||||
def forward(self, x, concat_tensors):
|
||||
for i in range(self.n_decoders):
|
||||
x = self.layers[i](x, concat_tensors[-1-i])
|
||||
return x
|
||||
|
||||
|
||||
class TimbreFilter(nn.Module):
|
||||
def __init__(self, latent_rep_channels):
|
||||
super(TimbreFilter, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
for latent_rep in latent_rep_channels:
|
||||
self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
|
||||
|
||||
def forward(self, x_tensors):
|
||||
out_tensors = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
out_tensors.append(layer(x_tensors[i]))
|
||||
return out_tensors
|
||||
|
||||
|
||||
class DeepUnet(nn.Module):
|
||||
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
||||
super(DeepUnet, self).__init__()
|
||||
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
||||
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
||||
self.tf = TimbreFilter(self.encoder.latent_channels)
|
||||
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x, concat_tensors = self.encoder(x)
|
||||
x = self.intermediate(x)
|
||||
concat_tensors = self.tf(concat_tensors)
|
||||
x = self.decoder(x, concat_tensors)
|
||||
return x
|
||||
|
||||
|
||||
class DeepUnet0(nn.Module):
|
||||
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
||||
super(DeepUnet0, self).__init__()
|
||||
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
||||
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
||||
self.tf = TimbreFilter(self.encoder.latent_channels)
|
||||
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x, concat_tensors = self.encoder(x)
|
||||
x = self.intermediate(x)
|
||||
x = self.decoder(x, concat_tensors)
|
||||
return x
|
57
modules/F0Predictor/rmvpe/inference.py
Normal file
57
modules/F0Predictor/rmvpe/inference.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio.transforms import Resample
|
||||
|
||||
from .constants import * # noqa: F403
|
||||
from .model import E2E0
|
||||
from .spec import MelSpectrogram
|
||||
from .utils import to_local_average_cents, to_viterbi_cents
|
||||
|
||||
|
||||
class RMVPE:
|
||||
def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160):
|
||||
self.resample_kernel = {}
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
model = E2E0(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path)
|
||||
model.load_state_dict(ckpt['model'])
|
||||
model = model.to(dtype).to(self.device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.dtype = dtype
|
||||
self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
|
||||
self.resample_kernel = {}
|
||||
|
||||
def mel2hidden(self, mel):
|
||||
with torch.no_grad():
|
||||
n_frames = mel.shape[-1]
|
||||
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
|
||||
hidden = self.model(mel)
|
||||
return hidden[:, :n_frames]
|
||||
|
||||
def decode(self, hidden, thred=0.03, use_viterbi=False):
|
||||
if use_viterbi:
|
||||
cents_pred = to_viterbi_cents(hidden, thred=thred)
|
||||
else:
|
||||
cents_pred = to_local_average_cents(hidden, thred=thred)
|
||||
f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device)
|
||||
return f0
|
||||
|
||||
def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False):
|
||||
audio = audio.unsqueeze(0).to(self.dtype).to(self.device)
|
||||
if sample_rate == 16000:
|
||||
audio_res = audio
|
||||
else:
|
||||
key_str = str(sample_rate)
|
||||
if key_str not in self.resample_kernel:
|
||||
self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
|
||||
self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
|
||||
audio_res = self.resample_kernel[key_str](audio)
|
||||
mel_extractor = self.mel_extractor.to(self.device)
|
||||
mel = mel_extractor(audio_res, center=True).to(self.dtype)
|
||||
hidden = self.mel2hidden(mel)
|
||||
f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
|
||||
return f0
|
67
modules/F0Predictor/rmvpe/model.py
Normal file
67
modules/F0Predictor/rmvpe/model.py
Normal file
@ -0,0 +1,67 @@
|
||||
from torch import nn
|
||||
|
||||
from .constants import * # noqa: F403
|
||||
from .deepunet import DeepUnet, DeepUnet0
|
||||
from .seq import BiGRU
|
||||
from .spec import MelSpectrogram
|
||||
|
||||
|
||||
class E2E(nn.Module):
|
||||
def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
|
||||
en_out_channels=16):
|
||||
super(E2E, self).__init__()
|
||||
self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
|
||||
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
||||
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
||||
if n_gru:
|
||||
self.fc = nn.Sequential(
|
||||
BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
|
||||
nn.Linear(512, N_CLASS), # noqa: F405
|
||||
nn.Dropout(0.25),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
else:
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
|
||||
nn.Dropout(0.25),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1)
|
||||
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
||||
# x = self.fc(x)
|
||||
hidden_vec = 0
|
||||
if len(self.fc) == 4:
|
||||
for i in range(len(self.fc)):
|
||||
x = self.fc[i](x)
|
||||
if i == 0:
|
||||
hidden_vec = x
|
||||
return hidden_vec, x
|
||||
|
||||
|
||||
class E2E0(nn.Module):
|
||||
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
|
||||
en_out_channels=16):
|
||||
super(E2E0, self).__init__()
|
||||
self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
||||
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
||||
if n_gru:
|
||||
self.fc = nn.Sequential(
|
||||
BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
|
||||
nn.Linear(512, N_CLASS), # noqa: F405
|
||||
nn.Dropout(0.25),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
else:
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
|
||||
nn.Dropout(0.25),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, mel):
|
||||
mel = mel.transpose(-1, -2).unsqueeze(1)
|
||||
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
||||
x = self.fc(x)
|
||||
return x
|
20
modules/F0Predictor/rmvpe/seq.py
Normal file
20
modules/F0Predictor/rmvpe/seq.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BiGRU(nn.Module):
|
||||
def __init__(self, input_features, hidden_features, num_layers):
|
||||
super(BiGRU, self).__init__()
|
||||
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.gru(x)[0]
|
||||
|
||||
|
||||
class BiLSTM(nn.Module):
|
||||
def __init__(self, input_features, hidden_features, num_layers):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lstm(x)[0]
|
||||
|
67
modules/F0Predictor/rmvpe/spec.py
Normal file
67
modules/F0Predictor/rmvpe/spec.py
Normal file
@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from librosa.filters import mel
|
||||
|
||||
|
||||
class MelSpectrogram(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mel_channels,
|
||||
sampling_rate,
|
||||
win_length,
|
||||
hop_length,
|
||||
n_fft=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clamp = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
n_fft = win_length if n_fft is None else n_fft
|
||||
self.hann_window = {}
|
||||
mel_basis = mel(
|
||||
sr=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mel_channels,
|
||||
fmin=mel_fmin,
|
||||
fmax=mel_fmax,
|
||||
htk=True)
|
||||
mel_basis = torch.from_numpy(mel_basis).float()
|
||||
self.register_buffer("mel_basis", mel_basis)
|
||||
self.n_fft = win_length if n_fft is None else n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.clamp = clamp
|
||||
|
||||
def forward(self, audio, keyshift=0, speed=1, center=True):
|
||||
factor = 2 ** (keyshift / 12)
|
||||
n_fft_new = int(np.round(self.n_fft * factor))
|
||||
win_length_new = int(np.round(self.win_length * factor))
|
||||
hop_length_new = int(np.round(self.hop_length * speed))
|
||||
|
||||
keyshift_key = str(keyshift)+'_'+str(audio.device)
|
||||
if keyshift_key not in self.hann_window:
|
||||
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
||||
|
||||
fft = torch.stft(
|
||||
audio,
|
||||
n_fft=n_fft_new,
|
||||
hop_length=hop_length_new,
|
||||
win_length=win_length_new,
|
||||
window=self.hann_window[keyshift_key],
|
||||
center=center,
|
||||
return_complex=True)
|
||||
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
||||
|
||||
if keyshift != 0:
|
||||
size = self.n_fft // 2 + 1
|
||||
resize = magnitude.size(1)
|
||||
if resize < size:
|
||||
magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
|
||||
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
||||
|
||||
mel_output = torch.matmul(self.mel_basis, magnitude)
|
||||
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
||||
return log_mel_spec
|
107
modules/F0Predictor/rmvpe/utils.py
Normal file
107
modules/F0Predictor/rmvpe/utils.py
Normal file
@ -0,0 +1,107 @@
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.modules.module import _addindent
|
||||
|
||||
from .constants import * # noqa: F403
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
while True:
|
||||
for item in iterable:
|
||||
yield item
|
||||
|
||||
|
||||
def summary(model, file=sys.stdout):
|
||||
def repr(model):
|
||||
# We treat the extra repr like the sub-module, one item per line
|
||||
extra_lines = []
|
||||
extra_repr = model.extra_repr()
|
||||
# empty string will be split into list ['']
|
||||
if extra_repr:
|
||||
extra_lines = extra_repr.split('\n')
|
||||
child_lines = []
|
||||
total_params = 0
|
||||
for key, module in model._modules.items():
|
||||
mod_str, num_params = repr(module)
|
||||
mod_str = _addindent(mod_str, 2)
|
||||
child_lines.append('(' + key + '): ' + mod_str)
|
||||
total_params += num_params
|
||||
lines = extra_lines + child_lines
|
||||
|
||||
for name, p in model._parameters.items():
|
||||
if hasattr(p, 'shape'):
|
||||
total_params += reduce(lambda x, y: x * y, p.shape)
|
||||
|
||||
main_str = model._get_name() + '('
|
||||
if lines:
|
||||
# simple one-liner info, which most builtin Modules will use
|
||||
if len(extra_lines) == 1 and not child_lines:
|
||||
main_str += extra_lines[0]
|
||||
else:
|
||||
main_str += '\n ' + '\n '.join(lines) + '\n'
|
||||
|
||||
main_str += ')'
|
||||
if file is sys.stdout:
|
||||
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
|
||||
else:
|
||||
main_str += ', {:,} params'.format(total_params)
|
||||
return main_str, total_params
|
||||
|
||||
string, count = repr(model)
|
||||
if file is not None:
|
||||
if isinstance(file, str):
|
||||
file = open(file, 'w')
|
||||
print(string, file=file)
|
||||
file.flush()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def to_local_average_cents(salience, center=None, thred=0.05):
|
||||
"""
|
||||
find the weighted average cents near the argmax bin
|
||||
"""
|
||||
|
||||
if not hasattr(to_local_average_cents, 'cents_mapping'):
|
||||
# the bin number-to-cents mapping
|
||||
to_local_average_cents.cents_mapping = (
|
||||
20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405
|
||||
|
||||
if salience.ndim == 1:
|
||||
if center is None:
|
||||
center = int(torch.argmax(salience))
|
||||
start = max(0, center - 4)
|
||||
end = min(len(salience), center + 5)
|
||||
salience = salience[start:end]
|
||||
product_sum = torch.sum(
|
||||
salience * to_local_average_cents.cents_mapping[start:end])
|
||||
weight_sum = torch.sum(salience)
|
||||
return product_sum / weight_sum if torch.max(salience) > thred else 0
|
||||
if salience.ndim == 2:
|
||||
return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in
|
||||
range(salience.shape[0])]).to(salience.device)
|
||||
|
||||
raise Exception("label should be either 1d or 2d ndarray")
|
||||
|
||||
def to_viterbi_cents(salience, thred=0.05):
|
||||
# Create viterbi transition matrix
|
||||
if not hasattr(to_viterbi_cents, 'transition'):
|
||||
xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405
|
||||
transition = torch.maximum(30 - abs(xx - yy), 0)
|
||||
transition = transition / transition.sum(axis=1, keepdims=True)
|
||||
to_viterbi_cents.transition = transition
|
||||
|
||||
# Convert to probability
|
||||
prob = salience.T
|
||||
prob = prob / prob.sum(axis=0)
|
||||
|
||||
# Perform viterbi decoding
|
||||
path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64)
|
||||
|
||||
return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in
|
||||
range(len(path))]).to(salience.device)
|
||||
|
5
utils.py
5
utils.py
@ -96,7 +96,10 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
|
||||
f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
||||
elif f0_predictor == "dio":
|
||||
from modules.F0Predictor.DioF0Predictor import DioF0Predictor
|
||||
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
||||
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
||||
elif f0_predictor == "rmvpe":
|
||||
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
|
||||
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"])
|
||||
else:
|
||||
raise Exception("Unknown f0 predictor")
|
||||
return f0_predictor_object
|
||||
|
Loading…
Reference in New Issue
Block a user