In [None]:
!pip install torch accelerate matplotlib numpy tqdm smalldiffusion

In [None]:
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from accelerate import Accelerator
from itertools import pairwise
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
from typing import Optional
from tqdm import tqdm

from smalldiffusion import (
    ScheduleLogLinear, training_loop, samples, Swissroll, TimeInputMLP, Schedule,
    ModelMixin, get_sigma_embeds
)

matplotlib.rcParams['animation.embed_limit'] = 2**128

# Introduction

For this exercise you will build diffusion model on a 2D toy dataset from scratch. We will
mostly recreate the implementation in the
[`smalldiffusion` library](https://github.com/yuanchenyang/smalldiffusion/).


# 1. Load data
First we create and load the 2D spiral dataset with 200 points.

In [None]:
dataset = Swissroll(np.pi/2, 5*np.pi, 200)
loader  = DataLoader(dataset, batch_size=2000)

Next we define `plot_batch`, which visualizes samples from this dataset

In [None]:
def plot_batch(batch, ax=None, **kwargs):
    batch = batch.cpu().numpy()
    ax = ax or plt
    return ax.scatter(batch[:,0], batch[:,1], marker='.', **kwargs)

plot_batch(next(iter(loader)))

# 2. Define Schedule
Here we define a simple log-linear noise schedule with 200 steps.

In [None]:
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
plt.plot(schedule.sigmas)
plt.xlabel('$t$')
plt.ylabel('$\\sigma_t$')
plt.yscale('log')

Since our neural network takes the noise level $\sigma_t$ as input, we need to
encode $\sigma_t$ to ensure it has bounded norm (otherwise $\sigma_t$ ranges
from very small to very large values, making training the neural network
ill-conditioned). A simple encoding scheme is to use $[\sin(\log(\sigma)/2), \cos(\log(\sigma)/2)]$.
`get_sigma_embeds` is used later in `TimeInputMLP` to encode $\sigma$ before
passing it into a neural network.

In [None]:
sx, sy = get_sigma_embeds(len(schedule), schedule.sigmas).T
plt.plot(sx, label='$\\sin(\\log(\\sigma_t)/2)$')
plt.plot(sy, label='$\\cos(\\log(\\sigma_t)/2)$')
plt.xlabel('$t$')
plt.legend()
plt.show()

# 3. Define Model
Next we define a simple diffusion model using a MLP. The 4-dimensional input to this MLP
is the (2-dimensional) $\sigma_t$ encoding concatenated with $x$. The MLP has a 2-dimensional output,
the predicted noise $\epsilon$.

In [None]:
model = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
print(model)

# 4. Train Model

Next we write code to train the model, optimizing the loss function

$$\mathcal{L}(\theta) =
\mathbb{E}[\Vert\epsilon_\theta(x_0 + \sigma \epsilon, \sigma) - \epsilon\Vert^2]$$

**Question 1 (1 point):** Complete the code in the training loop to implement
this loss function. Train for 20000 epochs and plot the loss over epochs.

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def training_loop(loader      : DataLoader,
                  model       : torch.nn.Module,
                  schedule    : Schedule,
                  epochs      : int = 10000,
                  lr          : float = 1e-3):
    accelerator = Accelerator()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
    for _ in tqdm(range(epochs)):
        for x0 in loader:
            model.train()
            optimizer.zero_grad()

            # Sample random sigmas with same number of batches as x0
            sigma = schedule.sample_batch(x0)
            while len(sigma.shape) < len(x0.shape):
                sigma = sigma.unsqueeze(-1)

            # Sample noise with same shape as x0
            eps = torch.randn_like(x0)
            loss = # mse_loss(..., ...) ### YOUR CODE HERE ###
            yield loss.item()
            accelerator.backward(loss)
            optimizer.step()

model = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
losses = list(training_loop(loader, model, schedule, epochs=20000, lr=1e-3))
plt.plot(moving_average(losses, 100))

# 5. Sample from Model

First, we plot the predictions of the trained model $\epsilon_\theta$ for
different values of $\sigma$.  The following code plots the direction of
$\epsilon_\theta(x, \sigma)$ for a fixed $\sigma$ and different values of $x$
from a grid.

**Question 2 (1 point):** Play around with different values of $\sigma$, describing
qualitatively how the direction of the predicted noise $\epsilon_\theta(x)$ change
with $\sigma$.

In [None]:
def plot_eps_field(model, sigma, ax=None, plot_width=3, mesh_width=21, color='0.5', scale=1,
                   scale_by_sigma=True, **kwargs):
    ax = plt.gca() if ax is None else ax
    mesh_x = np.linspace(-plot_width/2, plot_width/2, mesh_width)
    x0s, x1s = np.meshgrid(mesh_x, mesh_x, indexing="ij")
    X = torch.tensor(np.vstack((x0s.flatten(), x1s.flatten())).T, dtype=torch.float32)
    with torch.no_grad():
        Y = model.predict_eps(X, sigma)
    scaling = sigma if scale_by_sigma else 1/np.linalg.norm(Y, axis=1)
    return ax.quiver(X[:, 0], X[:, 1], -Y[:, 0]*scaling, -Y[:, 1]*scaling,
                     angles='xy', scale_units='xy', scale=scale, color=color, **kwargs)

plot_batch(next(iter(loader)))
plot_eps_field(model, schedule.sigmas[0], plot_width=3, scale_by_sigma=False, scale=4)

Next, we start with a deterministic DDIM sampler to sample from the model,
using the update step:

$$x_{t-1} = x_t - (\sigma_t - \sigma_{t-1})\epsilon_\theta(x_t, \sigma_t)$$

The following code plots the result of sampling 2000 points using 10 sampling steps.

In [None]:
%matplotlib inline
@torch.no_grad()
def samples_ddim(model      : torch.nn.Module,
                 sigmas     : torch.FloatTensor, # Iterable with N+1 values for N sampling steps
                 xt         : Optional[torch.FloatTensor] = None,
                 batchsize  : int = 1):
    model.eval()
    accelerator = Accelerator()
    xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
    for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
        xt = xt - (sig - sig_prev) * model.predict_eps(xt, sig.to(xt))
        yield xt

sigmas_10 = schedule.sample_sigmas(10) # Subsample 10 steps from training schedule
*xts, x0 = samples_ddim(model, sigmas_10, batchsize=2000)
plot_batch(x0)

We can also plot the sampling trajectory as an animation.

In [None]:
def get_anim(model, loader, sigmas, xts, start_from=0, width=2, quiver_args=None):
    quiver_args = quiver_args or dict(
        mesh_width=41, color=str(0.6), scale_by_sigma=False, scale=15,
        headwidth=2, headlength=3, width=0.002,
    )
    xts = np.array([s.cpu().numpy() for s in xts])
    sigmas, xts = sigmas[start_from:], xts[start_from:]

    fig, ax = plt.subplots(figsize=(7,7))
    ax.set_xlim(-width/2, width/2)
    ax.set_ylim(-width/2, width/2)
    plot_batch(next(iter(loader)), s=8, ax=ax)
    lines = [ax.plot(*xts[:2,i,:].T, color='red')[0] for i in range(xts.shape[1])]
    elems = []
    def get_quiver(t=0):
        while len(elems) > 0:
            elems.pop().remove()
        for i, line in enumerate(lines):
            x, y = xts[:t+2,i,:].T
            line.set_xdata(x)
            line.set_ydata(y)
        quiver = plot_eps_field(model, sigmas[t], ax=ax, plot_width=width, **quiver_args)
        elems.append(quiver)
        return (quiver,) + tuple(lines)
    return FuncAnimation(fig, get_quiver, frames=len(sigmas), interval=200, blit=True)

# Generates two circles each with N/2 points
def get_xT(N=20, outer_radii=1.3, inner_radii=0.6):
    radii = [(inner_radii if i%2 == 0 else outer_radii) for i in range(N)]
    th = torch.linspace(0, 2*np.pi, N)
    xT = torch.stack([torch.sin(th), torch.cos(th)]).T * torch.tensor(radii).unsqueeze(1)
    return xT

sigmas_50 = schedule.sample_sigmas(50)
xT = get_xT() * sigmas_50[0]
xts = list(samples_ddim(model, sigmas_50, xt=xT))
ani = get_anim(model, loader, sigmas_50, xts, start_from=10)
HTML(ani.to_jshtml())

**Question 3 (1 point):** Implement an accelerated deterministic sampler (with parameter $\gamma > 1$) that uses the following update step:

$$
\begin{align*}
\bar{\epsilon_t} &= \gamma \epsilon_\theta(x_t, \sigma_t) + (1-\gamma)\epsilon_\theta(x_{t+1}, \sigma_{t+1}) \\
x_{t-1} &= x_t - (\sigma_t - \sigma_{t-1})\bar{\epsilon_t}
\end{align*}
$$

Similar to above, plot the result of 10-step sampling, as well as the animation of sampling trajectories, for different values of $\gamma$. Describe qualitatively the difference in final samples as well as sampling trajectories when varying $\gamma$.


In [None]:
@torch.no_grad()
def samples(model      : torch.nn.Module,
            sigmas     : torch.FloatTensor, # Iterable with N+1 values for N sampling steps
            xt         : Optional[torch.FloatTensor] = None,
            gam        : float = 1.,
            batchsize  : int = 1):
    model.eval()
    accelerator = Accelerator()
    xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
    for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
        ### YOUR CODE HERE ###
        yield xt

*xts, x0 = samples(model, sigmas_10, batchsize=2000, gam=2.0)
plot_batch(x0)

In [None]:
xts = list(samples(model, sigmas_50, xt=xT, gam=2.0))
ani = get_anim(model, loader, sigmas_50, xts, start_from=10)
HTML(ani.to_jshtml())

# 6. Optional questions (0 points)

For students who wish to learn more about diffusion models, here are some optional questions designed to further your understanding of diffusion models.


## 6.1 Generalization

The swissroll dataset consist of 200 discrete datapoints. A diffusion model trained on these points generalize in the sense that the samples converge to the underlying spiral manifold, not to the discrete points used in training.

Reduce the number of points in the spiral (i.e. change `dataset = Swissroll(np.pi/2, 5*np.pi, 200)`) and determine when the diffusion model fail to generalize to the spiral manifold. Play with the trainin/noise schedule parameters and determine which parameters enable better generalization.
   
## 6.2 Experimenting with samplers
Implement DDPM sampling, where noise is added between each diffusion step. Figure out how to vary the noise level to interpolate between DDPM and DDIM. Plot sampling trajectories showing the effect of added noise.

## 6.3 Flow matching
Implement flow matching training and sampling, where $t$ ranges from 0 to 1 and the training loss is given by:

$$\mathcal{L}(\theta) =
\mathbb{E}[\Vert v_\theta((1-t) x_0 + t \epsilon, t) - (\epsilon-x_0)\Vert^2]$$

For sampling, use the following update step for $0 \le t' < t \le 1$:

$$
\begin{align*}
u_1 &\sim N(0, I) \\
u_{t'} &= u_t + (t'-t)v_\theta(u_t, t)
\end{align*}
$$

## 6.4 Larger datasets
Practice training diffusion models on datasets of different sizes and modalities. To get started, follow the [examples](https://github.com/yuanchenyang/smalldiffusion/tree/main/examples) in the `smalldiffusion` library.

