prior
Latent prior distributions for score-based diffusion models.
This module defines the abstract base class for latent prior distributions and provides concrete implementations, such as the uniform prior on [0,1].
Classes:
-
LatentPrior–Abstract base class for latent priors.
-
UniformPrior–Uniform prior over [0,1] with periodic support.
LatentPrior(*, is_periodic=False)
dataclass
¶
Bases: DefaultDataClass, ABC
Abstract base class for latent prior distributions.
Attributes:
-
is_periodic(bool) –If True, the support of the prior is treated as periodic.
Methods:
-
_prior : str–Name identifier of the prior (abstract property).
-
prior_log_pdf–Log probability density at x.
-
prior_pdf–Probability density at x.
-
prior_sample–Sample from the prior.
-
prior_force–Force term (gradient of log-pdf) at x.
-
prior_x_t–Diffuse x with noise and wrap if periodic.
UniformPrior(*, is_periodic=True)
dataclass
¶
Bases: LatentPrior
Uniform prior over the unit interval [0, 1].
A periodic prior with constant density and zero force.
Attributes:
-
is_periodic(bool) –Always True for the uniform prior.
Methods:
-
prior_log_pdf–Returns zero array (log-density).
-
prior_pdf–Returns one array (density).
-
prior_sample–Samples uniformly in [0,1].
-
prior_force–Returns zero array (gradient of log-density).
-
prior_x_t–Applies diffusion step with noise and wraps modulo 1.
prior_log_pdf(x)
¶
Log-probability density: zero everywhere on [0,1].
Source code in src/fpsl/ddm/prior.py
127 128 129 130 131 132 | |
prior_pdf(x)
¶
Probability density: one everywhere on [0,1].
Source code in src/fpsl/ddm/prior.py
134 135 136 137 138 139 | |
prior_sample(key, shape)
¶
Sample uniformly from [0,1] with given JAX PRNG key.
Source code in src/fpsl/ddm/prior.py
141 142 143 144 145 146 147 | |
prior_force(x)
¶
Force term (gradient of log-pdf): zero everywhere.
Source code in src/fpsl/ddm/prior.py
149 150 151 152 153 154 | |
prior_x_t(x, t, eps)
¶
Diffuse x with noise and wrap into [0,1].
Parameters:
-
x((array - like, shape(dim1))) –Current latent variable.
-
t((array - like, shape(dim1))) –Time embedding for noise scaling.
-
eps((array - like, shape(dim1))) –Standard normal noise sample.
Returns:
-
x_next((ndarray, shape(dim1))) –Noisy update of x wrapped modulo 1.
Source code in src/fpsl/ddm/prior.py
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | |