mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-08 11:57:43 +08:00
fix(preprocess): pass device
This commit is contained in:
parent
43fb46e64d
commit
9bea608c98
@ -104,7 +104,8 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
|
||||
if not os.path.exists(aug_vol_path):
|
||||
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
|
||||
|
||||
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
|
||||
|
||||
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"):
|
||||
print("Loading speech encoder for content...")
|
||||
rank = mp.current_process()._identity
|
||||
rank = rank[0] if len(rank) > 0 else 0
|
||||
@ -117,19 +118,20 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
|
||||
for filename in tqdm(file_chunk):
|
||||
process_one(filename, hmodel, f0p, diff, mel_extractor)
|
||||
|
||||
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
|
||||
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
|
||||
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
||||
tasks = []
|
||||
for i in range(num_processes):
|
||||
start = int(i * len(filenames) / num_processes)
|
||||
end = int((i + 1) * len(filenames) / num_processes)
|
||||
file_chunk = filenames[start:end]
|
||||
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
|
||||
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
|
||||
for task in tqdm(tasks):
|
||||
task.result()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--device', type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--in_dir", type=str, default="dataset/44k", help="path to input dir"
|
||||
)
|
||||
@ -144,13 +146,18 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
f0p = args.f0_predictor
|
||||
device = args.device
|
||||
if device is None:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print(speech_encoder)
|
||||
print(f0p)
|
||||
print(args.use_diff)
|
||||
print("use_diff: ", args.use_diff)
|
||||
print("device: ", device)
|
||||
if args.use_diff:
|
||||
print("use_diff")
|
||||
print("Loading Mel Extractor...")
|
||||
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = "cuda:0")
|
||||
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device)
|
||||
print("Loaded Mel Extractor.")
|
||||
else:
|
||||
mel_extractor = None
|
||||
@ -161,5 +168,5 @@ if __name__ == "__main__":
|
||||
num_processes = args.num_processes
|
||||
if num_processes == 0:
|
||||
num_processes = os.cpu_count()
|
||||
|
||||
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)
|
||||
|
||||
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor, device)
|
||||
|
Loading…
Reference in New Issue
Block a user