models
FPSL(*, sigma_min=0.05, sigma_max=0.5, is_periodic=True, mlp_network, key, n_sample_steps=100, n_epochs=100, batch_size=128, wandb_log=False, gamma_energy_regulariztion=1e-05, fourier_features=1, warmup_steps=5, box_size=1.0, symmetric=False, diffusion=lambda x: 1.0, pbc_bins=0)
dataclass
¶
Bases: LinearForceSchedule
, LinearPriorSchedule
, UniformPrior
, ExponetialVarianceNoiseSchedule
, DefaultDataClass
Fokker-Planck Score Learning (FPSL) model for periodic data.
An energy-based denoising diffusion model designed for learning probability distributions on periodic domains [0, 1]. The model combines multiple inheritance from various schedule and prior classes to provide a complete diffusion modeling framework with force scheduling capabilities.
This implementation uses JAX for efficient computation and supports both symmetric and asymmetric periodic MLPs for score function approximation.
Parameters:
-
mlp_network
(tuple[int]
) –Architecture of the MLP network as a tuple specifying the number of units in each hidden layer.
-
key
(JaxKey
) –JAX random key for reproducible random number generation.
-
n_sample_steps
(int
, default:100
) –Number of integration steps for the sampling process.
-
n_epochs
(int
, default:100
) –Number of training epochs.
-
batch_size
(int
, default:128
) –Batch size for training.
-
wandb_log
(bool
, default:False
) –Whether to log training metrics to Weights & Biases.
-
gamma_energy_regulariztion
(float
, default:1e-5
) –Regularization coefficient for energy term in the loss function.
-
fourier_features
(int
, default:1
) –Number of Fourier features to use in the network.
-
warmup_steps
(int
, default:5
) –Number of warmup steps for learning rate scheduling.
-
box_size
(float
, default:1.0
) –Size of the periodic box domain. Currently, this is not used to scale the input data.
-
symmetric
(bool
, default:False
) –Whether to use symmetric (cos-only) periodic MLP architecture or a periodic (sin+cos) MLP architecture.
-
diffusion
(Callable[[Float[ArrayLike, ' n_features']], Float[ArrayLike, '']]
, default:lambda x: 1.0
) –Position-dependent diffusion function. Defaults to constant diffusion.
-
pbc_bins
(int
, default:0
) –Number of bins for periodic boundary condition corrections. If 0, no PBC corrections are applied.
Attributes:
-
params
(Params
) –Trained model parameters (available after training).
-
dim
(int
) –Dimensionality of the data (set during training).
-
score_model
(ScorePeriodicMLP or ScoreSymmetricPeriodicMLP
) –The neural network used for score function approximation.
Methods:
-
train
–Train the model on provided data (\(X\in[0, 1]\)) with forces.
-
sample
–Generate samples from the learned distribution.
-
evaluate
–Evaluate the model loss on held-out data.
-
score
–Compute the score function at given positions and time.
-
energy
–Compute the energy function at given positions and time.
Notes
The model implements the Fokker-Planck score learning approach for diffusion models on periodic domains. It combines:
- Linear force scheduling for non-equilibrium dynamics
- Linear prior scheduling for interpolation between prior and data
- Exponential variance noise scheduling
- Uniform prior distribution on [0, 1]
The training objective includes both score matching and energy regularization terms, with support for periodic boundary conditions.
Examples:
>>> import jax.random as jr
>>> from fpsl import FPSL
>>>
>>> # Create model
>>> key = jr.PRNGKey(42)
>>> model = FPSL(
... mlp_network=(64, 64, 64),
... key=key,
... n_epochs=50,
... batch_size=64,
... )
>>>
>>> # Train on data
>>> X = jr.uniform(key, (1000, 1)) # periodic data
>>> y = jr.normal(key, (1000, 1)) # force data
>>> lrs = [1e-6, 1e-4] # Learning rate range
>>> loss_hist = model.train(X, y, lrs)
score_model
cached
property
¶
Create and cache the neural network.
score(x, t, y=None)
¶
Compute the diffusion score function at given positions and time.
The score function represents the gradient of the log probability density with respect to the input coordinates: \(\nabla_x \ln p_t(x)\).
Parameters:
-
x
(Float[ArrayLike, 'n_samples n_features']
) –Input positions where to evaluate the score function.
-
t
(float
) –Time parameter in [0, 1], where t=1 is pure noise and t=0 is data.
-
y
(Float[ArrayLike, ''] or None
, default:None
) –Optional force/conditioning variable. If None, uses equilibrium score.
Returns:
-
Float[ArrayLike, 'n_samples n_features']
–Score function values at each input position.
Notes
The score function is computed as:
Source code in src/fpsl/ddm/models.py
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
|
energy(x, t, y=None)
¶
Compute the energy function at given positions and time.
The energy function represents the negative log probability density up to a constant: \(E_\theta(x, t) = -\ln p_t(x) + C\).
Parameters:
-
x
(Float[ArrayLike, 'n_samples n_features']
) –Input positions where to evaluate the energy function.
-
t
(float
) –Time parameter in [0, 1], where \(t=1\) is pure noise and \(t=0\) is data.
-
y
(Float[ArrayLike, ''] or None
, default:None
) –Optional force/conditioning variable. If None, uses equilibrium energy.
Returns:
-
Float[ArrayLike, ' n_samples']
–Energy function values at each input position.
Notes
The energy function is related to the score function by: $$ \nabla_x E_\theta(x, t) = -s_\theta(x, t)\sigma(t) $$
Source code in src/fpsl/ddm/models.py
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
|
train(X, y, lrs, key=None, n_epochs=None, X_val=None, y_val=None, project='entropy-prod-diffusion', wandb_kwargs={})
¶
Train the FPSL model on the provided dataset.
This method trains the score function neural network using a combination of score matching loss and energy regularization. The training uses warmup cosine decay learning rate scheduling and AdamW optimizer.
Parameters:
-
X
(Float[ArrayLike, 'n_samples n_features']
) –Training data positions. Must be 2D array with shape (n_samples, n_features). Data should be in the periodic domain [0, 1].
-
y
(Float[ArrayLike, ' n_features']
) –Force/conditioning variables corresponding to each sample in X.
-
lrs
(Float[ArrayLike, 2]
) –Learning rate range as [min_lr, max_lr] for warmup cosine decay schedule.
-
key
(JaxKey or None
, default:None
) –Random key for reproducible training. If None, uses self.key.
-
n_epochs
(int or None
, default:None
) –Number of training epochs. If None, uses self.n_epochs.
-
X_val
(Float[ArrayLike, 'n_val n_features'] or None
, default:None
) –Validation data positions. If provided, validation loss will be computed.
-
y_val
(Float[ArrayLike, ' n_features'] or None
, default:None
) –Validation force variables. Required if X_val is provided.
-
project
(str
, default:'entropy-prod-diffusion'
) –Weights & Biases project name for logging (if wandb_log=True).
-
wandb_kwargs
(dict
, default:{}
) –Additional keyword arguments passed to wandb.init().
Returns:
-
dict
–Dictionary containing training history with keys: - 'train_loss': Array of training losses for each epoch - 'val_loss': Array of validation losses (if validation data provided)
Raises:
-
ValueError
–If X is not a 2D array.
Notes
The training objective combines:
1. Score matching loss with periodic boundary handling
2. Energy regularization term controlled by gamma_energy_regularization
The model parameters are stored in self.params
after training.
Source code in src/fpsl/ddm/models.py
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 |
|
evaluate(X, y=None, key=None)
¶
Evaluate the model loss on held-out data.
Computes the same loss function used during training (score matching + energy regularization) on the provided data without updating model parameters.
Parameters:
-
X
(Float[ArrayLike, 'n_samples n_features']
) –Test data positions in the periodic domain [0, 1].
-
y
(Float[ArrayLike, ' n_features'] or None
, default:None
) –Force/conditioning variables for the test data. If None, assumes equilibrium evaluation.
-
key
(JaxKey or None
, default:None
) –Random key for stochastic evaluation. If None, uses self.key.
Returns:
-
float
–Evaluation loss value.
Notes
This method is useful for monitoring generalization performance on validation or test sets during or after training.
Source code in src/fpsl/ddm/models.py
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 |
|
sample(key, n_samples, t_final=0, n_steps=None)
¶
Generate samples from the learned probability distribution.
Uses reverse-time SDE integration to generate samples by starting from the prior distribution and integrating backwards through the diffusion process using the learned score function.
Parameters:
-
key
(JaxKey
) –Random key for reproducible sampling.
-
n_samples
(int
) –Number of samples to generate.
-
t_final
(float
, default:0
) –Final time for the reverse integration. \(t=0\) corresponds to the data distribution, \(t=1\) to pure noise.
-
n_steps
(int or None
, default:None
) –Number of integration steps for the reverse SDE. If None, uses self.n_sample_steps.
Returns:
-
Float[ArrayLike, 'n_samples n_dims']
–Generated samples from the learned distribution.
Notes
The sampling procedure follows the reverse-time SDE:
where \(s_\theta\) is the learned score function and \(\beta(t)\) is the noise schedule. For periodic domains, the samples are wrapped to \([0, 1]\) at each step.
Examples:
>>> # Generate 100 samples
>>> samples = model.sample(key, n_samples=100)
>>>
>>> # Generate with custom integration steps
>>> samples = model.sample(key, n_samples=50, n_steps=200)
Source code in src/fpsl/ddm/models.py
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 |
|