Merge pull request #365 from svc-develop-team/4.1-Stable

To Latest
This commit is contained in:
YuriHead 2023-08-02 00:43:07 +08:00 committed by GitHub
commit 39b0befef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 35 deletions

View File

@ -136,7 +136,7 @@ class FCPE(nn.Module):
B, N, _ = y.size() B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1) ci = self.cent_table[None, None, :].expand(B, N, -1)
confident, max_index = torch.max(y, dim=-1, keepdim=True) confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4) local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index<0] = 0 local_argmax_index[local_argmax_index<0] = 0
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1 local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
ci_l = torch.gather(ci,-1,local_argmax_index) ci_l = torch.gather(ci,-1,local_argmax_index)

View File

@ -13,14 +13,17 @@ import diffusion.logger.utils as du
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
def get_wav_duration(file_path): def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file: try:
# 获取音频帧数 with wave.open(file_path, 'rb') as wav_file:
n_frames = wav_file.getnframes() # 获取音频帧数
# 获取采样率 n_frames = wav_file.getnframes()
framerate = wav_file.getframerate() # 获取采样率
# 计算时长(秒) framerate = wav_file.getframerate()
duration = n_frames / float(framerate) # 计算时长(秒)
return duration return n_frames / float(framerate)
except Exception as e:
logger.error(f"Reading {file_path}")
raise e
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -42,32 +45,39 @@ if __name__ == "__main__":
for speaker in tqdm(os.listdir(args.source_dir)): for speaker in tqdm(os.listdir(args.source_dir)):
spk_dict[speaker] = spk_id spk_dict[speaker] = spk_id
spk_id += 1 spk_id += 1
wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))] wavs = []
new_wavs = []
for file in wavs: for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
if not file.endswith("wav"): if not file_name.endswith("wav"):
continue continue
if not pattern.match(file): if file_name.startswith("."):
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
if get_wav_duration(file) < 0.3:
logger.info("Skip too short audio:" + file)
continue continue
new_wavs.append(file)
wavs = new_wavs file_path = "/".join([args.source_dir, speaker, file_name])
if not pattern.match(file_name):
logger.warning("Detected non-ASCII file name: " + file_path)
if get_wav_duration(file_path) < 0.3:
logger.info("Skip too short audio: " + file_path)
continue
wavs.append(file_path)
shuffle(wavs) shuffle(wavs)
train += wavs[2:] train += wavs[2:]
val += wavs[:2] val += wavs[:2]
shuffle(train) shuffle(train)
shuffle(val) shuffle(val)
logger.info("Writing" + args.train_list) logger.info("Writing " + args.train_list)
with open(args.train_list, "w") as f: with open(args.train_list, "w") as f:
for fname in tqdm(train): for fname in tqdm(train):
wavpath = fname wavpath = fname
f.write(wavpath + "\n") f.write(wavpath + "\n")
logger.info("Writing" + args.val_list) logger.info("Writing " + args.val_list)
with open(args.val_list, "w") as f: with open(args.val_list, "w") as f:
for fname in tqdm(val): for fname in tqdm(val):
wavpath = fname wavpath = fname

View File

@ -113,7 +113,7 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu")
logger.info(f"Rank {rank} uses device {device}") logger.info(f"Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device) hmodel = utils.get_speech_encoder(speech_encoder, device=device)
logger.info(f"Loaded speech encoder for rank {rank}") logger.info(f"Loaded speech encoder for rank {rank}")
for filename in tqdm(file_chunk): for filename in tqdm(file_chunk, position = rank):
process_one(filename, hmodel, f0p, device, diff, mel_extractor) process_one(filename, hmodel, f0p, device, diff, mel_extractor)
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device): def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
@ -124,7 +124,7 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device)
end = int((i + 1) * len(filenames) / num_processes) end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end] file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device)) tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
for task in tqdm(tasks): for task in tqdm(tasks, position = 0):
task.result() task.result()
if __name__ == "__main__": if __name__ == "__main__":
@ -149,10 +149,10 @@ if __name__ == "__main__":
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(speech_encoder) print(speech_encoder)
logger.info("Using device: ", device) logger.info("Using device: " + str(device))
logger.info("Using SpeechEncoder: " + speech_encoder) logger.info("Using SpeechEncoder: " + speech_encoder)
logger.info("Using extractor: " + f0p) logger.info("Using extractor: " + f0p)
logger.info("Using diff Mode: " + str( args.use_diff)) logger.info("Using diff Mode: " + str(args.use_diff))
if args.use_diff: if args.use_diff:
print("use_diff") print("use_diff")

0
pretrain/__init__.py Normal file
View File

View File

@ -78,10 +78,9 @@
"#@markdown\n", "#@markdown\n",
"\n", "\n",
"!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n", "!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n",
"%pip uninstall -y torchdata torchtext\n", "%cd /content/so-vits-svc\n",
"%pip install --upgrade pip setuptools numpy numba\n", "%pip install --upgrade pip setuptools\n",
"%pip install pyworld praat-parselmouth fairseq tensorboardX torchcrepe librosa==0.9.1 pyyaml pynvml pyloudnorm faiss-gpu\n", "%pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118\n",
"%pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118\n",
"exit()" "exit()"
] ]
}, },
@ -163,8 +162,8 @@
"#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n", "#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n",
"\n", "\n",
"download_pretrained_model = True #@param {type:\"boolean\"}\n", "download_pretrained_model = True #@param {type:\"boolean\"}\n",
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\"] {allow-input: true}\n", "D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_D_320000.pth\"] {allow-input: true}\n",
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\"] {allow-input: true}\n", "G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_G_320000.pth\"] {allow-input: true}\n",
"\n", "\n",
"download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n", "download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n",
"diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n", "diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n",
@ -317,13 +316,17 @@
"#@markdown\n", "#@markdown\n",
"%cd /content/so-vits-svc\n", "%cd /content/so-vits-svc\n",
"\n", "\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n", "f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
"use_diff = True #@param {type:\"boolean\"}\n", "use_diff = True #@param {type:\"boolean\"}\n",
"\n", "\n",
"import os\n", "import os\n",
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n", "if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n", " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n", "\n",
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
"\n",
"\n",
"diff_param = \"\"\n", "diff_param = \"\"\n",
"if use_diff:\n", "if use_diff:\n",
" diff_param = \"--use_diff\"\n", " diff_param = \"--use_diff\"\n",
@ -624,7 +627,7 @@
"if auto_predict_f0:\n", "if auto_predict_f0:\n",
" apf = \" -a \"\n", " apf = \" -a \"\n",
"\n", "\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n", "f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
"\n", "\n",
"enhance = False #@param {type:\"boolean\"}\n", "enhance = False #@param {type:\"boolean\"}\n",
"ehc = \"\"\n", "ehc = \"\"\n",
@ -644,6 +647,9 @@
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n", "if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n", " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n", "\n",
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
"\n",
"if not os.path.exists(output):\n", "if not os.path.exists(output):\n",
" !curl -L {url} -o {output}\n", " !curl -L {url} -o {output}\n",
"\n", "\n",