mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-08 11:57:43 +08:00
Debug BF16 and RMVPE
This commit is contained in:
parent
36787124f4
commit
f808d8e60b
2
train.py
2
train.py
@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
if writers is not None:
|
||||
writer, writer_eval = writers
|
||||
|
||||
half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16
|
||||
half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16
|
||||
|
||||
# train_loader.batch_sampler.set_epoch(epoch)
|
||||
global global_step
|
||||
|
2
utils.py
2
utils.py
@ -99,7 +99,7 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
|
||||
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
||||
elif f0_predictor == "rmvpe":
|
||||
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
|
||||
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"])
|
||||
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
|
||||
else:
|
||||
raise Exception("Unknown f0 predictor")
|
||||
return f0_predictor_object
|
||||
|
Loading…
Reference in New Issue
Block a user