Update preprocess_hubert_f0.py

This commit is contained in:
Stardust·减 2023-07-22 22:01:44 +08:00 committed by GitHub
parent d07d92b61a
commit 1cdccce44a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ import random
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from random import shuffle
from loguru import logger
import librosa
import numpy as np
@ -28,7 +29,6 @@ speech_encoder = hps["model"]["speech_encoder"]
def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
# print(filename)
wav, sr = librosa.load(filename, sr=sampling_rate)
audio_norm = torch.FloatTensor(wav)
audio_norm = audio_norm.unsqueeze(0)
@ -104,15 +104,15 @@ def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...")
logger.info("Loading speech encoder for content...")
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print(f"Rank {rank} uses device {device}")
logger.info(f"Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.")
logger.info(f"Loaded speech encoder for rank {rank}")
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, rank, diff, mel_extractor)
@ -144,7 +144,9 @@ if __name__ == "__main__":
args = parser.parse_args()
f0p = args.f0_predictor
print(speech_encoder)
print(f0p)
logger.info("Using " + speech_encoder + " SpeechEncoder")
logger.info("Using " + f0p + "f0 extractor")
logger.info("Using diff Mode:")
print(args.use_diff)
if args.use_diff:
print("use_diff")