Skip to content

noiseschedule

Time-dependent noise schedules for DDPM diffusion models.

This module defines an abstract base class for noise schedules and several concrete implementations for mapping a normalized time variable \(t \in [0,1]\) to the noise parameters \(\beta(t)\), \(\sigma(t)\), and \(\gamma(t)\):

  • QuadraticVarianceNoiseSchedule: \(\sigma(t)\propto(\sqrt{\sigma_{\min}/\sigma_{\max}}+t)^2\).
  • LinearVarianceNoiseSchedule: \(\sigma(t)\propto(\sigma_{\min}/\sigma_{\max}+t)\).
  • ExponentialVarianceNoiseSchedule: \(\sigma(t)=\sigma_{\min}^{1-t}\,\sigma_{\max}^t\).

NoiseSchedule(*, sigma_min, sigma_max) dataclass

Bases: DefaultDataClass, ABC

Abstract base class for time-dependent noise schedules.

Defines the noise schedule \(\beta(t)\), the cummulative noise scale \(\sigma(t)\), and the mean drift \(\gamma(t)\) for a diffusion process given a normalized time \(t \in [0,1]\).

Attributes:

  • sigma_min (float) –

    Minimum noise scale at \(t=0\).

  • sigma_max (float) –

    Maximum noise scale at \(t=1\).

Methods:

  • beta

    Instantaneous noise rate \(\beta(t)\).

  • sigma

    Noise scale \(\sigma(t)\).

  • gamma

    Mean drift \(\gamma(t)\).

beta(t) abstractmethod

Instantaneous noise rate \(\beta(t)\).

Parameters:

  • t ((array_like, shape(dim1))) –

    Normalized time steps \(t \in [0,1]\).

Returns:

  • beta_t ( (array_like, shape(dim1)) ) –

    Noise rate at each time.

