Skip to content

ddm

Denoising Diffusion Models (DDM) for score-based generative modeling.

This submodule provides a comprehensive suite of components for building and training score-based denoising diffusion models, with a focus on periodic data and force-conditioned sampling.

The submodule is structured into the following submodules:

  • models: Core FPSL (Fokker-Planck Score Learning) diffusion model implementation for learning score functions and generating samples.
  • network: Neural network architectures including MLPs with Fourier feature embeddings for score function approximation on periodic domains.
  • noiseschedule: Time-dependent noise scheduling functions that control the variance and diffusion coefficients during the forward/reverse processes.
  • prior: Latent prior distribution definitions, including uniform priors with periodic boundary conditions for circular/toroidal data.
  • priorschedule: Interpolation schedules between data and prior distributions that control the mixing coefficient \(\alpha(t)\) during diffusion.
  • forceschedule: Force conditioning schedules that control how external forces influence the diffusion process through time-dependent scaling factors.

FPSL(*, sigma_min=0.05, sigma_max=0.5, is_periodic=True, mlp_network, key, n_sample_steps=100, n_epochs=100, batch_size=128, wandb_log=False, gamma_energy_regulariztion=1e-05, fourier_features=1, warmup_steps=5, box_size=1.0, symmetric=False, diffusion=lambda x: 1.0, pbc_bins=0) dataclass

Bases: LinearForceSchedule, LinearPriorSchedule, UniformPrior, ExponetialVarianceNoiseSchedule, DefaultDataClass

Fokker-Planck Score Learning (FPSL) model for periodic data.

An energy-based denoising diffusion model designed for learning probability distributions on periodic domains [0, 1]. The model combines multiple inheritance from various schedule and prior classes to provide a complete diffusion modeling framework with force scheduling capabilities.

This implementation uses JAX for efficient computation and supports both symmetric and asymmetric periodic MLPs for score function approximation.

Parameters:

  • mlp_network (tuple[int]) –

    Architecture of the MLP network as a tuple specifying the number of units in each hidden layer.

  • key (JaxKey) –

    JAX random key for reproducible random number generation.

  • n_sample_steps (int, default: 100 ) –

    Number of integration steps for the sampling process.

  • n_epochs (int, default: 100 ) –

    Number of training epochs.

  • batch_size (int, default: 128 ) –

    Batch size for training.

  • wandb_log (bool, default: False ) –

    Whether to log training metrics to Weights & Biases.

  • gamma_energy_regulariztion (float, default: 1e-5 ) –

    Regularization coefficient for energy term in the loss function.

  • fourier_features (int, default: 1 ) –

    Number of Fourier features to use in the network.

  • warmup_steps (int, default: 5 ) –

    Number of warmup steps for learning rate scheduling.

  • box_size (float, default: 1.0 ) –

    Size of the periodic box domain. Currently, this is not used to scale the input data.

  • symmetric (bool, default: False ) –

    Whether to use symmetric (cos-only) periodic MLP architecture or a periodic (sin+cos) MLP architecture.

  • diffusion (Callable[[Float[ArrayLike, ' n_features']], Float[ArrayLike, '']], default: lambda x: 1.0 ) –

    Position-dependent diffusion function. Defaults to constant diffusion.

  • pbc_bins (int, default: 0 ) –

    Number of bins for periodic boundary condition corrections. If 0, no PBC corrections are applied.

Attributes:

  • params (Params) –

    Trained model parameters (available after training).

  • dim (int) –

    Dimensionality of the data (set during training).

  • score_model (ScorePeriodicMLP or ScoreSymmetricPeriodicMLP) –

    The neural network used for score function approximation.

Methods:

  • train

    Train the model on provided data (\(X\in[0, 1]\)) with forces.

  • sample

    Generate samples from the learned distribution.

  • evaluate

    Evaluate the model loss on held-out data.

  • score

    Compute the score function at given positions and time.

  • energy

    Compute the energy function at given positions and time.

Notes

The model implements the Fokker-Planck score learning approach for diffusion models on periodic domains. It combines:

  • Linear force scheduling for non-equilibrium dynamics
  • Linear prior scheduling for interpolation between prior and data
  • Exponential variance noise scheduling
  • Uniform prior distribution on [0, 1]

