fix(preprocess): pass device

This commit is contained in:
magic-akari 2023-07-23 22:12:04 +08:00
parent 43fb46e64d
commit 9bea608c98
No known key found for this signature in database
GPG Key ID: EC005B1159285BDD

View File

@ -104,7 +104,8 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
if not os.path.exists(aug_vol_path):
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"):
print("Loading speech encoder for content...")
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
@ -117,19 +118,20 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, diff, mel_extractor)
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
with ProcessPoolExecutor(max_workers=num_processes) as executor:
tasks = []
for i in range(num_processes):
start = int(i * len(filenames) / num_processes)
end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
for task in tqdm(tasks):
task.result()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--device', type=str, default=None)
parser.add_argument(
"--in_dir", type=str, default="dataset/44k", help="path to input dir"
)
@ -144,13 +146,18 @@ if __name__ == "__main__":
)
args = parser.parse_args()
f0p = args.f0_predictor
device = args.device
if device is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(speech_encoder)
print(f0p)
print(args.use_diff)
print("use_diff: ", args.use_diff)
print("device: ", device)
if args.use_diff:
print("use_diff")
print("Loading Mel Extractor...")
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = "cuda:0")
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device)
print("Loaded Mel Extractor.")
else:
mel_extractor = None
@ -161,5 +168,5 @@ if __name__ == "__main__":
num_processes = args.num_processes
if num_processes == 0:
num_processes = os.cpu_count()
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor, device)