2023-06-26 14:57:53 +08:00
import argparse
import logging
2023-03-10 18:11:04 +08:00
import os
2023-05-16 13:17:51 +08:00
import random
2023-06-26 14:57:53 +08:00
from concurrent . futures import ProcessPoolExecutor
from glob import glob
from random import shuffle
2023-07-22 14:55:20 +08:00
2023-06-20 03:38:13 +08:00
import librosa
import numpy as np
2023-06-26 14:57:53 +08:00
import torch
2023-07-22 20:18:20 +08:00
import torch . multiprocessing as mp
2023-07-22 23:02:52 +08:00
from loguru import logger
2023-03-10 18:11:04 +08:00
from tqdm import tqdm
2023-06-26 14:57:53 +08:00
import diffusion . logger . utils as du
import utils
2023-06-20 03:38:13 +08:00
from diffusion . vocoder import Vocoder
2023-03-24 13:00:14 +08:00
from modules . mel_processing import spectrogram_torch
2023-03-10 18:11:04 +08:00
2023-03-24 13:00:14 +08:00
logging . getLogger ( " numba " ) . setLevel ( logging . WARNING )
2023-05-16 13:17:51 +08:00
logging . getLogger ( " matplotlib " ) . setLevel ( logging . WARNING )
2023-03-10 18:11:04 +08:00
hps = utils . get_hparams_from_file ( " configs/config.json " )
2023-05-16 13:17:51 +08:00
dconfig = du . load_config ( " configs/diffusion.yaml " )
2023-03-10 18:11:04 +08:00
sampling_rate = hps . data . sampling_rate
hop_length = hps . data . hop_length
2023-05-14 15:22:20 +08:00
speech_encoder = hps [ " model " ] [ " speech_encoder " ]
2023-03-10 18:11:04 +08:00
2023-05-16 13:17:51 +08:00
2023-07-23 23:27:58 +08:00
def process_one ( filename , hmodel , f0p , device , diff = False , mel_extractor = None ) :
2023-03-10 18:11:04 +08:00
wav , sr = librosa . load ( filename , sr = sampling_rate )
2023-05-16 13:17:51 +08:00
audio_norm = torch . FloatTensor ( wav )
audio_norm = audio_norm . unsqueeze ( 0 )
2023-03-10 18:11:04 +08:00
soft_path = filename + " .soft.pt "
if not os . path . exists ( soft_path ) :
wav16k = librosa . resample ( wav , orig_sr = sampling_rate , target_sr = 16000 )
2023-03-17 07:10:47 +08:00
wav16k = torch . from_numpy ( wav16k ) . to ( device )
2023-05-14 14:39:07 +08:00
c = hmodel . encoder ( wav16k )
2023-03-10 18:11:04 +08:00
torch . save ( c . cpu ( ) , soft_path )
2023-06-20 03:38:13 +08:00
2023-03-10 18:11:04 +08:00
f0_path = filename + " .f0.npy "
if not os . path . exists ( f0_path ) :
2023-05-14 14:39:07 +08:00
f0_predictor = utils . get_f0_predictor ( f0p , sampling_rate = sampling_rate , hop_length = hop_length , device = None , threshold = 0.05 )
2023-05-13 23:45:56 +08:00
f0 , uv = f0_predictor . compute_f0_uv (
2023-05-13 15:33:40 +08:00
wav
2023-03-24 13:00:14 +08:00
)
2023-05-13 23:45:56 +08:00
np . save ( f0_path , np . asanyarray ( ( f0 , uv ) , dtype = object ) )
2023-06-20 03:38:13 +08:00
2023-03-24 13:00:14 +08:00
spec_path = filename . replace ( " .wav " , " .spec.pt " )
if not os . path . exists ( spec_path ) :
# Process spectrogram
# The following code can't be replaced by torch.FloatTensor(wav)
# because load_wav_to_torch return a tensor that need to be normalized
if sr != hps . data . sampling_rate :
raise ValueError (
" {} SR doesn ' t match target {} SR " . format (
sr , hps . data . sampling_rate
)
)
2023-06-20 03:38:13 +08:00
2023-05-16 13:17:51 +08:00
#audio_norm = audio / hps.data.max_wav_value
2023-06-20 03:38:13 +08:00
2023-03-24 13:00:14 +08:00
spec = spectrogram_torch (
audio_norm ,
hps . data . filter_length ,
hps . data . sampling_rate ,
hps . data . hop_length ,
hps . data . win_length ,
center = False ,
)
spec = torch . squeeze ( spec , 0 )
torch . save ( spec , spec_path )
2023-05-28 21:47:32 +08:00
if diff or hps . model . vol_embedding :
2023-05-16 13:17:51 +08:00
volume_path = filename + " .vol.npy "
volume_extractor = utils . Volume_Extractor ( hop_length )
if not os . path . exists ( volume_path ) :
volume = volume_extractor . extract ( audio_norm )
np . save ( volume_path , volume . to ( ' cpu ' ) . numpy ( ) )
2023-05-28 21:47:32 +08:00
if diff :
2023-05-16 13:17:51 +08:00
mel_path = filename + " .mel.npy "
if not os . path . exists ( mel_path ) and mel_extractor is not None :
mel_t = mel_extractor . extract ( audio_norm . to ( device ) , sampling_rate )
mel = mel_t . squeeze ( ) . to ( ' cpu ' ) . numpy ( )
np . save ( mel_path , mel )
aug_mel_path = filename + " .aug_mel.npy "
aug_vol_path = filename + " .aug_vol.npy "
max_amp = float ( torch . max ( torch . abs ( audio_norm ) ) ) + 1e-5
max_shift = min ( 1 , np . log10 ( 1 / max_amp ) )
log10_vol_shift = random . uniform ( - 1 , max_shift )
keyshift = random . uniform ( - 5 , 5 )
if mel_extractor is not None :
aug_mel_t = mel_extractor . extract ( audio_norm * ( 10 * * log10_vol_shift ) , sampling_rate , keyshift = keyshift )
aug_mel = aug_mel_t . squeeze ( ) . to ( ' cpu ' ) . numpy ( )
aug_vol = volume_extractor . extract ( audio_norm * ( 10 * * log10_vol_shift ) )
if not os . path . exists ( aug_mel_path ) :
np . save ( aug_mel_path , np . asanyarray ( ( aug_mel , keyshift ) , dtype = object ) )
if not os . path . exists ( aug_vol_path ) :
np . save ( aug_vol_path , aug_vol . to ( ' cpu ' ) . numpy ( ) )
2023-07-23 22:12:04 +08:00
def process_batch ( file_chunk , f0p , diff = False , mel_extractor = None , device = " cpu " ) :
2023-07-22 22:01:44 +08:00
logger . info ( " Loading speech encoder for content... " )
2023-07-22 14:30:54 +08:00
rank = mp . current_process ( ) . _identity
rank = rank [ 0 ] if len ( rank ) > 0 else 0
if torch . cuda . is_available ( ) :
gpu_id = rank % torch . cuda . device_count ( )
device = torch . device ( f " cuda: { gpu_id } " )
2023-07-22 22:01:44 +08:00
logger . info ( f " Rank { rank } uses device { device } " )
2023-06-20 03:38:13 +08:00
hmodel = utils . get_speech_encoder ( speech_encoder , device = device )
2023-07-22 22:01:44 +08:00
logger . info ( f " Loaded speech encoder for rank { rank } " )
2023-07-31 00:03:07 +08:00
for filename in tqdm ( file_chunk , position = rank ) :
2023-07-23 23:27:58 +08:00
process_one ( filename , hmodel , f0p , device , diff , mel_extractor )
2023-06-20 03:38:13 +08:00
2023-07-23 22:12:04 +08:00
def parallel_process ( filenames , num_processes , f0p , diff , mel_extractor , device ) :
2023-06-20 03:38:13 +08:00
with ProcessPoolExecutor ( max_workers = num_processes ) as executor :
tasks = [ ]
for i in range ( num_processes ) :
start = int ( i * len ( filenames ) / num_processes )
end = int ( ( i + 1 ) * len ( filenames ) / num_processes )
file_chunk = filenames [ start : end ]
2023-07-23 22:12:04 +08:00
tasks . append ( executor . submit ( process_batch , file_chunk , f0p , diff , mel_extractor , device = device ) )
2023-07-31 00:03:07 +08:00
for task in tqdm ( tasks , position = 0 ) :
2023-06-20 03:38:13 +08:00
task . result ( )
2023-03-10 18:11:04 +08:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
2023-07-23 22:12:04 +08:00
parser . add_argument ( ' -d ' , ' --device ' , type = str , default = None )
2023-03-24 13:00:14 +08:00
parser . add_argument (
" --in_dir " , type = str , default = " dataset/44k " , help = " path to input dir "
)
2023-06-20 03:38:13 +08:00
parser . add_argument (
2023-05-16 13:17:51 +08:00
' --use_diff ' , action = ' store_true ' , help = ' Whether to use the diffusion model '
2023-05-14 14:39:07 +08:00
)
2023-06-20 03:38:13 +08:00
parser . add_argument (
2023-08-02 02:06:49 +08:00
' --f0_predictor ' , type = str , default = " rmvpe " , help = ' Select F0 predictor, can select crepe,pm,dio,harvest,rmvpe,fcpe|default: pm(note: crepe is original F0 using mean filter) '
2023-05-15 01:23:46 +08:00
)
2023-06-20 03:38:13 +08:00
parser . add_argument (
2023-05-31 02:12:46 +08:00
' --num_processes ' , type = int , default = 1 , help = ' You are advised to set the number of processes to the same as the number of CPU cores '
)
2023-03-10 18:11:04 +08:00
args = parser . parse_args ( )
2023-05-14 14:39:07 +08:00
f0p = args . f0_predictor
2023-07-23 22:12:04 +08:00
device = args . device
if device is None :
device = torch . device ( " cuda:0 " if torch . cuda . is_available ( ) else " cpu " )
2023-05-14 14:39:07 +08:00
print ( speech_encoder )
2023-07-31 00:03:07 +08:00
logger . info ( " Using device: " + str ( device ) )
2023-07-23 23:05:02 +08:00
logger . info ( " Using SpeechEncoder: " + speech_encoder )
logger . info ( " Using extractor: " + f0p )
2023-07-31 00:03:07 +08:00
logger . info ( " Using diff Mode: " + str ( args . use_diff ) )
2023-07-23 23:05:02 +08:00
2023-05-16 13:17:51 +08:00
if args . use_diff :
print ( " use_diff " )
print ( " Loading Mel Extractor... " )
2023-07-23 22:12:04 +08:00
mel_extractor = Vocoder ( dconfig . vocoder . type , dconfig . vocoder . ckpt , device = device )
2023-05-16 13:17:51 +08:00
print ( " Loaded Mel Extractor. " )
else :
mel_extractor = None
2023-03-24 13:00:14 +08:00
filenames = glob ( f " { args . in_dir } /*/*.wav " , recursive = True ) # [:10]
2023-03-10 18:11:04 +08:00
shuffle ( filenames )
2023-07-22 14:30:54 +08:00
mp . set_start_method ( " spawn " , force = True )
2023-06-20 03:38:13 +08:00
2023-05-31 02:12:46 +08:00
num_processes = args . num_processes
2023-06-20 03:38:13 +08:00
if num_processes == 0 :
num_processes = os . cpu_count ( )
2023-07-23 22:12:04 +08:00
parallel_process ( filenames , num_processes , f0p , args . use_diff , mel_extractor , device )