Skip to content

prior

Latent prior distributions for score-based diffusion models.

This module defines the abstract base class for latent prior distributions and provides concrete implementations, such as the uniform prior on [0,1].

Classes:

  • LatentPrior

    Abstract base class for latent priors.

  • UniformPrior

    Uniform prior over [0,1] with periodic support.

LatentPrior(*, is_periodic=False) dataclass

Bases: DefaultDataClass, ABC

Abstract base class for latent prior distributions.

Attributes:

  • is_periodic (bool) –

    If True, the support of the prior is treated as periodic.

Methods:

  • _prior : str

    Name identifier of the prior (abstract property).

  • prior_log_pdf

    Log probability density at x.

  • prior_pdf

    Probability density at x.

  • prior_sample

    Sample from the prior.

  • prior_force

    Force term (gradient of log-pdf) at x.

  • prior_x_t

    Diffuse x with noise and wrap if periodic.

UniformPrior(*, is_periodic=True) dataclass

Bases: LatentPrior

Uniform prior over the unit interval [0, 1].

A periodic prior with constant density and zero force.

Attributes:

  • is_periodic (bool) –

    Always True for the uniform prior.

Methods:

  • prior_log_pdf

    Returns zero array (log-density).

  • prior_pdf

    Returns one array (density).

  • prior_sample

    Samples uniformly in [0,1].

  • prior_force

    Returns zero array (gradient of log-density).

  • prior_x_t

    Applies diffusion step with noise and wraps modulo 1.

prior_log_pdf(x)

Log-probability density: zero everywhere on [0,1].

Source code in src/fpsl/ddm/prior.py
127
128
129
130
131
132
def prior_log_pdf(
    self,
    x: Float[ArrayLike, ' dim1'],
) -> Float[ArrayLike, '']:
    """Log-probability density: zero everywhere on [0,1]."""
    return jnp.zeros_like(x)

prior_pdf(x)

Probability density: one everywhere on [0,1].

Source code in src/fpsl/ddm/prior.py
134
135
136
137
138
139
def prior_pdf(
    self,
    x: Float[ArrayLike, ' dim1'],
) -> Float[ArrayLike, '']:
    """Probability density: one everywhere on [0,1]."""
    return jnp.ones_like(x)

prior_sample(key, shape)

Sample uniformly from [0,1] with given JAX PRNG key.

Source code in src/fpsl/ddm/prior.py
141
142
143
144
145
146
147
def prior_sample(
    self,
    key: JaxKey,
    shape: tuple[int],
) -> Float[ArrayLike, ' dim1']:
    """Sample uniformly from [0,1] with given JAX PRNG key."""
    return jax.random.uniform(key, shape)

prior_force(x)

Force term (gradient of log-pdf): zero everywhere.

Source code in src/fpsl/ddm/prior.py
149
150
151
152
153
154
def prior_force(
    self,
    x: Float[ArrayLike, ' dim1'],
) -> Float[ArrayLike, ' dim1']:
    """Force term (gradient of log-pdf): zero everywhere."""
    return jnp.zeros_like(x)

prior_x_t(x, t, eps)

Diffuse x with noise and wrap into [0,1].

Parameters:

  • x ((array - like, shape(dim1))) –

    Current latent variable.

  • t ((array - like, shape(dim1))) –

    Time embedding for noise scaling.

  • eps ((array - like, shape(dim1))) –

    Standard normal noise sample.

Returns:

  • x_next ( (ndarray, shape(dim1)) ) –

    Noisy update of x wrapped modulo 1.

Source code in src/fpsl/ddm/prior.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def prior_x_t(
    self,
    x: Float[ArrayLike, ' dim1'],
    t: Float[ArrayLike, ' dim1'],
    eps: Float[ArrayLike, ' dim1'],
) -> Float[ArrayLike, ' dim1']:
    """Diffuse x with noise and wrap into [0,1].

    Parameters
    ----------
    x : array-like, shape (dim1,)
        Current latent variable.
    t : array-like, shape (dim1,)
        Time embedding for noise scaling.
    eps : array-like, shape (dim1,)
        Standard normal noise sample.

    Returns
    -------
    x_next : jnp.ndarray, shape (dim1,)
        Noisy update of x wrapped modulo 1.
    """
    return (x + self.sigma(t) * eps) % 1