Merge pull request #289 from xdedss/add-local-checkpoint

Add local checkpoint selection in webui
This commit is contained in:
YuriHead 2023-07-11 23:13:34 +08:00 committed by GitHub
commit 3147c85203
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 7 deletions

2
.gitignore vendored
View File

@ -163,3 +163,5 @@ filelists/val.txt
.idea/inspectionProfiles/Project_Default.xml
pretrain/
.vscode/launch.json
trained/**/

View File

View File

@ -5,6 +5,7 @@ import re
import subprocess
import time
import traceback
import glob
from itertools import chain
from pathlib import Path
@ -30,6 +31,8 @@ model = None
spk = None
debug = False
local_model_root = './trained'
cuda = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
@ -75,14 +78,23 @@ def updata_mix_info(files):
traceback.print_exc()
raise gr.Error(e)
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix):
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection):
global model
try:
device = cuda[device] if "CUDA" in device else device
cluster_filepath = os.path.split(cluster_model_path.name) if cluster_model_path is not None else "no_cluster"
# get model and config path
if (local_model_enabled):
# local path
model_path = glob.glob(os.path.join(local_model_selection, '*.pth'))[0]
config_path = glob.glob(os.path.join(local_model_selection, '*.json'))[0]
else:
# upload from webpage
model_path = model_path.name
config_path = config_path.name
fr = ".pkl" in cluster_filepath[1]
model = Svc(model_path.name,
config_path.name,
model = Svc(model_path,
config_path,
device=device if device != "Auto" else None,
cluster_model_path = cluster_model_path.name if cluster_model_path is not None else "",
nsf_hifigan_enhance=enhance,
@ -239,6 +251,22 @@ def model_compression(_model):
removeOptimizer(_model.name, output_path)
return f"模型已成功被保存在了{output_path}"
def scan_local_models():
res = []
candidates = glob.glob(os.path.join(local_model_root, '**', '*.json'), recursive=True)
candidates = set([os.path.dirname(c) for c in candidates])
for candidate in candidates:
jsons = glob.glob(os.path.join(candidate, '*.json'))
pths = glob.glob(os.path.join(candidate, '*.pth'))
if (len(jsons) == 1 and len(pths) == 1):
# must contain exactly one json and one pth file
res.append(candidate)
return res
def local_model_refresh_fn():
choices = scan_local_models()
return gr.Dropdown.update(choices=choices)
def debug_change():
global debug
debug = debug_button.value
@ -260,9 +288,17 @@ with gr.Blocks(
gr.Markdown(value="""
<font size=2> 模型设置</font>
""")
with gr.Row():
model_path = gr.File(label="选择模型文件")
config_path = gr.File(label="选择配置文件")
with gr.Tabs():
# invisible checkbox that tracks tab status
local_model_enabled = gr.Checkbox(value=False, visible=False)
with gr.TabItem('上传') as local_model_tab_upload:
with gr.Row():
model_path = gr.File(label="选择模型文件")
config_path = gr.File(label="选择配置文件")
with gr.TabItem('本地') as local_model_tab_local:
gr.Markdown(f'模型应当放置于f{local_model_root}文件夹下')
local_model_refresh_btn = gr.Button('刷新本地模型列表')
local_model_selection = gr.Dropdown(label='选择模型文件夹', choices=[], interactive=True)
with gr.Row():
diff_model_path = gr.File(label="选择扩散模型文件")
diff_config_path = gr.File(label="选择扩散模型配置文件")
@ -374,11 +410,17 @@ with gr.Blocks(
<font size=2> WebUI设置</font>
""")
debug_button = gr.Checkbox(label="Debug模式如果向社区反馈BUG需要打开打开后控制台可以显示具体错误提示", value=debug)
# refresh local model list
local_model_refresh_btn.click(local_model_refresh_fn, outputs=local_model_selection)
# set local enabled/disabled on tab switch
local_model_tab_upload.select(lambda: False, outputs=local_model_enabled)
local_model_tab_local.select(lambda: True, outputs=local_model_enabled)
vc_submit.click(vc_fn, [sid, vc_input3, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2])
vc_submit2.click(vc_fn2, [text2tts, tts_lang, tts_gender, tts_rate, tts_volume, sid, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2])
debug_button.change(debug_change,[],[])
model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix],[sid,sid_output])
model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection],[sid,sid_output])
model_unload_button.click(modelUnload,[],[sid,sid_output])
os.system("start http://127.0.0.1:7860")
app.launch()