mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-09 04:27:31 +08:00
feat: add compress_model
This commit is contained in:
parent
49a3b0f2ae
commit
565a6d58de
69
compress_model.py
Normal file
69
compress_model.py
Normal file
@ -0,0 +1,69 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
import utils
|
||||
from models import SynthesizerTrn
|
||||
|
||||
|
||||
def copyStateDict(state_dict):
|
||||
if list(state_dict.keys())[0].startswith('module'):
|
||||
start_idx = 1
|
||||
else:
|
||||
start_idx = 0
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = ','.join(k.split('.')[start_idx:])
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def removeOptimizer(config: str, input_model: str, output_model: str):
|
||||
hps = utils.get_hparams_from_file(config)
|
||||
|
||||
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model)
|
||||
|
||||
optim_g = torch.optim.AdamW(net_g.parameters(),
|
||||
hps.train.learning_rate,
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps)
|
||||
|
||||
state_dict_g = torch.load(input_model, map_location="cpu")
|
||||
new_dict_g = copyStateDict(state_dict_g)
|
||||
keys = []
|
||||
for k, v in new_dict_g['model'].items():
|
||||
keys.append(k)
|
||||
|
||||
new_dict_g = {k: new_dict_g['model'][k] for k in keys}
|
||||
|
||||
torch.save(
|
||||
{
|
||||
'model': new_dict_g,
|
||||
'iteration': 0,
|
||||
'optimizer': optim_g.state_dict(),
|
||||
'learning_rate': 0.0001
|
||||
}, output_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c",
|
||||
"--config",
|
||||
type=str,
|
||||
default='configs/config.json')
|
||||
parser.add_argument("-i", "--input", type=str)
|
||||
parser.add_argument("-o", "--output", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output = args.output
|
||||
|
||||
if output is None:
|
||||
import os.path
|
||||
filename, ext = os.path.splitext(args.input)
|
||||
output = filename + "_release" + ext
|
||||
|
||||
removeOptimizer(args.config, args.input, output)
|
Loading…
Reference in New Issue
Block a user