mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-09 04:27:31 +08:00
Merge pull request #289 from xdedss/add-local-checkpoint
Add local checkpoint selection in webui
This commit is contained in:
commit
3147c85203
2
.gitignore
vendored
2
.gitignore
vendored
@ -163,3 +163,5 @@ filelists/val.txt
|
||||
.idea/inspectionProfiles/Project_Default.xml
|
||||
pretrain/
|
||||
.vscode/launch.json
|
||||
|
||||
trained/**/
|
||||
|
0
trained/put_trained_checkpoints_here
Normal file
0
trained/put_trained_checkpoints_here
Normal file
56
webUI.py
56
webUI.py
@ -5,6 +5,7 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
import glob
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
@ -30,6 +31,8 @@ model = None
|
||||
spk = None
|
||||
debug = False
|
||||
|
||||
local_model_root = './trained'
|
||||
|
||||
cuda = {}
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
@ -75,14 +78,23 @@ def updata_mix_info(files):
|
||||
traceback.print_exc()
|
||||
raise gr.Error(e)
|
||||
|
||||
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix):
|
||||
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection):
|
||||
global model
|
||||
try:
|
||||
device = cuda[device] if "CUDA" in device else device
|
||||
cluster_filepath = os.path.split(cluster_model_path.name) if cluster_model_path is not None else "no_cluster"
|
||||
# get model and config path
|
||||
if (local_model_enabled):
|
||||
# local path
|
||||
model_path = glob.glob(os.path.join(local_model_selection, '*.pth'))[0]
|
||||
config_path = glob.glob(os.path.join(local_model_selection, '*.json'))[0]
|
||||
else:
|
||||
# upload from webpage
|
||||
model_path = model_path.name
|
||||
config_path = config_path.name
|
||||
fr = ".pkl" in cluster_filepath[1]
|
||||
model = Svc(model_path.name,
|
||||
config_path.name,
|
||||
model = Svc(model_path,
|
||||
config_path,
|
||||
device=device if device != "Auto" else None,
|
||||
cluster_model_path = cluster_model_path.name if cluster_model_path is not None else "",
|
||||
nsf_hifigan_enhance=enhance,
|
||||
@ -239,6 +251,22 @@ def model_compression(_model):
|
||||
removeOptimizer(_model.name, output_path)
|
||||
return f"模型已成功被保存在了{output_path}"
|
||||
|
||||
def scan_local_models():
|
||||
res = []
|
||||
candidates = glob.glob(os.path.join(local_model_root, '**', '*.json'), recursive=True)
|
||||
candidates = set([os.path.dirname(c) for c in candidates])
|
||||
for candidate in candidates:
|
||||
jsons = glob.glob(os.path.join(candidate, '*.json'))
|
||||
pths = glob.glob(os.path.join(candidate, '*.pth'))
|
||||
if (len(jsons) == 1 and len(pths) == 1):
|
||||
# must contain exactly one json and one pth file
|
||||
res.append(candidate)
|
||||
return res
|
||||
|
||||
def local_model_refresh_fn():
|
||||
choices = scan_local_models()
|
||||
return gr.Dropdown.update(choices=choices)
|
||||
|
||||
def debug_change():
|
||||
global debug
|
||||
debug = debug_button.value
|
||||
@ -260,9 +288,17 @@ with gr.Blocks(
|
||||
gr.Markdown(value="""
|
||||
<font size=2> 模型设置</font>
|
||||
""")
|
||||
with gr.Row():
|
||||
model_path = gr.File(label="选择模型文件")
|
||||
config_path = gr.File(label="选择配置文件")
|
||||
with gr.Tabs():
|
||||
# invisible checkbox that tracks tab status
|
||||
local_model_enabled = gr.Checkbox(value=False, visible=False)
|
||||
with gr.TabItem('上传') as local_model_tab_upload:
|
||||
with gr.Row():
|
||||
model_path = gr.File(label="选择模型文件")
|
||||
config_path = gr.File(label="选择配置文件")
|
||||
with gr.TabItem('本地') as local_model_tab_local:
|
||||
gr.Markdown(f'模型应当放置于f{local_model_root}文件夹下')
|
||||
local_model_refresh_btn = gr.Button('刷新本地模型列表')
|
||||
local_model_selection = gr.Dropdown(label='选择模型文件夹', choices=[], interactive=True)
|
||||
with gr.Row():
|
||||
diff_model_path = gr.File(label="选择扩散模型文件")
|
||||
diff_config_path = gr.File(label="选择扩散模型配置文件")
|
||||
@ -374,11 +410,17 @@ with gr.Blocks(
|
||||
<font size=2> WebUI设置</font>
|
||||
""")
|
||||
debug_button = gr.Checkbox(label="Debug模式,如果向社区反馈BUG需要打开,打开后控制台可以显示具体错误提示", value=debug)
|
||||
# refresh local model list
|
||||
local_model_refresh_btn.click(local_model_refresh_fn, outputs=local_model_selection)
|
||||
# set local enabled/disabled on tab switch
|
||||
local_model_tab_upload.select(lambda: False, outputs=local_model_enabled)
|
||||
local_model_tab_local.select(lambda: True, outputs=local_model_enabled)
|
||||
|
||||
vc_submit.click(vc_fn, [sid, vc_input3, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2])
|
||||
vc_submit2.click(vc_fn2, [text2tts, tts_lang, tts_gender, tts_rate, tts_volume, sid, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2])
|
||||
|
||||
debug_button.change(debug_change,[],[])
|
||||
model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix],[sid,sid_output])
|
||||
model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection],[sid,sid_output])
|
||||
model_unload_button.click(modelUnload,[],[sid,sid_output])
|
||||
os.system("start http://127.0.0.1:7860")
|
||||
app.launch()
|
||||
|
Loading…
Reference in New Issue
Block a user