import numpy as np
import matplotlib.pyplot as plt
= 1000
num_timesteps = 1/num_timesteps
dt = 0.01
sigma_min = 50
sigma_max
= sigma_min * (sigma_max / sigma_min) ** np.linspace(0, 1, 1000)
sigmas = 2 * sigmas ** 2 * (np.log(sigma_max) - np.log(sigma_min))
d2sigma_dt
# Create a figure and two subplots side by side
= plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
fig, axes
# First subplot
= axes[0]
ax1 # Example plot (replace with your data)
0, 1, 1000), sigmas)
ax1.plot(np.linspace(# Set title and x-label for the first subplot
r'Plot of $\sigma(t)$ with respect to time', fontsize=17)
ax1.set_title(r'$t$', fontsize=17)
ax1.set_xlabel(r'$\sigma(t)$', fontsize=17)
ax1.set_ylabel(
# Second subplot
= axes[1]
ax2 # Example plot (replace with your data)
0, 1, 1000), np.sqrt(d2sigma_dt))
ax2.plot(np.linspace(# Set title and x-label for the second subplot
r'Plot of $\frac{d\sigma(t)}{dt}$ with respect to time', fontsize=17)
ax2.set_title(r'$t$', fontsize=17)
ax2.set_xlabel(r'$\frac{d\sigma(t)}{dt}$', fontsize=17)
ax2.set_ylabel(
# Adjust layout to prevent overlap
plt.tight_layout()
# Display the plots
plt.show()
In the previous post we have seen the VP Diffusion model DDPM. An other popular type of Diffusion models are the Variance Exploding (VE) SDEs SDE Diffusion:
\[d\boldsymbol{x}=\sqrt{\left[\frac{d}{dt}\sigma^2(t)\right]}d\boldsymbol{w}\]
with the following reverse process:
\[d\boldsymbol{x}=\left[-\frac{d}{dt}\left(\sigma^2(t)\right)\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})\right]dt+\sqrt{\left[\frac{d}{dt}\sigma^2(t)\right]}d\bar{\boldsymbol{w}}\]
A common choice of the variance function is \(\sigma(t)=\sigma_{\text{max}}\left(\frac{\sigma_{\text{max}}}{\sigma_{\text{min}}}\right)^t\), with \(\sigma_{\text{min}}=0.01\) and \(\sigma_{\text{max}}=50\). Then, we have that:
\[\frac{d\sigma(t)^2}{dt}=2\sigma(t)^2\left(\log(\sigma_{\text{max}})-\log(\sigma_{\text{min}})\right)\]
Let’s first understand how \(\sigma(t)\) and \(\frac{d\sigma(t)}{dt}\) change with time.
Now let’s consider \(p_0(x)=\mathcal{N}(5,4)\) and propagate data samples \(x(0)\) to \(x(1)\). The distribution of \(x(1), p_1(x)\) is some noise distribution we will use to generate data samples, by propagating in reverse time.
import torch
import numpy as np
import matplotlib.pyplot as plt
class VE_SDE_DIFF:
def __init__(self, num_timesteps):
self.num_timesteps = num_timesteps
self.dt = 1/num_timesteps
self.sigma_min = torch.tensor(0.01)
self.sigma_max = torch.tensor(50)
self.sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** torch.linspace(0, 1, num_timesteps)
self.d2sigma_dt = 2 * self.sigmas ** 2 * (torch.log(self.sigma_max) - torch.log(self.sigma_min))
self.forward_term = torch.sqrt(self.d2sigma_dt * self.dt)
def forward_diff(self, x0):
= x0.shape[0]
B = torch.zeros((self.num_timesteps+1,B))
x 0] = x0
x[for t in range(self.num_timesteps):
+1] = x[t] + self.forward_term[t] * torch.randn(B)
x[treturn x
= VE_SDE_DIFF(1000)
ve_sde = torch.randn(100000) * 2 + 5
x0 = ve_sde.forward_diff(x0).numpy()
x
= x[1000]
plot_sample
print(f"Estimated mean is {np.mean(plot_sample)}")
print(f"Estimated std is {np.std(plot_sample)}")
=True)
plt.hist(plot_sample, density
# Define parameters for the normal distribution
= 0.0
mean = 50.0
std
# Create a range of x values
= plt.xlim()
xmin, xmax = torch.linspace(xmin, xmax, 200).numpy()
x
# Compute the normal PDF: (1/(σ√(2π))) exp(-(x-μ)²/(2σ²))
= (1/(std * np.sqrt(2*np.pi))) * np.exp(-0.5*((x - mean)/std)**2)
pdf
# Plot the PDF line
'r', linewidth=2)
plt.plot(x, pdf,
plt.show()
Estimated mean is 5.073878288269043
Estimated std is 50.15424346923828
Note that for large variance small differences in the mean are negligible. Thus, \(\mathcal{N}(5,50^2)\approx\mathcal{N}(0,50^2)\). Consider the case of images, that pixels are normalised between -1 and 1, then the approximation is very good.
= np.linspace(-200, 200, 200)
x # Compute the normal PDF: (1/(σ√(2π))) exp(-(x-μ)²/(2σ²))
= (1/(std * np.sqrt(2*np.pi))) * np.exp(-0.5*((x - 0.0)/std)**2)
pdf_1 = (1/(std * np.sqrt(2*np.pi))) * np.exp(-0.5*((x - 5.0)/std)**2)
pdf_2
# Plot the PDF line
'r', linewidth=2, label=r'$\mathcal{N}(0,50^2)$')
plt.plot(x, pdf_1, 'g', linewidth=2, label=r'$\mathcal{N}(5,50^2)$')
plt.plot(x, pdf_2,
plt.legend()
plt.show()
Now we will use the exact score to perform the reverse diffusion. We assume that \(p_1(x)=\mathcal{N}(5,50)\) (since we know the mean of \(p_0(x)\) we can use it instead of 0). Then,
\[p_t(x)=\mathcal{N}(5, 4 + \sigma_{t}^2 - \sigma_{\text{min}}^2)\]
and
\[\nabla_x\log p_t(x)=\frac{5-x}{4 + \sigma_{t}^2 - \sigma_{\text{min}}^2}\]
We use the exact score to reverse the diffusion process.
import torch
import numpy as np
import matplotlib.pyplot as plt
class VE_SDE_DIFF:
def __init__(self, num_timesteps):
self.num_timesteps = num_timesteps
self.dt = 1/num_timesteps
self.sigma_min = torch.tensor(0.01)
self.sigma_min2 = self.sigma_min ** 2
self.sigma_max = torch.tensor(50)
self.sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** torch.linspace(0, 1, num_timesteps)
self.sigmas2 = self.sigmas ** 2
self.d2sigma_dt = 2 * self.sigmas ** 2 * (torch.log(self.sigma_max) - torch.log(self.sigma_min))
self.forward_term = torch.sqrt(self.d2sigma_dt * self.dt)
def forward_diff(self, x0):
= x0.shape[0]
B = torch.zeros((self.num_timesteps+1,B))
x 0] = x0
x[for t in range(self.num_timesteps):
+1] = x[t] + self.forward_term[t] * torch.randn(B)
x[treturn x
def reverse_diff(self, num_samples):
= torch.zeros((self.num_timesteps+1,num_samples))
samples self.num_timesteps] = torch.randn(num_samples) * 50 + 5
samples[for t in range(self.num_timesteps-1,-1,-1):
= self.d2sigma_dt[t] * self.score(samples[t+1],t+1) * self.dt
drift = self.forward_term[t] * torch.randn(num_samples)
diff = samples[t+1] + drift + diff
samples[t] return samples
def score(self,x,t):
= 5 - x
numerator = 4 + self.sigmas2[t-1] - self.sigma_min2
denomenator return numerator / denomenator
= VE_SDE_DIFF(1000)
ve_sde = ve_sde.reverse_diff(10000).numpy()
samples
= samples[0]
plot_sample
print(f"Estimated mean is {np.mean(plot_sample)}")
print(f"Estimated std is {np.std(plot_sample)}")
=True)
plt.hist(plot_sample, density
# Define parameters for the normal distribution
= 5.0
mean = 2.0
std
# Create a range of x values
= plt.xlim()
xmin, xmax = torch.linspace(xmin, xmax, 200).numpy()
x
# Compute the normal PDF: (1/(σ√(2π))) exp(-(x-μ)²/(2σ²))
= (1/(std * np.sqrt(2*np.pi))) * np.exp(-0.5*((x - mean)/std)**2)
pdf
# Plot the PDF line
'r', linewidth=2)
plt.plot(x, pdf,
plt.show()
Estimated mean is 5.011462688446045
Estimated std is 1.9858821630477905
If you want a sanity check, try modifiying the code above, so that the we sample from \(\mathcal{N}(0,50^2)\) at \(t=1\). In a real-world example we will not know the mean of \(p_0(x)\), which is the same as the mean of \(p_1(x)\), and we use \(0\) instead. Note that the score function should not be modified (in a real-world example the score is estimated with a neural network)!
You can also check what happens if you ignore \(\sigma_{\text{min}}^2\) in the denominator of the score function. Spoiler alert: it is that small that makes no difference.