The training objective includes both score matching and energy regularization terms, with support for periodic boundary conditions.

Examples:

>>> import jax.random as jr
>>> from fpsl import FPSL
>>>
>>> # Create model
>>> key = jr.PRNGKey(42)
>>> model = FPSL(
...     mlp_network=(64, 64, 64),
...     key=key,
...     n_epochs=50,
...     batch_size=64,
... )
>>>
>>> # Train on data
>>> X = jr.uniform(key, (1000, 1))  # periodic data
>>> y = jr.normal(key, (1000, 1))   # force data
>>> lrs = [1e-6, 1e-4]  # Learning rate range
>>> loss_hist = model.train(X, y, lrs)

score_model cached property

Create and cache the neural network.

score(x, t, y=None)

Compute the diffusion score function at given positions and time.

The score function represents the gradient of the log probability density with respect to the input coordinates: \(\nabla_x \ln p_t(x)\).

Parameters:

  • x (Float[ArrayLike, 'n_samples n_features']) –

    Input positions where to evaluate the score function.

  • t (float) –

    Time parameter in [0, 1], where t=1 is pure noise and t=0 is data.

  • y (Float[ArrayLike, ''] or None, default: None ) –

    Optional force/conditioning variable. If None, uses equilibrium score.

Returns:

  • Float[ArrayLike, 'n_samples n_features']

    Score function values at each input position.

Notes

The score function is computed as:

\[ s_\theta(x, t) = \nabla_x \ln p_t(x) = -\frac{\nabla_x E_\theta(x, t)}{\sigma(t)} \]
Source code in src/fpsl/ddm/models.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def score(
    self,
    x: Float[ArrayLike, 'n_samples n_features'],
    t: float,
    y: None | Float[ArrayLike, ''] = None,
) -> Float[ArrayLike, 'n_samples n_features']:
    r"""Compute the diffusion score function at given positions and time.

    The score function represents the gradient of the log probability density
    with respect to the input coordinates: $\nabla_x \ln p_t(x)$.

    Parameters
    ----------
    x : Float[ArrayLike, 'n_samples n_features']
        Input positions where to evaluate the score function.
    t : float
        Time parameter in [0, 1], where t=1 is pure noise and t=0 is data.
    y : Float[ArrayLike, ''] or None, default=None
        Optional force/conditioning variable. If None, uses equilibrium score.

    Returns
    -------
    Float[ArrayLike, 'n_samples n_features']
        Score function values at each input position.

    Notes
    -----
    The score function is computed as:

    $$
        s_\theta(x, t) = \nabla_x \ln p_t(x) = -\frac{\nabla_x E_\theta(x, t)}{\sigma(t)}
    $$
    """
    if self.sigma(t) == 0:  # catch division by zero
        return np.zeros_like(x)

    score_times_minus_sigma = jax.vmap(
        self._score_eq, in_axes=(None, 0, 0),
    )(self.params, x, jnp.full((len(x), 1), t)) if y is None else jax.vmap(
        self._score, in_axes=(None, 0, 0, 0),
    )(self.params, x, jnp.full((len(x), 1), t), jnp.full((len(x), 1), y))  # fmt: skip

    return -score_times_minus_sigma / self.sigma(t)

energy(x, t, y=None)

Compute the energy function at given positions and time.

The energy function represents the negative log probability density up to a constant: \(E_\theta(x, t) = -\ln p_t(x) + C\).

Parameters:

  • x (Float[ArrayLike, 'n_samples n_features']) –

    Input positions where to evaluate the energy function.

  • t (float) –

    Time parameter in [0, 1], where \(t=1\) is pure noise and \(t=0\) is data.

  • y (Float[ArrayLike, ''] or None, default: None ) –

    Optional force/conditioning variable. If None, uses equilibrium energy.

Returns:

  • Float[ArrayLike, ' n_samples']

    Energy function values at each input position.

Notes

The energy function is related to the score function by: $$ \nabla_x E_\theta(x, t) = -s_\theta(x, t)\sigma(t) $$

