mirror of
https://github.com/svc-develop-team/so-vits-svc.git
synced 2025-01-08 11:57:43 +08:00
Accelerate up random slice segments
This commit is contained in:
parent
730930d337
commit
0ee0b0899e
@ -65,20 +65,19 @@ def rand_gumbel_like(x):
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
# Slice segments
|
||||
gather_indices = ids_str[:, None, None] + torch.arange(
|
||||
segment_size, device=x.device
|
||||
)
|
||||
return torch.gather(x, 2, gather_indices)
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
||||
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user