mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-08 11:57:43 +08:00
commit
39b0befef5
@ -136,7 +136,7 @@ class FCPE(nn.Module):
|
||||
B, N, _ = y.size()
|
||||
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
||||
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
||||
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
|
||||
local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4)
|
||||
local_argmax_index[local_argmax_index<0] = 0
|
||||
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
|
||||
ci_l = torch.gather(ci,-1,local_argmax_index)
|
||||
|
@ -13,14 +13,17 @@ import diffusion.logger.utils as du
|
||||
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
# 获取音频帧数
|
||||
n_frames = wav_file.getnframes()
|
||||
# 获取采样率
|
||||
framerate = wav_file.getframerate()
|
||||
# 计算时长(秒)
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
try:
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
# 获取音频帧数
|
||||
n_frames = wav_file.getnframes()
|
||||
# 获取采样率
|
||||
framerate = wav_file.getframerate()
|
||||
# 计算时长(秒)
|
||||
return n_frames / float(framerate)
|
||||
except Exception as e:
|
||||
logger.error(f"Reading {file_path}")
|
||||
raise e
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -42,32 +45,39 @@ if __name__ == "__main__":
|
||||
for speaker in tqdm(os.listdir(args.source_dir)):
|
||||
spk_dict[speaker] = spk_id
|
||||
spk_id += 1
|
||||
wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))]
|
||||
new_wavs = []
|
||||
for file in wavs:
|
||||
if not file.endswith("wav"):
|
||||
wavs = []
|
||||
|
||||
for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
|
||||
if not file_name.endswith("wav"):
|
||||
continue
|
||||
if not pattern.match(file):
|
||||
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
|
||||
if get_wav_duration(file) < 0.3:
|
||||
logger.info("Skip too short audio:" + file)
|
||||
if file_name.startswith("."):
|
||||
continue
|
||||
new_wavs.append(file)
|
||||
wavs = new_wavs
|
||||
|
||||
file_path = "/".join([args.source_dir, speaker, file_name])
|
||||
|
||||
if not pattern.match(file_name):
|
||||
logger.warning("Detected non-ASCII file name: " + file_path)
|
||||
|
||||
if get_wav_duration(file_path) < 0.3:
|
||||
logger.info("Skip too short audio: " + file_path)
|
||||
continue
|
||||
|
||||
wavs.append(file_path)
|
||||
|
||||
shuffle(wavs)
|
||||
train += wavs[2:]
|
||||
val += wavs[:2]
|
||||
|
||||
shuffle(train)
|
||||
shuffle(val)
|
||||
|
||||
logger.info("Writing" + args.train_list)
|
||||
|
||||
logger.info("Writing " + args.train_list)
|
||||
with open(args.train_list, "w") as f:
|
||||
for fname in tqdm(train):
|
||||
wavpath = fname
|
||||
f.write(wavpath + "\n")
|
||||
|
||||
logger.info("Writing" + args.val_list)
|
||||
|
||||
logger.info("Writing " + args.val_list)
|
||||
with open(args.val_list, "w") as f:
|
||||
for fname in tqdm(val):
|
||||
wavpath = fname
|
||||
|
@ -113,7 +113,7 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu")
|
||||
logger.info(f"Rank {rank} uses device {device}")
|
||||
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
|
||||
logger.info(f"Loaded speech encoder for rank {rank}")
|
||||
for filename in tqdm(file_chunk):
|
||||
for filename in tqdm(file_chunk, position = rank):
|
||||
process_one(filename, hmodel, f0p, device, diff, mel_extractor)
|
||||
|
||||
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
|
||||
@ -124,7 +124,7 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device)
|
||||
end = int((i + 1) * len(filenames) / num_processes)
|
||||
file_chunk = filenames[start:end]
|
||||
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
|
||||
for task in tqdm(tasks):
|
||||
for task in tqdm(tasks, position = 0):
|
||||
task.result()
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -149,10 +149,10 @@ if __name__ == "__main__":
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print(speech_encoder)
|
||||
logger.info("Using device: ", device)
|
||||
logger.info("Using device: " + str(device))
|
||||
logger.info("Using SpeechEncoder: " + speech_encoder)
|
||||
logger.info("Using extractor: " + f0p)
|
||||
logger.info("Using diff Mode: " + str( args.use_diff))
|
||||
logger.info("Using diff Mode: " + str(args.use_diff))
|
||||
|
||||
if args.use_diff:
|
||||
print("use_diff")
|
||||
|
0
pretrain/__init__.py
Normal file
0
pretrain/__init__.py
Normal file
@ -78,10 +78,9 @@
|
||||
"#@markdown\n",
|
||||
"\n",
|
||||
"!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n",
|
||||
"%pip uninstall -y torchdata torchtext\n",
|
||||
"%pip install --upgrade pip setuptools numpy numba\n",
|
||||
"%pip install pyworld praat-parselmouth fairseq tensorboardX torchcrepe librosa==0.9.1 pyyaml pynvml pyloudnorm faiss-gpu\n",
|
||||
"%pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118\n",
|
||||
"%cd /content/so-vits-svc\n",
|
||||
"%pip install --upgrade pip setuptools\n",
|
||||
"%pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118\n",
|
||||
"exit()"
|
||||
]
|
||||
},
|
||||
@ -163,8 +162,8 @@
|
||||
"#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n",
|
||||
"\n",
|
||||
"download_pretrained_model = True #@param {type:\"boolean\"}\n",
|
||||
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\"] {allow-input: true}\n",
|
||||
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\"] {allow-input: true}\n",
|
||||
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_D_320000.pth\"] {allow-input: true}\n",
|
||||
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_G_320000.pth\"] {allow-input: true}\n",
|
||||
"\n",
|
||||
"download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n",
|
||||
"diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n",
|
||||
@ -317,13 +316,17 @@
|
||||
"#@markdown\n",
|
||||
"%cd /content/so-vits-svc\n",
|
||||
"\n",
|
||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
|
||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
|
||||
"use_diff = True #@param {type:\"boolean\"}\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
||||
"\n",
|
||||
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
|
||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"diff_param = \"\"\n",
|
||||
"if use_diff:\n",
|
||||
" diff_param = \"--use_diff\"\n",
|
||||
@ -624,7 +627,7 @@
|
||||
"if auto_predict_f0:\n",
|
||||
" apf = \" -a \"\n",
|
||||
"\n",
|
||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
|
||||
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
|
||||
"\n",
|
||||
"enhance = False #@param {type:\"boolean\"}\n",
|
||||
"ehc = \"\"\n",
|
||||
@ -644,6 +647,9 @@
|
||||
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
|
||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
|
||||
"\n",
|
||||
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
|
||||
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
|
||||
"\n",
|
||||
"if not os.path.exists(output):\n",
|
||||
" !curl -L {url} -o {output}\n",
|
||||
"\n",
|
||||
|
Loading…
Reference in New Issue
Block a user