Source code in src/fpsl/ddm/models.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def energy(
    self,
    x: Float[ArrayLike, 'n_samples n_features'],
    t: float,
    y: None | Float[ArrayLike, ''] = None,
) -> Float[ArrayLike, ' n_samples']:
    r"""Compute the energy function at given positions and time.

    The energy function represents the negative log probability density
    up to a constant: $E_\theta(x, t) = -\ln p_t(x) + C$.

    Parameters
    ----------
    x : Float[ArrayLike, 'n_samples n_features']
        Input positions where to evaluate the energy function.
    t : float
        Time parameter in [0, 1], where $t=1$ is pure noise and $t=0$ is data.
    y : Float[ArrayLike, ''] or None, default=None
        Optional force/conditioning variable. If None, uses equilibrium energy.

    Returns
    -------
    Float[ArrayLike, ' n_samples']
        Energy function values at each input position.

    Notes
    -----
    The energy function is related to the score function by:
    $$
        \nabla_x E_\theta(x, t) = -s_\theta(x, t)\sigma(t)
    $$
    """
    # catch division by zero
    if isinstance(t, float) and self.sigma(t) == 0:
        return np.zeros_like(x)

    energy_times_minus_sigma = (
        jax.vmap(
            self._energy_eq,
            in_axes=(None, 0, 0),
        )(self.params, x, jnp.full((len(x), 1), t))
        if y is None
        else jax.vmap(
            self._energy,
            in_axes=(None, 0, 0, 0),
        )(self.params, x, jnp.full((len(x), 1), t), jnp.full((len(x), 1), y))
    )

    return -energy_times_minus_sigma / self.sigma(t) + jax.vmap(
        self._ln_diffusion_t,
    )(x, jnp.full((len(x), 1), t))

train(X, y, lrs, key=None, n_epochs=None, X_val=None, y_val=None, project='entropy-prod-diffusion', wandb_kwargs={})

Train the FPSL model on the provided dataset.

This method trains the score function neural network using a combination of score matching loss and energy regularization. The training uses warmup cosine decay learning rate scheduling and AdamW optimizer.

Parameters:

  • X (Float[ArrayLike, 'n_samples n_features']) –

    Training data positions. Must be 2D array with shape (n_samples, n_features). Data should be in the periodic domain [0, 1].

  • y (Float[ArrayLike, ' n_features']) –

    Force/conditioning variables corresponding to each sample in X.

  • lrs (Float[ArrayLike, 2]) –

    Learning rate range as [min_lr, max_lr] for warmup cosine decay schedule.

  • key (JaxKey or None, default: None ) –

    Random key for reproducible training. If None, uses self.key.

  • n_epochs (int or None, default: None ) –

    Number of training epochs. If None, uses self.n_epochs.

  • X_val (Float[ArrayLike, 'n_val n_features'] or None, default: None ) –

    Validation data positions. If provided, validation loss will be computed.

  • y_val (Float[ArrayLike, ' n_features'] or None, default: None ) –

    Validation force variables. Required if X_val is provided.

  • project (str, default: 'entropy-prod-diffusion' ) –

    Weights & Biases project name for logging (if wandb_log=True).

  • wandb_kwargs (dict, default: {} ) –

    Additional keyword arguments passed to wandb.init().

Returns:

  • dict

    Dictionary containing training history with keys: - 'train_loss': Array of training losses for each epoch - 'val_loss': Array of validation losses (if validation data provided)

Raises:

  • ValueError

    If X is not a 2D array.

Notes

The training objective combines: 1. Score matching loss with periodic boundary handling 2. Energy regularization term controlled by gamma_energy_regularization

The model parameters are stored in self.params after training.

