2023-06-20 01:46:18 +08:00
import math
2023-06-26 14:57:53 +08:00
import torch
2023-06-20 01:46:18 +08:00
class NoiseScheduleVP :
def __init__ (
self ,
schedule = ' discrete ' ,
betas = None ,
alphas_cumprod = None ,
continuous_beta_0 = 0.1 ,
continuous_beta_1 = 20. ,
dtype = torch . float32 ,
) :
""" Create a wrapper class for the forward SDE (VP type).
* * *
Update : We support discrete - time diffusion models by implementing a picewise linear interpolation for log_alpha_t .
We recommend to use schedule = ' discrete ' for the discrete - time diffusion models , especially for high - resolution images .
* * *
The forward SDE ensures that the condition distribution q_ { t | 0 } ( x_t | x_0 ) = N ( alpha_t * x_0 , sigma_t ^ 2 * I ) .
We further define lambda_t = log ( alpha_t ) - log ( sigma_t ) , which is the half - logSNR ( described in the DPM - Solver paper ) .
Therefore , we implement the functions for computing alpha_t , sigma_t and lambda_t . For t in [ 0 , T ] , we have :
log_alpha_t = self . marginal_log_mean_coeff ( t )
sigma_t = self . marginal_std ( t )
lambda_t = self . marginal_lambda ( t )
Moreover , as lambda ( t ) is an invertible function , we also support its inverse function :
t = self . inverse_lambda ( lambda_t )
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
We support both discrete - time DPMs ( trained on n = 0 , 1 , . . . , N - 1 ) and continuous - time DPMs ( trained on t in [ t_0 , T ] ) .
1. For discrete - time DPMs :
For discrete - time DPMs trained on n = 0 , 1 , . . . , N - 1 , we convert the discrete steps to continuous time steps by :
t_i = ( i + 1 ) / N
e . g . for N = 1000 , we have t_0 = 1e-3 and T = t_ { N - 1 } = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3 .
Args :
betas : A ` torch . Tensor ` . The beta array for the discrete - time DPM . ( See the original DDPM paper for details )
alphas_cumprod : A ` torch . Tensor ` . The cumprod alphas for the discrete - time DPM . ( See the original DDPM paper for details )
Note that we always have alphas_cumprod = cumprod ( 1 - betas ) . Therefore , we only need to set one of ` betas ` and ` alphas_cumprod ` .
* * Important * * : Please pay special attention for the args for ` alphas_cumprod ` :
The ` alphas_cumprod ` is the \hat { alpha_n } arrays in the notations of DDPM . Specifically , DDPMs assume that
q_ { t_n | 0 } ( x_ { t_n } | x_0 ) = N ( \sqrt { \hat { alpha_n } } * x_0 , ( 1 - \hat { alpha_n } ) * I ) .
Therefore , the notation \hat { alpha_n } is different from the notation alpha_t in DPM - Solver . In fact , we have
alpha_ { t_n } = \sqrt { \hat { alpha_n } } ,
and
log ( alpha_ { t_n } ) = 0.5 * log ( \hat { alpha_n } ) .
2. For continuous - time DPMs :
We support two types of VPSDEs : linear ( DDPM ) and cosine ( improved - DDPM ) . The hyperparameters for the noise
schedule are the default settings in DDPM and improved - DDPM :
Args :
beta_min : A ` float ` number . The smallest beta for the linear schedule .
beta_max : A ` float ` number . The largest beta for the linear schedule .
cosine_s : A ` float ` number . The hyperparameter in the cosine schedule .
cosine_beta_max : A ` float ` number . The hyperparameter in the cosine schedule .
T : A ` float ` number . The ending time of the forward process .
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
Args :
schedule : A ` str ` . The noise schedule of the forward SDE . ' discrete ' for discrete - time DPMs ,
' linear ' or ' cosine ' for continuous - time DPMs .
Returns :
A wrapper object of the forward SDE ( VP type ) .
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
Example :
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>> > ns = NoiseScheduleVP ( ' discrete ' , betas = betas )
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>> > ns = NoiseScheduleVP ( ' discrete ' , alphas_cumprod = alphas_cumprod )
# For continuous-time DPMs (VPSDE), linear schedule:
>> > ns = NoiseScheduleVP ( ' linear ' , continuous_beta_0 = 0.1 , continuous_beta_1 = 20. )
"""
if schedule not in [ ' discrete ' , ' linear ' , ' cosine ' ] :
raise ValueError ( " Unsupported noise schedule {} . The schedule needs to be ' discrete ' or ' linear ' or ' cosine ' " . format ( schedule ) )
self . schedule = schedule
if schedule == ' discrete ' :
if betas is not None :
log_alphas = 0.5 * torch . log ( 1 - betas ) . cumsum ( dim = 0 )
else :
assert alphas_cumprod is not None
log_alphas = 0.5 * torch . log ( alphas_cumprod )
self . total_N = len ( log_alphas )
self . T = 1.
self . t_array = torch . linspace ( 0. , 1. , self . total_N + 1 ) [ 1 : ] . reshape ( ( 1 , - 1 ) ) . to ( dtype = dtype )
self . log_alpha_array = log_alphas . reshape ( ( 1 , - 1 , ) ) . to ( dtype = dtype )
else :
self . total_N = 1000
self . beta_0 = continuous_beta_0
self . beta_1 = continuous_beta_1
self . cosine_s = 0.008
self . cosine_beta_max = 999.
self . cosine_t_max = math . atan ( self . cosine_beta_max * ( 1. + self . cosine_s ) / math . pi ) * 2. * ( 1. + self . cosine_s ) / math . pi - self . cosine_s
self . cosine_log_alpha_0 = math . log ( math . cos ( self . cosine_s / ( 1. + self . cosine_s ) * math . pi / 2. ) )
self . schedule = schedule
if schedule == ' cosine ' :
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self . T = 0.9946
else :
self . T = 1.
def marginal_log_mean_coeff ( self , t ) :
"""
Compute log ( alpha_t ) of a given continuous - time label t in [ 0 , T ] .
"""
if self . schedule == ' discrete ' :
return interpolate_fn ( t . reshape ( ( - 1 , 1 ) ) , self . t_array . to ( t . device ) , self . log_alpha_array . to ( t . device ) ) . reshape ( ( - 1 ) )
elif self . schedule == ' linear ' :
return - 0.25 * t * * 2 * ( self . beta_1 - self . beta_0 ) - 0.5 * t * self . beta_0
elif self . schedule == ' cosine ' :
2023-06-25 23:46:26 +08:00
def log_alpha_fn ( s ) :
return torch . log ( torch . cos ( ( s + self . cosine_s ) / ( 1.0 + self . cosine_s ) * math . pi / 2.0 ) )
2023-06-20 01:46:18 +08:00
log_alpha_t = log_alpha_fn ( t ) - self . cosine_log_alpha_0
return log_alpha_t
def marginal_alpha ( self , t ) :
"""
Compute alpha_t of a given continuous - time label t in [ 0 , T ] .
"""
return torch . exp ( self . marginal_log_mean_coeff ( t ) )
def marginal_std ( self , t ) :
"""
Compute sigma_t of a given continuous - time label t in [ 0 , T ] .
"""
return torch . sqrt ( 1. - torch . exp ( 2. * self . marginal_log_mean_coeff ( t ) ) )
def marginal_lambda ( self , t ) :
"""
Compute lambda_t = log ( alpha_t ) - log ( sigma_t ) of a given continuous - time label t in [ 0 , T ] .
"""
log_mean_coeff = self . marginal_log_mean_coeff ( t )
log_std = 0.5 * torch . log ( 1. - torch . exp ( 2. * log_mean_coeff ) )
return log_mean_coeff - log_std
def inverse_lambda ( self , lamb ) :
"""
Compute the continuous - time label t in [ 0 , T ] of a given half - logSNR lambda_t .
"""
if self . schedule == ' linear ' :
tmp = 2. * ( self . beta_1 - self . beta_0 ) * torch . logaddexp ( - 2. * lamb , torch . zeros ( ( 1 , ) ) . to ( lamb ) )
Delta = self . beta_0 * * 2 + tmp
return tmp / ( torch . sqrt ( Delta ) + self . beta_0 ) / ( self . beta_1 - self . beta_0 )
elif self . schedule == ' discrete ' :
log_alpha = - 0.5 * torch . logaddexp ( torch . zeros ( ( 1 , ) ) . to ( lamb . device ) , - 2. * lamb )
t = interpolate_fn ( log_alpha . reshape ( ( - 1 , 1 ) ) , torch . flip ( self . log_alpha_array . to ( lamb . device ) , [ 1 ] ) , torch . flip ( self . t_array . to ( lamb . device ) , [ 1 ] ) )
return t . reshape ( ( - 1 , ) )
else :
log_alpha = - 0.5 * torch . logaddexp ( - 2. * lamb , torch . zeros ( ( 1 , ) ) . to ( lamb ) )
2023-06-25 23:46:26 +08:00
def t_fn ( log_alpha_t ) :
return torch . arccos ( torch . exp ( log_alpha_t + self . cosine_log_alpha_0 ) ) * 2.0 * ( 1.0 + self . cosine_s ) / math . pi - self . cosine_s
2023-06-20 01:46:18 +08:00
t = t_fn ( log_alpha )
return t
def model_wrapper (
model ,
noise_schedule ,
model_type = " noise " ,
model_kwargs = { } ,
guidance_type = " uncond " ,
condition = None ,
unconditional_condition = None ,
guidance_scale = 1. ,
classifier_fn = None ,
classifier_kwargs = { } ,
) :
""" Create a wrapper function for the noise prediction model.
"""
def get_model_input_time ( t_continuous ) :
"""
Convert the continuous - time ` t_continuous ` ( in [ epsilon , T ] ) to the model input time .
For discrete - time DPMs , we convert ` t_continuous ` in [ 1 / N , 1 ] to ` t_input ` in [ 0 , 1000 * ( N - 1 ) / N ] .
For continuous - time DPMs , we just use ` t_continuous ` .
"""
if noise_schedule . schedule == ' discrete ' :
return ( t_continuous - 1. / noise_schedule . total_N ) * noise_schedule . total_N
else :
return t_continuous
def noise_pred_fn ( x , t_continuous , cond = None ) :
t_input = get_model_input_time ( t_continuous )
if cond is None :
output = model ( x , t_input , * * model_kwargs )
else :
output = model ( x , t_input , cond , * * model_kwargs )
if model_type == " noise " :
return output
elif model_type == " x_start " :
alpha_t , sigma_t = noise_schedule . marginal_alpha ( t_continuous ) , noise_schedule . marginal_std ( t_continuous )
return ( x - alpha_t * output ) / sigma_t
elif model_type == " v " :
alpha_t , sigma_t = noise_schedule . marginal_alpha ( t_continuous ) , noise_schedule . marginal_std ( t_continuous )
return alpha_t * output + sigma_t * x
elif model_type == " score " :
sigma_t = noise_schedule . marginal_std ( t_continuous )
return - sigma_t * output
def cond_grad_fn ( x , t_input ) :
"""
Compute the gradient of the classifier , i . e . nabla_ { x } log p_t ( cond | x_t ) .
"""
with torch . enable_grad ( ) :
x_in = x . detach ( ) . requires_grad_ ( True )
log_prob = classifier_fn ( x_in , t_input , condition , * * classifier_kwargs )
return torch . autograd . grad ( log_prob . sum ( ) , x_in ) [ 0 ]
def model_fn ( x , t_continuous ) :
"""
The noise predicition model function that is used for DPM - Solver .
"""
if guidance_type == " uncond " :
return noise_pred_fn ( x , t_continuous )
elif guidance_type == " classifier " :
assert classifier_fn is not None
t_input = get_model_input_time ( t_continuous )
cond_grad = cond_grad_fn ( x , t_input )
sigma_t = noise_schedule . marginal_std ( t_continuous )
noise = noise_pred_fn ( x , t_continuous )
return noise - guidance_scale * sigma_t * cond_grad
elif guidance_type == " classifier-free " :
if guidance_scale == 1. or unconditional_condition is None :
return noise_pred_fn ( x , t_continuous , cond = condition )
else :
x_in = torch . cat ( [ x ] * 2 )
t_in = torch . cat ( [ t_continuous ] * 2 )
c_in = torch . cat ( [ unconditional_condition , condition ] )
noise_uncond , noise = noise_pred_fn ( x_in , t_in , cond = c_in ) . chunk ( 2 )
return noise_uncond + guidance_scale * ( noise - noise_uncond )
assert model_type in [ " noise " , " x_start " , " v " ]
assert guidance_type in [ " uncond " , " classifier " , " classifier-free " ]
return model_fn
class UniPC :
def __init__ (
self ,
model_fn ,
noise_schedule ,
algorithm_type = " data_prediction " ,
correcting_x0_fn = None ,
correcting_xt_fn = None ,
thresholding_max_val = 1. ,
dynamic_thresholding_ratio = 0.995 ,
variant = ' bh1 '
) :
""" Construct a UniPC.
We support both data_prediction and noise_prediction .
"""
self . model = lambda x , t : model_fn ( x , t . expand ( ( x . shape [ 0 ] ) ) )
self . noise_schedule = noise_schedule
assert algorithm_type in [ " data_prediction " , " noise_prediction " ]
if correcting_x0_fn == " dynamic_thresholding " :
self . correcting_x0_fn = self . dynamic_thresholding_fn
else :
self . correcting_x0_fn = correcting_x0_fn
self . correcting_xt_fn = correcting_xt_fn
self . dynamic_thresholding_ratio = dynamic_thresholding_ratio
self . thresholding_max_val = thresholding_max_val
self . variant = variant
self . predict_x0 = algorithm_type == " data_prediction "
def dynamic_thresholding_fn ( self , x0 , t = None ) :
"""
The dynamic thresholding method .
"""
dims = x0 . dim ( )
p = self . dynamic_thresholding_ratio
s = torch . quantile ( torch . abs ( x0 ) . reshape ( ( x0 . shape [ 0 ] , - 1 ) ) , p , dim = 1 )
s = expand_dims ( torch . maximum ( s , self . thresholding_max_val * torch . ones_like ( s ) . to ( s . device ) ) , dims )
x0 = torch . clamp ( x0 , - s , s ) / s
return x0
def noise_prediction_fn ( self , x , t ) :
"""
Return the noise prediction model .
"""
return self . model ( x , t )
def data_prediction_fn ( self , x , t ) :
"""
Return the data prediction model ( with corrector ) .
"""
noise = self . noise_prediction_fn ( x , t )
alpha_t , sigma_t = self . noise_schedule . marginal_alpha ( t ) , self . noise_schedule . marginal_std ( t )
x0 = ( x - sigma_t * noise ) / alpha_t
if self . correcting_x0_fn is not None :
x0 = self . correcting_x0_fn ( x0 )
return x0
def model_fn ( self , x , t ) :
"""
Convert the model to the noise prediction model or the data prediction model .
"""
if self . predict_x0 :
return self . data_prediction_fn ( x , t )
else :
return self . noise_prediction_fn ( x , t )
def get_time_steps ( self , skip_type , t_T , t_0 , N , device ) :
""" Compute the intermediate time steps for sampling.
"""
if skip_type == ' logSNR ' :
lambda_T = self . noise_schedule . marginal_lambda ( torch . tensor ( t_T ) . to ( device ) )
lambda_0 = self . noise_schedule . marginal_lambda ( torch . tensor ( t_0 ) . to ( device ) )
logSNR_steps = torch . linspace ( lambda_T . cpu ( ) . item ( ) , lambda_0 . cpu ( ) . item ( ) , N + 1 ) . to ( device )
return self . noise_schedule . inverse_lambda ( logSNR_steps )
elif skip_type == ' time_uniform ' :
return torch . linspace ( t_T , t_0 , N + 1 ) . to ( device )
elif skip_type == ' time_quadratic ' :
t_order = 2
t = torch . linspace ( t_T * * ( 1. / t_order ) , t_0 * * ( 1. / t_order ) , N + 1 ) . pow ( t_order ) . to ( device )
return t
else :
raise ValueError ( " Unsupported skip_type {} , need to be ' logSNR ' or ' time_uniform ' or ' time_quadratic ' " . format ( skip_type ) )
def get_orders_and_timesteps_for_singlestep_solver ( self , steps , order , skip_type , t_T , t_0 , device ) :
"""
Get the order of each step for sampling by the singlestep DPM - Solver .
"""
if order == 3 :
K = steps / / 3 + 1
if steps % 3 == 0 :
orders = [ 3 , ] * ( K - 2 ) + [ 2 , 1 ]
elif steps % 3 == 1 :
orders = [ 3 , ] * ( K - 1 ) + [ 1 ]
else :
orders = [ 3 , ] * ( K - 1 ) + [ 2 ]
elif order == 2 :
if steps % 2 == 0 :
K = steps / / 2
orders = [ 2 , ] * K
else :
K = steps / / 2 + 1
orders = [ 2 , ] * ( K - 1 ) + [ 1 ]
elif order == 1 :
K = steps
orders = [ 1 , ] * steps
else :
raise ValueError ( " ' order ' must be ' 1 ' or ' 2 ' or ' 3 ' . " )
if skip_type == ' logSNR ' :
# To reproduce the results in DPM-Solver paper
timesteps_outer = self . get_time_steps ( skip_type , t_T , t_0 , K , device )
else :
timesteps_outer = self . get_time_steps ( skip_type , t_T , t_0 , steps , device ) [ torch . cumsum ( torch . tensor ( [ 0 , ] + orders ) , 0 ) . to ( device ) ]
return timesteps_outer , orders
def denoise_to_zero_fn ( self , x , s ) :
"""
Denoise at the final step , which is equivalent to solve the ODE from lambda_s to infty by first - order discretization .
"""
return self . data_prediction_fn ( x , s )
def multistep_uni_pc_update ( self , x , model_prev_list , t_prev_list , t , order , * * kwargs ) :
if len ( t . shape ) == 0 :
t = t . view ( - 1 )
if ' bh ' in self . variant :
return self . multistep_uni_pc_bh_update ( x , model_prev_list , t_prev_list , t , order , * * kwargs )
else :
assert self . variant == ' vary_coeff '
return self . multistep_uni_pc_vary_update ( x , model_prev_list , t_prev_list , t , order , * * kwargs )
def multistep_uni_pc_vary_update ( self , x , model_prev_list , t_prev_list , t , order , use_corrector = True ) :
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self . noise_schedule
assert order < = len ( model_prev_list )
# first compute rks
t_prev_0 = t_prev_list [ - 1 ]
lambda_prev_0 = ns . marginal_lambda ( t_prev_0 )
lambda_t = ns . marginal_lambda ( t )
model_prev_0 = model_prev_list [ - 1 ]
sigma_prev_0 , sigma_t = ns . marginal_std ( t_prev_0 ) , ns . marginal_std ( t )
log_alpha_t = ns . marginal_log_mean_coeff ( t )
alpha_t = torch . exp ( log_alpha_t )
h = lambda_t - lambda_prev_0
rks = [ ]
D1s = [ ]
for i in range ( 1 , order ) :
t_prev_i = t_prev_list [ - ( i + 1 ) ]
model_prev_i = model_prev_list [ - ( i + 1 ) ]
lambda_prev_i = ns . marginal_lambda ( t_prev_i )
rk = ( lambda_prev_i - lambda_prev_0 ) / h
rks . append ( rk )
D1s . append ( ( model_prev_i - model_prev_0 ) / rk )
rks . append ( 1. )
rks = torch . tensor ( rks , device = x . device )
K = len ( rks )
# build C matrix
C = [ ]
col = torch . ones_like ( rks )
for k in range ( 1 , K + 1 ) :
C . append ( col )
col = col * rks / ( k + 1 )
C = torch . stack ( C , dim = 1 )
if len ( D1s ) > 0 :
D1s = torch . stack ( D1s , dim = 1 ) # (B, K)
C_inv_p = torch . linalg . inv ( C [ : - 1 , : - 1 ] )
A_p = C_inv_p
if use_corrector :
#print('using corrector')
C_inv = torch . linalg . inv ( C )
A_c = C_inv
hh = - h if self . predict_x0 else h
h_phi_1 = torch . expm1 ( hh )
h_phi_ks = [ ]
factorial_k = 1
h_phi_k = h_phi_1
for k in range ( 1 , K + 2 ) :
h_phi_ks . append ( h_phi_k )
h_phi_k = h_phi_k / hh - 1 / factorial_k
factorial_k * = ( k + 1 )
model_t = None
if self . predict_x0 :
x_t_ = (
sigma_t / sigma_prev_0 * x
- alpha_t * h_phi_1 * model_prev_0
)
# now predictor
x_t = x_t_
if len ( D1s ) > 0 :
# compute the residuals for predictor
for k in range ( K - 1 ) :
x_t = x_t - alpha_t * h_phi_ks [ k + 1 ] * torch . einsum ( ' bkchw,k->bchw ' , D1s , A_p [ k ] )
# now corrector
if use_corrector :
model_t = self . model_fn ( x_t , t )
D1_t = ( model_t - model_prev_0 )
x_t = x_t_
k = 0
for k in range ( K - 1 ) :
x_t = x_t - alpha_t * h_phi_ks [ k + 1 ] * torch . einsum ( ' bkchw,k->bchw ' , D1s , A_c [ k ] [ : - 1 ] )
x_t = x_t - alpha_t * h_phi_ks [ K ] * ( D1_t * A_c [ k ] [ - 1 ] )
else :
log_alpha_prev_0 , log_alpha_t = ns . marginal_log_mean_coeff ( t_prev_0 ) , ns . marginal_log_mean_coeff ( t )
x_t_ = (
( torch . exp ( log_alpha_t - log_alpha_prev_0 ) ) * x
- ( sigma_t * h_phi_1 ) * model_prev_0
)
# now predictor
x_t = x_t_
if len ( D1s ) > 0 :
# compute the residuals for predictor
for k in range ( K - 1 ) :
x_t = x_t - sigma_t * h_phi_ks [ k + 1 ] * torch . einsum ( ' bkchw,k->bchw ' , D1s , A_p [ k ] )
# now corrector
if use_corrector :
model_t = self . model_fn ( x_t , t )
D1_t = ( model_t - model_prev_0 )
x_t = x_t_
k = 0
for k in range ( K - 1 ) :
x_t = x_t - sigma_t * h_phi_ks [ k + 1 ] * torch . einsum ( ' bkchw,k->bchw ' , D1s , A_c [ k ] [ : - 1 ] )
x_t = x_t - sigma_t * h_phi_ks [ K ] * ( D1_t * A_c [ k ] [ - 1 ] )
return x_t , model_t
def multistep_uni_pc_bh_update ( self , x , model_prev_list , t_prev_list , t , order , x_t = None , use_corrector = True ) :
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns = self . noise_schedule
assert order < = len ( model_prev_list )
# first compute rks
t_prev_0 = t_prev_list [ - 1 ]
lambda_prev_0 = ns . marginal_lambda ( t_prev_0 )
lambda_t = ns . marginal_lambda ( t )
model_prev_0 = model_prev_list [ - 1 ]
sigma_prev_0 , sigma_t = ns . marginal_std ( t_prev_0 ) , ns . marginal_std ( t )
log_alpha_prev_0 , log_alpha_t = ns . marginal_log_mean_coeff ( t_prev_0 ) , ns . marginal_log_mean_coeff ( t )
alpha_t = torch . exp ( log_alpha_t )
h = lambda_t - lambda_prev_0
rks = [ ]
D1s = [ ]
for i in range ( 1 , order ) :
t_prev_i = t_prev_list [ - ( i + 1 ) ]
model_prev_i = model_prev_list [ - ( i + 1 ) ]
lambda_prev_i = ns . marginal_lambda ( t_prev_i )
rk = ( lambda_prev_i - lambda_prev_0 ) / h
rks . append ( rk )
D1s . append ( ( model_prev_i - model_prev_0 ) / rk )
rks . append ( 1. )
rks = torch . tensor ( rks , device = x . device )
R = [ ]
b = [ ]
hh = - h if self . predict_x0 else h
h_phi_1 = torch . expm1 ( hh ) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self . variant == ' bh1 ' :
B_h = hh
elif self . variant == ' bh2 ' :
B_h = torch . expm1 ( hh )
else :
raise NotImplementedError ( )
for i in range ( 1 , order + 1 ) :
R . append ( torch . pow ( rks , i - 1 ) )
b . append ( h_phi_k * factorial_i / B_h )
factorial_i * = ( i + 1 )
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch . stack ( R )
b = torch . cat ( b )
# now predictor
use_predictor = len ( D1s ) > 0 and x_t is None
if len ( D1s ) > 0 :
D1s = torch . stack ( D1s , dim = 1 ) # (B, K)
if x_t is None :
# for order 2, we use a simplified version
if order == 2 :
rhos_p = torch . tensor ( [ 0.5 ] , device = b . device )
else :
rhos_p = torch . linalg . solve ( R [ : - 1 , : - 1 ] , b [ : - 1 ] )
else :
D1s = None
if use_corrector :
#print('using corrector')
# for order 1, we use a simplified version
if order == 1 :
rhos_c = torch . tensor ( [ 0.5 ] , device = b . device )
else :
rhos_c = torch . linalg . solve ( R , b )
model_t = None
if self . predict_x0 :
x_t_ = (
sigma_t / sigma_prev_0 * x
- alpha_t * h_phi_1 * model_prev_0
)
if x_t is None :
if use_predictor :
pred_res = torch . einsum ( ' k,bkchw->bchw ' , rhos_p , D1s )
else :
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
if use_corrector :
model_t = self . model_fn ( x_t , t )
if D1s is not None :
corr_res = torch . einsum ( ' k,bkchw->bchw ' , rhos_c [ : - 1 ] , D1s )
else :
corr_res = 0
D1_t = ( model_t - model_prev_0 )
x_t = x_t_ - alpha_t * B_h * ( corr_res + rhos_c [ - 1 ] * D1_t )
else :
x_t_ = (
torch . exp ( log_alpha_t - log_alpha_prev_0 ) * x
- sigma_t * h_phi_1 * model_prev_0
)
if x_t is None :
if use_predictor :
pred_res = torch . einsum ( ' k,bkchw->bchw ' , rhos_p , D1s )
else :
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
if use_corrector :
model_t = self . model_fn ( x_t , t )
if D1s is not None :
corr_res = torch . einsum ( ' k,bkchw->bchw ' , rhos_c [ : - 1 ] , D1s )
else :
corr_res = 0
D1_t = ( model_t - model_prev_0 )
x_t = x_t_ - sigma_t * B_h * ( corr_res + rhos_c [ - 1 ] * D1_t )
return x_t , model_t
def sample ( self , x , steps = 20 , t_start = None , t_end = None , order = 2 , skip_type = ' time_uniform ' ,
method = ' multistep ' , lower_order_final = True , denoise_to_zero = False , atol = 0.0078 , rtol = 0.05 , return_intermediate = False ,
) :
"""
Compute the sample at time ` t_end ` by UniPC , given the initial ` x ` at time ` t_start ` .
"""
t_0 = 1. / self . noise_schedule . total_N if t_end is None else t_end
t_T = self . noise_schedule . T if t_start is None else t_start
assert t_0 > 0 and t_T > 0 , " Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array "
if return_intermediate :
assert method in [ ' multistep ' , ' singlestep ' , ' singlestep_fixed ' ] , " Cannot use adaptive solver when saving intermediate values "
if self . correcting_xt_fn is not None :
assert method in [ ' multistep ' , ' singlestep ' , ' singlestep_fixed ' ] , " Cannot use adaptive solver when correcting_xt_fn is not None "
device = x . device
intermediates = [ ]
with torch . no_grad ( ) :
if method == ' multistep ' :
assert steps > = order
timesteps = self . get_time_steps ( skip_type = skip_type , t_T = t_T , t_0 = t_0 , N = steps , device = device )
assert timesteps . shape [ 0 ] - 1 == steps
# Init the initial values.
step = 0
t = timesteps [ step ]
t_prev_list = [ t ]
model_prev_list = [ self . model_fn ( x , t ) ]
if self . correcting_xt_fn is not None :
x = self . correcting_xt_fn ( x , t , step )
if return_intermediate :
intermediates . append ( x )
# Init the first `order` values by lower order multistep UniPC.
for step in range ( 1 , order ) :
t = timesteps [ step ]
x , model_x = self . multistep_uni_pc_update ( x , model_prev_list , t_prev_list , t , step , use_corrector = True )
if model_x is None :
model_x = self . model_fn ( x , t )
if self . correcting_xt_fn is not None :
x = self . correcting_xt_fn ( x , t , step )
if return_intermediate :
intermediates . append ( x )
t_prev_list . append ( t )
model_prev_list . append ( model_x )
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in range ( order , steps + 1 ) :
t = timesteps [ step ]
if lower_order_final :
step_order = min ( order , steps + 1 - step )
else :
step_order = order
if step == steps :
#print('do not run corrector at the last step')
use_corrector = False
else :
use_corrector = True
x , model_x = self . multistep_uni_pc_update ( x , model_prev_list , t_prev_list , t , step_order , use_corrector = use_corrector )
if self . correcting_xt_fn is not None :
x = self . correcting_xt_fn ( x , t , step )
if return_intermediate :
intermediates . append ( x )
for i in range ( order - 1 ) :
t_prev_list [ i ] = t_prev_list [ i + 1 ]
model_prev_list [ i ] = model_prev_list [ i + 1 ]
t_prev_list [ - 1 ] = t
# We do not need to evaluate the final model value.
if step < steps :
if model_x is None :
model_x = self . model_fn ( x , t )
model_prev_list [ - 1 ] = model_x
else :
raise ValueError ( " Got wrong method {} " . format ( method ) )
if denoise_to_zero :
t = torch . ones ( ( 1 , ) ) . to ( device ) * t_0
x = self . denoise_to_zero_fn ( x , t )
if self . correcting_xt_fn is not None :
x = self . correcting_xt_fn ( x , t , step + 1 )
if return_intermediate :
intermediates . append ( x )
if return_intermediate :
return x , intermediates
else :
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn ( x , xp , yp ) :
"""
A piecewise linear function y = f ( x ) , using xp and yp as keypoints .
We implement f ( x ) in a differentiable way ( i . e . applicable for autograd ) .
The function f ( x ) is well - defined for all x - axis . ( For x beyond the bounds of xp , we use the outmost points of xp to define the linear function . )
Args :
x : PyTorch tensor with shape [ N , C ] , where N is the batch size , C is the number of channels ( we use C = 1 for DPM - Solver ) .
xp : PyTorch tensor with shape [ C , K ] , where K is the number of keypoints .
yp : PyTorch tensor with shape [ C , K ] .
Returns :
The function values f ( x ) , with shape [ N , C ] .
"""
N , K = x . shape [ 0 ] , xp . shape [ 1 ]
all_x = torch . cat ( [ x . unsqueeze ( 2 ) , xp . unsqueeze ( 0 ) . repeat ( ( N , 1 , 1 ) ) ] , dim = 2 )
sorted_all_x , x_indices = torch . sort ( all_x , dim = 2 )
x_idx = torch . argmin ( x_indices , dim = 2 )
cand_start_idx = x_idx - 1
start_idx = torch . where (
torch . eq ( x_idx , 0 ) ,
torch . tensor ( 1 , device = x . device ) ,
torch . where (
torch . eq ( x_idx , K ) , torch . tensor ( K - 2 , device = x . device ) , cand_start_idx ,
) ,
)
end_idx = torch . where ( torch . eq ( start_idx , cand_start_idx ) , start_idx + 2 , start_idx + 1 )
start_x = torch . gather ( sorted_all_x , dim = 2 , index = start_idx . unsqueeze ( 2 ) ) . squeeze ( 2 )
end_x = torch . gather ( sorted_all_x , dim = 2 , index = end_idx . unsqueeze ( 2 ) ) . squeeze ( 2 )
start_idx2 = torch . where (
torch . eq ( x_idx , 0 ) ,
torch . tensor ( 0 , device = x . device ) ,
torch . where (
torch . eq ( x_idx , K ) , torch . tensor ( K - 2 , device = x . device ) , cand_start_idx ,
) ,
)
y_positions_expanded = yp . unsqueeze ( 0 ) . expand ( N , - 1 , - 1 )
start_y = torch . gather ( y_positions_expanded , dim = 2 , index = start_idx2 . unsqueeze ( 2 ) ) . squeeze ( 2 )
end_y = torch . gather ( y_positions_expanded , dim = 2 , index = ( start_idx2 + 1 ) . unsqueeze ( 2 ) ) . squeeze ( 2 )
cand = start_y + ( x - start_x ) * ( end_y - start_y ) / ( end_x - start_x )
return cand
def expand_dims ( v , dims ) :
"""
Expand the tensor ` v ` to the dim ` dims ` .
Args :
` v ` : a PyTorch tensor with shape [ N ] .
` dim ` : a ` int ` .
Returns :
a PyTorch tensor with shape [ N , 1 , 1 , . . . , 1 ] and the total dimension is ` dims ` .
"""
return v [ ( . . . , ) + ( None , ) * ( dims - 1 ) ]