mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-09 04:27:31 +08:00
128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
import json
|
|
import os
|
|
|
|
import torch
|
|
import yaml
|
|
|
|
|
|
def traverse_dir(
|
|
root_dir,
|
|
extensions,
|
|
amount=None,
|
|
str_include=None,
|
|
str_exclude=None,
|
|
is_pure=False,
|
|
is_sort=False,
|
|
is_ext=True):
|
|
|
|
file_list = []
|
|
cnt = 0
|
|
for root, _, files in os.walk(root_dir):
|
|
for file in files:
|
|
if any([file.endswith(f".{ext}") for ext in extensions]):
|
|
# path
|
|
mix_path = os.path.join(root, file)
|
|
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
|
|
|
|
# amount
|
|
if (amount is not None) and (cnt == amount):
|
|
if is_sort:
|
|
file_list.sort()
|
|
return file_list
|
|
|
|
# check string
|
|
if (str_include is not None) and (str_include not in pure_path):
|
|
continue
|
|
if (str_exclude is not None) and (str_exclude in pure_path):
|
|
continue
|
|
|
|
if not is_ext:
|
|
ext = pure_path.split('.')[-1]
|
|
pure_path = pure_path[:-(len(ext)+1)]
|
|
file_list.append(pure_path)
|
|
cnt += 1
|
|
if is_sort:
|
|
file_list.sort()
|
|
return file_list
|
|
|
|
|
|
|
|
class DotDict(dict):
|
|
def __getattr__(*args):
|
|
val = dict.get(*args)
|
|
return DotDict(val) if type(val) is dict else val
|
|
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
|
|
def get_network_paras_amount(model_dict):
|
|
info = dict()
|
|
for model_name, model in model_dict.items():
|
|
# all_params = sum(p.numel() for p in model.parameters())
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
info[model_name] = trainable_params
|
|
return info
|
|
|
|
|
|
def load_config(path_config):
|
|
with open(path_config, "r") as config:
|
|
args = yaml.safe_load(config)
|
|
args = DotDict(args)
|
|
# print(args)
|
|
return args
|
|
|
|
def save_config(path_config,config):
|
|
config = dict(config)
|
|
with open(path_config, "w") as f:
|
|
yaml.dump(config, f)
|
|
|
|
def to_json(path_params, path_json):
|
|
params = torch.load(path_params, map_location=torch.device('cpu'))
|
|
raw_state_dict = {}
|
|
for k, v in params.items():
|
|
val = v.flatten().numpy().tolist()
|
|
raw_state_dict[k] = val
|
|
|
|
with open(path_json, 'w') as outfile:
|
|
json.dump(raw_state_dict, outfile,indent= "\t")
|
|
|
|
|
|
def convert_tensor_to_numpy(tensor, is_squeeze=True):
|
|
if is_squeeze:
|
|
tensor = tensor.squeeze()
|
|
if tensor.requires_grad:
|
|
tensor = tensor.detach()
|
|
if tensor.is_cuda:
|
|
tensor = tensor.cpu()
|
|
return tensor.numpy()
|
|
|
|
|
|
def load_model(
|
|
expdir,
|
|
model,
|
|
optimizer,
|
|
name='model',
|
|
postfix='',
|
|
device='cpu'):
|
|
if postfix == '':
|
|
postfix = '_' + postfix
|
|
path = os.path.join(expdir, name+postfix)
|
|
path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
|
|
global_step = 0
|
|
if len(path_pt) > 0:
|
|
steps = [s[len(path):] for s in path_pt]
|
|
maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
|
|
if maxstep >= 0:
|
|
path_pt = path+str(maxstep)+'.pt'
|
|
else:
|
|
path_pt = path+'best.pt'
|
|
print(' [*] restoring model from', path_pt)
|
|
ckpt = torch.load(path_pt, map_location=torch.device(device))
|
|
global_step = ckpt['global_step']
|
|
model.load_state_dict(ckpt['model'], strict=False)
|
|
if ckpt.get("optimizer") is not None:
|
|
optimizer.load_state_dict(ckpt['optimizer'])
|
|
return global_step, model, optimizer
|