Source code in src/fpsl/ddm/models.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def train(
    self,
    X: Float[ArrayLike, 'n_samples n_features'],
    y: Float[ArrayLike, ' n_features'],
    lrs: Float[ArrayLike, '2'],
    key: None | JaxKey = None,
    n_epochs: None | int = None,
    X_val: None | Float[ArrayLike, 'n_val n_features'] = None,
    y_val: None | Float[ArrayLike, ' n_features'] = None,
    project: str = 'entropy-prod-diffusion',
    wandb_kwargs: dict = {},
):
    """Train the FPSL model on the provided dataset.

    This method trains the score function neural network using a combination
    of score matching loss and energy regularization. The training uses
    warmup cosine decay learning rate scheduling and AdamW optimizer.

    Parameters
    ----------
    X : Float[ArrayLike, 'n_samples n_features']
        Training data positions. Must be 2D array with shape (n_samples, n_features).
        Data should be in the periodic domain [0, 1].
    y : Float[ArrayLike, ' n_features']
        Force/conditioning variables corresponding to each sample in X.
    lrs : Float[ArrayLike, '2']
        Learning rate range as [min_lr, max_lr] for warmup cosine decay schedule.
    key : JaxKey or None, default=None
        Random key for reproducible training. If None, uses self.key.
    n_epochs : int or None, default=None
        Number of training epochs. If None, uses self.n_epochs.
    X_val : Float[ArrayLike, 'n_val n_features'] or None, default=None
        Validation data positions. If provided, validation loss will be computed.
    y_val : Float[ArrayLike, ' n_features'] or None, default=None
        Validation force variables. Required if X_val is provided.
    project : str, default='entropy-prod-diffusion'
        Weights & Biases project name for logging (if wandb_log=True).
    wandb_kwargs : dict, default={}
        Additional keyword arguments passed to wandb.init().

    Returns
    -------
    dict
        Dictionary containing training history with keys:
        - 'train_loss': Array of training losses for each epoch
        - 'val_loss': Array of validation losses (if validation data provided)

    Raises
    ------
    ValueError
        If X is not a 2D array.

    Notes
    -----
    The training objective combines:
    1. Score matching loss with periodic boundary handling
    2. Energy regularization term controlled by `gamma_energy_regularization`

    The model parameters are stored in `self.params` after training.
    """
    if X.ndim == 1:
        raise ValueError('X must be 2D array.')

    # if self.wandb_log and dataset is None:
    #    raise ValueError('Please provide a dataset for logging.')

    if key is None:
        key = self.key

    if n_epochs is None:
        n_epochs = self.n_epochs

    self.dim: int = X.shape[-1]

    # start a new wandb run to track this script
    if self.wandb_log:
        wandb.init(
            project=project,
            config=self._get_config(
                lrs=lrs,
                key=key,
                n_epochs=n_epochs,
                X=X,
                y=y,
            )
            | wandb_kwargs,
        )

    # main logic
    loss_hist = self._train(
        X=X,
        lrs=lrs,
        key=key,
        n_epochs=n_epochs,
        y=y,
        X_val=X_val,
        y_val=y_val,
    )

    if self.wandb_log:
        wandb.finish()
    return loss_hist

evaluate(X, y=None, key=None)

Evaluate the model loss on held-out data.

Computes the same loss function used during training (score matching + energy regularization) on the provided data without updating model parameters.

Parameters:

  • X (Float[ArrayLike, 'n_samples n_features']) –

    Test data positions in the periodic domain [0, 1].

  • y (Float[ArrayLike, ' n_features'] or None, default: None ) –

    Force/conditioning variables for the test data. If None, assumes equilibrium evaluation.

  • key (JaxKey or None, default: None ) –

    Random key for stochastic evaluation. If None, uses self.key.

Returns:

  • float

    Evaluation loss value.

Notes

This method is useful for monitoring generalization performance on validation or test sets during or after training.

Source code in src/fpsl/ddm/models.py
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
def evaluate(
    self,
    X: Float[ArrayLike, 'n_samples n_features'],
    y: None | Float[ArrayLike, ' n_features'] = None,
    key: None | JaxKey = None,
) -> float:
    """Evaluate the model loss on held-out data.

    Computes the same loss function used during training (score matching
    + energy regularization) on the provided data without updating model
    parameters.

    Parameters
    ----------
    X : Float[ArrayLike, 'n_samples n_features']
        Test data positions in the periodic domain [0, 1].
    y : Float[ArrayLike, ' n_features'] or None, default=None
        Force/conditioning variables for the test data. If None, assumes
        equilibrium evaluation.
    key : JaxKey or None, default=None
        Random key for stochastic evaluation. If None, uses self.key.

    Returns
    -------
    float
        Evaluation loss value.

    Notes
    -----
    This method is useful for monitoring generalization performance on
    validation or test sets during or after training.
    """
    if key is None:
        key = self.key
    loss_fn = self._create_loss_fn()
    return float(loss_fn(params=self.params, key=key, X=X, y=y))

