2023-03-10 18:11:04 +08:00
import argparse
2023-06-26 14:57:53 +08:00
import json
import os
2023-03-10 18:11:04 +08:00
import re
2023-06-26 14:57:53 +08:00
import wave
from random import shuffle
2023-03-10 18:11:04 +08:00
2023-07-22 21:56:02 +08:00
from loguru import logger
2023-03-10 18:11:04 +08:00
from tqdm import tqdm
2023-05-16 13:17:51 +08:00
import diffusion . logger . utils as du
2023-03-10 18:11:04 +08:00
pattern = re . compile ( r ' ^[ \ .a-zA-Z0-9_ \ /]+$ ' )
def get_wav_duration ( file_path ) :
2023-07-27 16:01:39 +08:00
try :
with wave . open ( file_path , ' rb ' ) as wav_file :
# 获取音频帧数
n_frames = wav_file . getnframes ( )
# 获取采样率
framerate = wav_file . getframerate ( )
# 计算时长(秒)
return n_frames / float ( framerate )
except Exception as e :
logger . error ( f " Reading { file_path } " )
raise e
2023-03-10 18:11:04 +08:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --train_list " , type = str , default = " ./filelists/train.txt " , help = " path to train list " )
parser . add_argument ( " --val_list " , type = str , default = " ./filelists/val.txt " , help = " path to val list " )
parser . add_argument ( " --source_dir " , type = str , default = " ./dataset/44k " , help = " path to source dir " )
2023-06-07 19:22:47 +08:00
parser . add_argument ( " --speech_encoder " , type = str , default = " vec768l12 " , help = " choice a speech encoder| ' vec768l12 ' , ' vec256l9 ' , ' hubertsoft ' , ' whisper-ppg ' , ' cnhubertlarge ' , ' dphubert ' , ' whisper-ppg-large ' , ' wavlmbase+ ' " )
2023-05-30 06:35:53 +08:00
parser . add_argument ( " --vol_aug " , action = " store_true " , help = " Whether to use volume embedding and volume augmentation " )
2023-07-17 23:46:58 +08:00
parser . add_argument ( " --tiny " , action = " store_true " , help = " Whether to train sovits tiny " )
2023-03-10 18:11:04 +08:00
args = parser . parse_args ( )
2023-07-20 23:28:28 +08:00
config_template = json . load ( open ( " configs_template/config_tiny_template.json " ) ) if args . tiny else json . load ( open ( " configs_template/config_template.json " ) )
2023-03-10 18:11:04 +08:00
train = [ ]
val = [ ]
idx = 0
spk_dict = { }
spk_id = 0
2023-07-20 23:28:28 +08:00
2023-03-10 18:11:04 +08:00
for speaker in tqdm ( os . listdir ( args . source_dir ) ) :
spk_dict [ speaker ] = spk_id
spk_id + = 1
2023-07-27 16:23:38 +08:00
wavs = [ ]
for file_name in os . listdir ( os . path . join ( args . source_dir , speaker ) ) :
if not file_name . endswith ( " wav " ) :
continue
if file_name . startswith ( " . " ) :
2023-03-10 18:11:04 +08:00
continue
2023-07-27 16:23:38 +08:00
file_path = " / " . join ( [ args . source_dir , speaker , file_name ] )
if not pattern . match ( file_name ) :
logger . warning ( " Detected non-ASCII file name: " + file_path )
if get_wav_duration ( file_path ) < 0.3 :
logger . info ( " Skip too short audio: " + file_path )
2023-03-10 18:11:04 +08:00
continue
2023-07-27 16:23:38 +08:00
wavs . append ( file_path )
2023-03-10 18:11:04 +08:00
shuffle ( wavs )
2023-03-24 12:43:29 +08:00
train + = wavs [ 2 : ]
2023-03-10 18:11:04 +08:00
val + = wavs [ : 2 ]
shuffle ( train )
shuffle ( val )
2023-07-27 16:23:38 +08:00
logger . info ( " Writing " + args . train_list )
2023-03-10 18:11:04 +08:00
with open ( args . train_list , " w " ) as f :
for fname in tqdm ( train ) :
wavpath = fname
f . write ( wavpath + " \n " )
2023-07-27 16:23:38 +08:00
logger . info ( " Writing " + args . val_list )
2023-03-10 18:11:04 +08:00
with open ( args . val_list , " w " ) as f :
for fname in tqdm ( val ) :
wavpath = fname
f . write ( wavpath + " \n " )
2023-05-16 13:17:51 +08:00
2023-07-20 23:28:28 +08:00
d_config_template = du . load_config ( " configs_template/diffusion_template.yaml " )
2023-05-18 19:34:40 +08:00
d_config_template [ " model " ] [ " n_spk " ] = spk_id
d_config_template [ " data " ] [ " encoder " ] = args . speech_encoder
d_config_template [ " spk " ] = spk_dict
2023-05-16 13:17:51 +08:00
2023-03-10 18:11:04 +08:00
config_template [ " spk " ] = spk_dict
2023-03-21 19:31:54 +08:00
config_template [ " model " ] [ " n_speakers " ] = spk_id
2023-05-14 14:39:07 +08:00
config_template [ " model " ] [ " speech_encoder " ] = args . speech_encoder
2023-05-16 13:17:51 +08:00
2023-06-07 19:22:47 +08:00
if args . speech_encoder == " vec768l12 " or args . speech_encoder == " dphubert " or args . speech_encoder == " wavlmbase+ " :
2023-05-14 14:39:07 +08:00
config_template [ " model " ] [ " ssl_dim " ] = config_template [ " model " ] [ " filter_channels " ] = config_template [ " model " ] [ " gin_channels " ] = 768
2023-05-18 19:34:40 +08:00
d_config_template [ " data " ] [ " encoder_out_channels " ] = 768
2023-05-14 14:39:07 +08:00
elif args . speech_encoder == " vec256l9 " or args . speech_encoder == ' hubertsoft ' :
2023-07-01 04:27:18 +08:00
config_template [ " model " ] [ " ssl_dim " ] = config_template [ " model " ] [ " gin_channels " ] = 256
2023-05-18 19:34:40 +08:00
d_config_template [ " data " ] [ " encoder_out_channels " ] = 256
2023-06-02 02:15:42 +08:00
elif args . speech_encoder == " whisper-ppg " or args . speech_encoder == ' cnhubertlarge ' :
2023-05-25 01:18:18 +08:00
config_template [ " model " ] [ " ssl_dim " ] = config_template [ " model " ] [ " filter_channels " ] = config_template [ " model " ] [ " gin_channels " ] = 1024
d_config_template [ " data " ] [ " encoder_out_channels " ] = 1024
2023-06-04 12:42:55 +08:00
elif args . speech_encoder == " whisper-ppg-large " :
config_template [ " model " ] [ " ssl_dim " ] = config_template [ " model " ] [ " filter_channels " ] = config_template [ " model " ] [ " gin_channels " ] = 1280
d_config_template [ " data " ] [ " encoder_out_channels " ] = 1280
2023-05-30 06:35:53 +08:00
if args . vol_aug :
config_template [ " train " ] [ " vol_aug " ] = config_template [ " model " ] [ " vol_embedding " ] = True
2023-05-25 00:41:04 +08:00
2023-07-17 23:46:58 +08:00
if args . tiny :
config_template [ " model " ] [ " filter_channels " ] = 512
2023-07-22 21:56:02 +08:00
logger . info ( " Writing to configs/config.json " )
2023-03-10 18:11:04 +08:00
with open ( " configs/config.json " , " w " ) as f :
json . dump ( config_template , f , indent = 2 )
2023-07-22 21:56:02 +08:00
logger . info ( " Writing to configs/diffusion.yaml " )
2023-05-16 13:17:51 +08:00
du . save_config ( " configs/diffusion.yaml " , d_config_template )