mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-09 04:27:31 +08:00
commit
39b0befef5
@ -136,7 +136,7 @@ class FCPE(nn.Module):
|
|||||||
B, N, _ = y.size()
|
B, N, _ = y.size()
|
||||||
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
||||||
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
||||||
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
|
local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4)
|
||||||
local_argmax_index[local_argmax_index<0] = 0
|
local_argmax_index[local_argmax_index<0] = 0
|
||||||
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
|
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
|
||||||
ci_l = torch.gather(ci,-1,local_argmax_index)
|
ci_l = torch.gather(ci,-1,local_argmax_index)
|
||||||
|
@ -13,14 +13,17 @@ import diffusion.logger.utils as du
|
|||||||
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
|
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
|
||||||
|
|
||||||
def get_wav_duration(file_path):
|
def get_wav_duration(file_path):
|
||||||
with wave.open(file_path, 'rb') as wav_file:
|
try:
|
||||||
# 获取音频帧数
|
with wave.open(file_path, 'rb') as wav_file:
|
||||||
n_frames = wav_file.getnframes()
|
# 获取音频帧数
|
||||||
# 获取采样率
|
n_frames = wav_file.getnframes()
|
||||||
framerate = wav_file.getframerate()
|
# 获取采样率
|
||||||
# 计算时长(秒)
|
framerate = wav_file.getframerate()
|
||||||
duration = n_frames / float(framerate)
|
# 计算时长(秒)
|
||||||
return duration
|
return n_frames / float(framerate)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Reading {file_path}")
|
||||||
|
raise e
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -42,32 +45,39 @@ if __name__ == "__main__":
|
|||||||
for speaker in tqdm(os.listdir(args.source_dir)):
|
for speaker in tqdm(os.listdir(args.source_dir)):
|
||||||
spk_dict[speaker] = spk_id
|
spk_dict[speaker] = spk_id
|
||||||
spk_id += 1
|
spk_id += 1
|
||||||
wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))]
|
wavs = []
|
||||||
new_wavs = []
|
|
||||||
for file in wavs:
|
for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
|
||||||
if not file.endswith("wav"):
|
if not file_name.endswith("wav"):
|
||||||
continue
|
continue
|
||||||
if not pattern.match(file):
|
if file_name.startswith("."):
|
||||||
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
|
|
||||||
if get_wav_duration(file) < 0.3:
|
|
||||||
logger.info("Skip too short audio:" + file)
|
|
||||||
continue
|
continue
|
||||||
new_wavs.append(file)
|
|
||||||
wavs = new_wavs
|
file_path = "/".join([args.source_dir, speaker, file_name])
|
||||||
|
|
||||||
|
if not pattern.match(file_name):
|
||||||
|
logger.warning("Detected non-ASCII file name: " + file_path)
|
||||||
|
|
||||||
|
if get_wav_duration(file_path) < 0.3:
|
||||||
|
logger.info("Skip too short audio: " + file_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
wavs.append(file_path)
|
||||||
|
|
||||||
shuffle(wavs)
|
shuffle(wavs)
|
||||||
train += wavs[2:]
|
train += wavs[2:]
|
||||||
val += wavs[:2]
|
val += wavs[:2]
|
||||||
|
|
||||||
shuffle(train)
|
shuffle(train)
|
||||||
shuffle(val)
|
shuffle(val)
|
||||||
|
|
||||||
logger.info("Writing" + args.train_list)
|
logger.info("Writing " + args.train_list)
|
||||||
with open(args.train_list, "w") as f:
|
with open(args.train_list, "w") as f:
|
||||||
for fname in tqdm(train):
|
for fname in tqdm(train):
|
||||||
wavpath = fname
|
wavpath = fname
|
||||||
f.write(wavpath + "\n")
|
f.write(wavpath + "\n")
|
||||||
|
|
||||||
logger.info("Writing" + args.val_list)
|
logger.info("Writing " + args.val_list)
|
||||||
with open(args.val_list, "w") as f:
|
with open(args.val_list, "w") as f:
|
||||||
for fname in tqdm(val):
|
for fname in tqdm(val):
|
||||||
wavpath = fname
|
wavpath = fname
|
||||||
|
@ -113,7 +113,7 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu")
|
|||||||
logger.info(f"Rank {rank} uses device {device}")
|
logger.info(f"Rank {rank} uses device {device}")
|
||||||
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
|
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
|
||||||
logger.info(f"Loaded speech encoder for rank {rank}")
|
logger.info(f"Loaded speech encoder for rank {rank}")
|
||||||
for filename in tqdm(file_chunk):
|
for filename in tqdm(file_chunk, position = rank):
|
||||||
process_one(filename, hmodel, f0p, device, diff, mel_extractor)
|
process_one(filename, hmodel, f0p, device, diff, mel_extractor)
|
||||||
|
|
||||||
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
|
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
|
||||||
@ -124,7 +124,7 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device)
|
|||||||
end = int((i + 1) * len(filenames) / num_processes)
|
end = int((i + 1) * len(filenames) / num_processes)
|
||||||
file_chunk = filenames[start:end]
|
file_chunk = filenames[start:end]
|
||||||
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
|
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
|
||||||
for task in tqdm(tasks):
|
for task in tqdm(tasks, position = 0):
|
||||||
task.result()
|
task.result()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -149,10 +149,10 @@ if __name__ == "__main__":
|
|||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
print(speech_encoder)
|
print(speech_encoder)
|
||||||
logger.info("Using device: ", device)
|
logger.info("Using device: " + str(device))
|
||||||
logger.info("Using SpeechEncoder: " + speech_encoder)
|
logger.info("Using SpeechEncoder: " + speech_encoder)
|
||||||
logger.info("Using extractor: " + f0p)
|
logger.info("Using extractor: " + f0p)
|
||||||
logger.info("Using diff Mode: " + str( args.use_diff))
|
logger.info("Using diff Mode: " + str(args.use_diff))
|
||||||
|
|
||||||
if args.use_diff:
|
if args.use_diff:
|
||||||
print("use_diff")
|
print("use_diff")
|
||||||
|
0
pretrain/__init__.py
Normal file
0
pretrain/__init__.py
Normal file
@ -78,10 +78,9 @@
|
|||||||
"#@markdown\n",
|
"#@markdown\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n",
|
"!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n",
|
||||||
"%pip uninstall -y torchdata torchtext\n",
|
"%cd /content/so-vits-svc\n",
|
||||||
"%pip install --upgrade pip setuptools numpy numba\n",
|
"%pip install --upgrade pip setuptools\n",
|
||||||
"%pip install pyworld praat-parselmouth fairseq tensorboardX torchcrepe librosa==0.9.1 pyyaml pynvml pyloudnorm faiss-gpu\n",
|
"%pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118\n",
|
||||||
"%pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118\n",
|
|
||||||
"exit()"
|
"exit()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -163,8 +162,8 @@
|
|||||||
"#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n",
|
"#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"download_pretrained_model = True #@param {type:\"boolean\"}\n",
|
"download_pretrained_model = True #@param {type:\"boolean\"}\n",
|
||||||
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\"] {allow-input: true}\n",
|
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_D_320000.pth\"] {allow-input: true}\n",
|
||||||
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\"] {allow-input: true}\n",
|
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_G_320000.pth\"] {allow-input: true}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n",
|
"download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n",
|
||||||
"diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n",
|
"diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n",
|
||||||
@ -317,13 +316,17 @@
|
|||||||
"#@markdown\n",
|
"#@markdown\n",
|
||||||
"%cd /content/so-vits-svc\n",
|
"%cd /content/so-vits-svc\n",
|
||||||
"\n",
|
"\n",
|
||||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
|
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
|
||||||
"use_diff = True #@param {type:\"boolean\"}\n",
|
"use_diff = True #@param {type:\"boolean\"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
||||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
|
||||||
|
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"diff_param = \"\"\n",
|
"diff_param = \"\"\n",
|
||||||
"if use_diff:\n",
|
"if use_diff:\n",
|
||||||
" diff_param = \"--use_diff\"\n",
|
" diff_param = \"--use_diff\"\n",
|
||||||
@ -624,7 +627,7 @@
|
|||||||
"if auto_predict_f0:\n",
|
"if auto_predict_f0:\n",
|
||||||
" apf = \" -a \"\n",
|
" apf = \" -a \"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
|
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"enhance = False #@param {type:\"boolean\"}\n",
|
"enhance = False #@param {type:\"boolean\"}\n",
|
||||||
"ehc = \"\"\n",
|
"ehc = \"\"\n",
|
||||||
@ -644,6 +647,9 @@
|
|||||||
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
||||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
|
||||||
|
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
|
||||||
|
"\n",
|
||||||
"if not os.path.exists(output):\n",
|
"if not os.path.exists(output):\n",
|
||||||
" !curl -L {url} -o {output}\n",
|
" !curl -L {url} -o {output}\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user