In [1]:
Copied!
%env XLA_PYTHON_CLIENT_PREALLOCATE = false
%env XLA_PYTHON_CLIENT_PREALLOCATE = false
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
Change Noise Schedule¶
To change the noise schedule, you can simply create your own class which inherits from a self-defined noise schedule.
In [2]:
Copied!
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import fpsl
from fpsl import FPSL
# define custom noise schedule
@dataclass(kw_only=True)
class CustomNoiseSchedule(fpsl.ddm.noiseschedule.NoiseSchedule):
"""Custom noise schedule for FPSL."""
sigma_min: float = 0.05
sigma_max: float = 0.5
@property
def _noise_schedule(self) -> str:
return 'custom' # this name is only used for logging
def gamma(self, t):
# this is for legacy reasons, as FPSL used to work also on non-periodic systems
raise NotImplementedError
def sigma(self, t):
# This is the default noise schedule used in FPSL.
# return self.sigma_min ** (1 - t) * self.sigma_max**t
return self.sigma_min ** (1 - t**2) * self.sigma_max ** (t**2)
def beta(self, t):
# Since we are too lazy to implement the analytical solution for the
# custom noise schedule, we use the numerical gradient of sigma.
return jnp.vectorize(jax.grad(lambda tt: self.sigma(tt) ** 2))(t)
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import fpsl
from fpsl import FPSL
# define custom noise schedule
@dataclass(kw_only=True)
class CustomNoiseSchedule(fpsl.ddm.noiseschedule.NoiseSchedule):
"""Custom noise schedule for FPSL."""
sigma_min: float = 0.05
sigma_max: float = 0.5
@property
def _noise_schedule(self) -> str:
return 'custom' # this name is only used for logging
def gamma(self, t):
# this is for legacy reasons, as FPSL used to work also on non-periodic systems
raise NotImplementedError
def sigma(self, t):
# This is the default noise schedule used in FPSL.
# return self.sigma_min ** (1 - t) * self.sigma_max**t
return self.sigma_min ** (1 - t**2) * self.sigma_max ** (t**2)
def beta(self, t):
# Since we are too lazy to implement the analytical solution for the
# custom noise schedule, we use the numerical gradient of sigma.
return jnp.vectorize(jax.grad(lambda tt: self.sigma(tt) ** 2))(t)
Now we can define the custom FPSL by:
In [3]:
Copied!
@dataclass(kw_only=True)
class CustomFPSL(CustomNoiseSchedule, FPSL):
"""Custom FPSL class with custom noise schedule."""
pass
# generating an instance we find that is uses the new custom noise schedule
CustomFPSL(
mlp_network=[32, 32, 32],
key=jax.random.PRNGKey(0),
)._noise_schedule
@dataclass(kw_only=True)
class CustomFPSL(CustomNoiseSchedule, FPSL):
"""Custom FPSL class with custom noise schedule."""
pass
# generating an instance we find that is uses the new custom noise schedule
CustomFPSL(
mlp_network=[32, 32, 32],
key=jax.random.PRNGKey(0),
)._noise_schedule
Out[3]:
'custom'
The same way, it is possible to change the force schedule, the prior sampling, the prior schedule, etc.