From 9bea608c98966f92934c287ced7fd8fdaa4abc61 Mon Sep 17 00:00:00 2001 From: magic-akari Date: Sun, 23 Jul 2023 22:12:04 +0800 Subject: [PATCH] fix(preprocess): pass device --- preprocess_hubert_f0.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/preprocess_hubert_f0.py b/preprocess_hubert_f0.py index 0e4a0c9..0ace6de 100644 --- a/preprocess_hubert_f0.py +++ b/preprocess_hubert_f0.py @@ -104,7 +104,8 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None): if not os.path.exists(aug_vol_path): np.save(aug_vol_path,aug_vol.to('cpu').numpy()) -def process_batch(file_chunk, f0p, diff=False, mel_extractor=None): + +def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"): print("Loading speech encoder for content...") rank = mp.current_process()._identity rank = rank[0] if len(rank) > 0 else 0 @@ -117,19 +118,20 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None): for filename in tqdm(file_chunk): process_one(filename, hmodel, f0p, diff, mel_extractor) -def parallel_process(filenames, num_processes, f0p, diff, mel_extractor): +def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device): with ProcessPoolExecutor(max_workers=num_processes) as executor: tasks = [] for i in range(num_processes): start = int(i * len(filenames) / num_processes) end = int((i + 1) * len(filenames) / num_processes) file_chunk = filenames[start:end] - tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor)) + tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device)) for task in tqdm(tasks): task.result() if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument('-d', '--device', type=str, default=None) parser.add_argument( "--in_dir", type=str, default="dataset/44k", help="path to input dir" ) @@ -144,13 +146,18 @@ if __name__ == "__main__": ) args = parser.parse_args() f0p = args.f0_predictor + device = args.device + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(speech_encoder) print(f0p) - print(args.use_diff) + print("use_diff: ", args.use_diff) + print("device: ", device) if args.use_diff: print("use_diff") print("Loading Mel Extractor...") - mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = "cuda:0") + mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device) print("Loaded Mel Extractor.") else: mel_extractor = None @@ -161,5 +168,5 @@ if __name__ == "__main__": num_processes = args.num_processes if num_processes == 0: num_processes = os.cpu_count() - - parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor) + + parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor, device)