Source code in src/fpsl/ddm/noiseschedule.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@abstractmethod
def beta(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Instantaneous noise rate $\beta(t)$.

    Parameters
    ----------
    t : array_like, shape (dim1,)
        Normalized time steps $t \in [0,1]$.

    Returns
    -------
    beta_t : array_like, shape (dim1,)
        Noise rate at each time.
    """
    raise NotImplementedError

sigma(t) abstractmethod

Noise scale \(\sigma(t)\).

Parameters:

  • t ((array_like, shape(dim1))) –

    Normalized time steps \(t \in [0,1]\).

Returns:

  • sigma_t ( (array_like, shape(dim1)) ) –

    Noise magnitude at each time.

Source code in src/fpsl/ddm/noiseschedule.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@abstractmethod
def sigma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Noise scale $\sigma(t)$.

    Parameters
    ----------
    t : array_like, shape (dim1,)
        Normalized time steps $t \in [0,1]$.

    Returns
    -------
    sigma_t : array_like, shape (dim1,)
        Noise magnitude at each time.
    """
    raise NotImplementedError

gamma(t) abstractmethod

Mean drift \(\gamma(t)\).

Parameters:

  • t ((array_like, shape(dim1))) –

    Normalized time steps \(t \in [0,1]\).

Returns:

  • gamma_t ( (array_like, shape(dim1)) ) –

    Mean drift at each time.

Source code in src/fpsl/ddm/noiseschedule.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@abstractmethod
def gamma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Mean drift $\gamma(t)$.

    Parameters
    ----------
    t : array_like, shape (dim1,)
        Normalized time steps $t \in [0,1]$.

    Returns
    -------
    gamma_t : array_like, shape (dim1,)
        Mean drift at each time.
    """
    raise NotImplementedError

QuadraticVarianceNoiseSchedule(*, sigma_min=0.07, sigma_max=0.5) dataclass

Bases: NoiseSchedule

Quadratic variance-exploding noise schedule.

Defines:

\[ \sigma(t) = \frac{(\sqrt{\sigma_{\min}/\sigma_{\max}} + t)^2} {(\sqrt{\sigma_{\min}/\sigma_{\max}} + 1)^2} \,\sigma_{\max}, \quad \beta(t) = \frac{d}{dt}\sigma(t)^2. \]

Attributes:

  • sigma_min (float, default=0.07) –

    Starting noise scale.

  • sigma_max (float, default=0.5) –

    Ending noise scale.

gamma(t)

No drift; not implemented.

Source code in src/fpsl/ddm/noiseschedule.py
137
138
139
def gamma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    """No drift; not implemented."""
    raise NotImplementedError

sigma(t)

Compute \(\sigma(t)\) as above.

Source code in src/fpsl/ddm/noiseschedule.py
141
142
143
144
145
def sigma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\sigma(t)$ as above."""
    factor = (jnp.sqrt(self.sigma_min / self.sigma_max) + t) ** 2
    norm = (jnp.sqrt(self.sigma_min / self.sigma_max) + 1) ** 2
    return factor / norm * self.sigma_max

beta(t)

Compute \(\beta(t) = \frac{d}{dt}\sigma(t)^2\) via autograd.

Source code in src/fpsl/ddm/noiseschedule.py
147
148
149
def beta(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\beta(t) = \frac{d}{dt}\sigma(t)^2$ via autograd."""
    return jnp.vectorize(jax.grad(lambda tt: self.sigma(tt) ** 2))(t)

LinearVarianceNoiseSchedule(*, sigma_min=0.05, sigma_max=0.5) dataclass

Bases: NoiseSchedule

Linear variance–exploding noise schedule.

Defines:

\[ \sigma(t) = \frac{(\sigma_{\min}/\sigma_{\max} + t)} {(\sigma_{\min}/\sigma_{\max} + 1)} \,\sigma_{\max}, \quad \beta(t) = \frac{d}{dt}\sigma(t)^2. \]

gamma(t)

No drift; not implemented.

Source code in src/fpsl/ddm/noiseschedule.py
174
175
176
def gamma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    """No drift; not implemented."""
    raise NotImplementedError

sigma(t)

Compute \(\sigma(t)\) as above.

Source code in src/fpsl/ddm/noiseschedule.py
178
179
180
181
182
def sigma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\sigma(t)$ as above."""
    factor = self.sigma_min / self.sigma_max + t
    norm = self.sigma_min / self.sigma_max + 1
    return factor / norm * self.sigma_max

beta(t)

Compute \(\beta(t) = \frac{d}{dt}\sigma(t)^2\) via autograd.

Source code in src/fpsl/ddm/noiseschedule.py
184
185
186
def beta(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\beta(t) = \frac{d}{dt}\sigma(t)^2$ via autograd."""
    return jnp.vectorize(jax.grad(lambda tt: self.sigma(tt) ** 2))(t)

ExponetialVarianceNoiseSchedule(*, sigma_min=0.05, sigma_max=0.5) dataclass

Bases: NoiseSchedule

Exponential variance-exploding noise schedule.

Defines:

\[ \sigma(t) = \sigma_{\min}^{1-t}\,\sigma_{\max}^t, \quad \beta(t) = \frac{d}{dt}\sigma(t)^2. \]

gamma(t)

No drift; not implemented.

Source code in src/fpsl/ddm/noiseschedule.py
209
210
211
def gamma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    """No drift; not implemented."""
    raise NotImplementedError

sigma(t)

Compute \(\sigma(t) = \sigma_{\min}^{1-t}\,\sigma_{\max}^t\).

Source code in src/fpsl/ddm/noiseschedule.py
213
214
215
def sigma(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\sigma(t) = \sigma_{\min}^{1-t}\,\sigma_{\max}^t$."""
    return self.sigma_min ** (1 - t) * self.sigma_max**t

beta(t)

Compute \(\beta(t) = \frac{d}{dt}\sigma(t)^2\) via autograd.

Source code in src/fpsl/ddm/noiseschedule.py
217
218
219
def beta(self, t: Float[ArrayLike, ' dim1']) -> Float[ArrayLike, ' dim1']:
    r"""Compute $\beta(t) = \frac{d}{dt}\sigma(t)^2$ via autograd."""
    return jnp.vectorize(jax.grad(lambda tt: self.sigma(tt) ** 2))(t)