From fc8336fffd40c39bdb225c1b041ab4dd15fac4e9 Mon Sep 17 00:00:00 2001 From: ylzz1997 Date: Wed, 2 Aug 2023 01:09:46 +0800 Subject: [PATCH] Updata VITS2 part (Transformer Flow) --- configs_template/config_template.json | 4 +- configs_template/config_tiny_template.json | 4 +- models.py | 48 ++++++++++++++++++++- modules/attentions.py | 22 ++++++++-- modules/modules.py | 50 ++++++++++++++++++++++ 5 files changed, 122 insertions(+), 6 deletions(-) diff --git a/configs_template/config_template.json b/configs_template/config_template.json index 70a74a6..4b1b323 100644 --- a/configs_template/config_template.json +++ b/configs_template/config_template.json @@ -54,6 +54,7 @@ "upsample_initial_channel": 512, "upsample_kernel_sizes": [16,16, 4, 4, 4], "n_layers_q": 3, + "n_layers_trans_flow": 3, "n_flow_layer": 4, "use_spectral_norm": false, "gin_channels": 768, @@ -65,7 +66,8 @@ "vol_embedding":false, "use_depthwise_conv":false, "flow_share_parameter": false, - "use_automatic_f0_prediction": true + "use_automatic_f0_prediction": true, + "use_transformer_flow": false }, "spk": { "nyaru": 0, diff --git a/configs_template/config_tiny_template.json b/configs_template/config_tiny_template.json index 4865ec5..d0a4381 100644 --- a/configs_template/config_tiny_template.json +++ b/configs_template/config_tiny_template.json @@ -54,6 +54,7 @@ "upsample_initial_channel": 400, "upsample_kernel_sizes": [16,16, 4, 4, 4], "n_layers_q": 3, + "n_layers_trans_flow": 3, "n_flow_layer": 4, "use_spectral_norm": false, "gin_channels": 768, @@ -65,7 +66,8 @@ "vol_embedding":false, "use_depthwise_conv":true, "flow_share_parameter": true, - "use_automatic_f0_prediction": true + "use_automatic_f0_prediction": true, + "use_transformer_flow": false }, "spk": { "nyaru": 0, diff --git a/models.py b/models.py index 6974980..24338fa 100644 --- a/models.py +++ b/models.py @@ -51,6 +51,46 @@ class ResidualCouplingBlock(nn.Module): x = flow(x, x_mask, g=g, reverse=reverse) return x +class TransformerCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + class Encoder(nn.Module): def __init__(self, @@ -327,6 +367,8 @@ class SynthesizerTrn(nn.Module): use_automatic_f0_prediction = True, flow_share_parameter = False, n_flow_layer = 4, + n_layers_trans_flow = 3, + use_transformer_flow = False, **kwargs): super().__init__() @@ -351,6 +393,7 @@ class SynthesizerTrn(nn.Module): self.emb_g = nn.Embedding(n_speakers, gin_channels) self.use_depthwise_conv = use_depthwise_conv self.use_automatic_f0_prediction = use_automatic_f0_prediction + self.n_layers_trans_flow = n_layers_trans_flow if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels) @@ -392,7 +435,10 @@ class SynthesizerTrn(nn.Module): self.dec = Generator(h=hps) self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + if use_transformer_flow: + self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + else: + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) if self.use_automatic_f0_prediction: self.f0_decoder = F0Decoder( 1, diff --git a/modules/attentions.py b/modules/attentions.py index 9086e0e..f9d75bc 100644 --- a/modules/attentions.py +++ b/modules/attentions.py @@ -5,12 +5,13 @@ from torch import nn from torch.nn import functional as F import modules.commons as commons +from modules.DSConv import weight_norm_modules from modules.modules import LayerNorm class FFT(nn.Module): def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., - proximal_bias=False, proximal_init=True, **kwargs): + proximal_bias=False, proximal_init=True, isflow = False, **kwargs): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -20,7 +21,11 @@ class FFT(nn.Module): self.p_dropout = p_dropout self.proximal_bias = proximal_bias self.proximal_init = proximal_init - + if isflow: + cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name='weight') + self.gin_channels = kwargs["gin_channels"] self.drop = nn.Dropout(p_dropout) self.self_attn_layers = nn.ModuleList() self.norm_layers_0 = nn.ModuleList() @@ -35,14 +40,25 @@ class FFT(nn.Module): FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) self.norm_layers_1.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask): + def forward(self, x, x_mask, g = None): """ x: decoder input h: encoder output """ + if g is not None: + g = self.cond_layer(g) + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) x = x * x_mask for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + x = commons.fused_add_tanh_sigmoid_multiply( + x, + g_l, + torch.IntTensor([self.hidden_channels])) y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.drop(y) x = self.norm_layers_0[i](x + y) diff --git a/modules/modules.py b/modules/modules.py index 2b9ad90..a622d4f 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -2,6 +2,7 @@ import torch from torch import nn from torch.nn import functional as F +import modules.attentions as attentions import modules.commons as commons from modules.commons import get_padding, init_weights from modules.DSConv import ( @@ -304,3 +305,52 @@ class ResidualCouplingLayer(nn.Module): x1 = (x1 - m) * torch.exp(-logs) * x_mask x = torch.cat([x0, x1], 1) return x + +class TransformerCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels = 0 + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x