diff --git a/utils.py b/utils.py index b9fe21a..62ca6ee 100644 --- a/utils.py +++ b/utils.py @@ -7,12 +7,15 @@ import re import subprocess import sys +import traceback import faiss +from sklearn.cluster import MiniBatchKMeans import librosa import numpy as np import torch from scipy.io.wavfile import read from torch.nn import functional as F +from multiprocessing import cpu_count MATPLOTLIB_FLAG = False @@ -447,6 +450,7 @@ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出 return data2 def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI + n_cpu = cpu_count() print("The feature index is constructing.") exp_dir = os.path.join(root_dir,spk_name) listdir_res = [] @@ -463,6 +467,25 @@ def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github. big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] + if big_npy.shape[0] > 2e5: + # if(1): + info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0] + print(info) + try: + big_npy = ( + MiniBatchKMeans( + n_clusters=10000, + verbose=True, + batch_size=256 * n_cpu, + compute_labels=False, + init="random", + ) + .fit(big_npy) + .cluster_centers_ + ) + except: + info = traceback.format_exc() + print(info) n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf) index_ivf = faiss.extract_index_ivf(index) #