sample(key, n_samples, t_final=0, n_steps=None)

Generate samples from the learned probability distribution.

Uses reverse-time SDE integration to generate samples by starting from the prior distribution and integrating backwards through the diffusion process using the learned score function.

Parameters:

  • key (JaxKey) –

    Random key for reproducible sampling.

  • n_samples (int) –

    Number of samples to generate.

  • t_final (float, default: 0 ) –

    Final time for the reverse integration. \(t=0\) corresponds to the data distribution, \(t=1\) to pure noise.

  • n_steps (int or None, default: None ) –

    Number of integration steps for the reverse SDE. If None, uses self.n_sample_steps.

Returns:

  • Float[ArrayLike, 'n_samples n_dims']

    Generated samples from the learned distribution.

Notes

The sampling procedure follows the reverse-time SDE:

\[ \mathrm{d}x = [\beta(t) s_\theta(x, t)] \mathrm{d}t + \sqrt{\beta(t)}\mathrm{d}W \]

where \(s_\theta\) is the learned score function and \(\beta(t)\) is the noise schedule. For periodic domains, the samples are wrapped to \([0, 1]\) at each step.

Examples:

>>> # Generate 100 samples
>>> samples = model.sample(key, n_samples=100)
>>>
>>> # Generate with custom integration steps
>>> samples = model.sample(key, n_samples=50, n_steps=200)
Source code in src/fpsl/ddm/models.py
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
def sample(
    self,
    key: JaxKey,
    n_samples: int,
    t_final: float = 0,
    n_steps: None | int = None,
) -> Float[ArrayLike, 'n_samples n_dims']:
    r"""Generate samples from the learned probability distribution.

    Uses reverse-time SDE integration to generate samples by starting
    from the prior distribution and integrating backwards through the
    diffusion process using the learned score function.

    Parameters
    ----------
    key : JaxKey
        Random key for reproducible sampling.
    n_samples : int
        Number of samples to generate.
    t_final : float, default=0
        Final time for the reverse integration. $t=0$ corresponds to the
        data distribution, $t=1$ to pure noise.
    n_steps : int or None, default=None
        Number of integration steps for the reverse SDE. If None, uses
        self.n_sample_steps.

    Returns
    -------
    Float[ArrayLike, 'n_samples n_dims']
        Generated samples from the learned distribution.

    Notes
    -----
    The sampling procedure follows the reverse-time SDE:

    $$
        \mathrm{d}x = [\beta(t) s_\theta(x, t)] \mathrm{d}t + \sqrt{\beta(t)}\mathrm{d}W
    $$

    where $s_\theta$ is the learned score function and $\beta(t)$ is the noise schedule.
    For periodic domains, the samples are wrapped to $[0, 1]$ at each step.

    Examples
    --------
    >>> # Generate 100 samples
    >>> samples = model.sample(key, n_samples=100)
    >>>
    >>> # Generate with custom integration steps
    >>> samples = model.sample(key, n_samples=50, n_steps=200)
    """
    x_init = self.prior_sample(key, (n_samples, self.dim))
    if n_steps is None:
        n_steps = self.n_sample_steps
    dt = (1 - t_final) / n_steps
    t_array = jnp.linspace(1, t_final, n_steps + 1)

    def body_fn(i, val):
        x, key = val
        key, subkey = jax.random.split(key)
        t_curr = t_array[i]
        eps = jax.random.normal(subkey, x.shape)

        score_times_minus_sigma = jax.vmap(
            self._score_eq,
            in_axes=(None, 0, 0),
        )(self.params, x, jnp.full((len(x), 1), t_curr))
        score = -score_times_minus_sigma / self.sigma(t_curr)

        x_new = (
            x
            + self.beta(t_curr) * score * dt
            + jnp.sqrt(self.beta(t_curr)) * eps * jnp.sqrt(dt)
        )
        if self.is_periodic:
            x_new = x_new % 1
        return (x_new, key)

    final_x, _ = jax.lax.fori_loop(
        0,
        n_steps + 1,
        body_fn,
        (x_init, key),
    )
    return final_x