Source code for pymc.ode.ode

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import logging

import numpy as np
import pytensor
import pytensor.tensor as pt
import scipy

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.type import TensorType

from pymc.exceptions import ShapeError
from pymc.ode import utils

_log = logging.getLogger(__name__)
floatX = pytensor.config.floatX


[docs] class DifferentialEquation(Op): r""" Specify an ordinary differential equation. Due to the nature of the model (as well as included solvers), the process of ODE solution may perform slowly. A faster alternative library based on PyMC--sunode--has implemented Adams' method and BDF (backward differentation formula). More information about sunode is available at: https://github.com/aseyboldt/sunode. .. math:: \dfrac{dy}{dt} = f(y,t,p) \quad y(t_0) = y_0 Parameters ---------- func : callable Function specifying the differential equation. Must take arguments y (n_states,), t (scalar), p (n_theta,) times : array Array of times at which to evaluate the solution of the differential equation. n_states : int Dimension of the differential equation. For scalar differential equations, n_states=1. For vector valued differential equations, n_states = number of differential equations in the system. n_theta : int Number of parameters in the differential equation. t0 : float Time corresponding to the initial condition Examples -------- .. code-block:: python def odefunc(y, t, p): # Logistic differential equation return p[0] * y[0] * (1 - y[0]) times = np.arange(0.5, 5, 0.5) ode_model = DifferentialEquation(func=odefunc, times=times, n_states=1, n_theta=1, t0=0) """ _itypes = [ TensorType(floatX, (False,)), # y0 as 1D floatX vector TensorType(floatX, (False,)), # theta as 1D floatX vector ] _otypes = [ TensorType(floatX, (False, False)), # model states as floatX of shape (T, S) TensorType( floatX, (False, False, False) ), # sensitivities as floatX of shape (T, S, len(y0) + len(theta)) ] __props__ = ("func", "times", "n_states", "n_theta", "t0")
[docs] def __init__(self, func, times, *, n_states, n_theta, t0=0): if not callable(func): raise ValueError("Argument func must be callable.") if n_states < 1: raise ValueError("Argument n_states must be at least 1.") if n_theta <= 0: raise ValueError("Argument n_theta must be positive.") # Public self.func = func self.t0 = t0 self.times = tuple(times) self.n_times = len(times) self.n_states = n_states self.n_theta = n_theta self.n_p = n_states + n_theta # Private self._augmented_times = np.insert(times, 0, t0).astype(floatX) self._augmented_func = utils.augment_system(func, self.n_states, self.n_theta) self._sens_ic = utils.make_sens_ic(self.n_states, self.n_theta, floatX) # Cache symbolic sensitivities by the hash of inputs self._apply_nodes = {} self._output_sensitivities = {}
def _system(self, Y, t, p): r"""Solve both ODE and sensitivities. This function will be passed to odeint. Parameters ---------- Y : array augmented state vector (n_states + n_states + n_theta) t : float current time p : array parameter vector (y0, theta) """ dydt, ddt_dydp = self._augmented_func(Y[: self.n_states], t, p, Y[self.n_states :]) derivatives = np.concatenate([dydt, ddt_dydp]) return derivatives def _simulate(self, y0, theta): # Initial condition comprised of state initial conditions and raveled sensitivity matrix s0 = np.concatenate([y0, self._sens_ic]) # perform the integration sol = scipy.integrate.odeint( func=self._system, y0=s0, t=self._augmented_times, args=(np.concatenate([y0, theta]),) ).astype(floatX) # The solution y = sol[1:, : self.n_states] # The sensitivities, reshaped to be a sequence of matrices sens = sol[1:, self.n_states :].reshape(self.n_times, self.n_states, self.n_p) return y, sens
[docs] def make_node(self, y0, theta): inputs = (y0, theta) _log.debug(f"make_node for inputs {hash(inputs)}") states = self._otypes[0]() sens = self._otypes[1]() # store symbolic output in dictionary such that it can be accessed in the grad method self._output_sensitivities[hash(inputs)] = sens return Apply(self, inputs, (states, sens))
def __call__(self, y0, theta, return_sens=False, **kwargs): if isinstance(y0, list | tuple) and not len(y0) == self.n_states: raise ShapeError("Length of y0 is wrong.", actual=(len(y0),), expected=(self.n_states,)) if isinstance(theta, list | tuple) and not len(theta) == self.n_theta: raise ShapeError( "Length of theta is wrong.", actual=(len(theta),), expected=(self.n_theta,) ) # convert inputs to tensors (and check their types) y0 = pt.cast(pt.as_tensor_variable(y0), floatX) theta = pt.cast(pt.as_tensor_variable(theta), floatX) inputs = [y0, theta] for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)): if not itype.is_super(input_val.type): raise ValueError( f"Input {i} of type {input_val.type} does not have the expected type of {itype}" ) # use default implementation to prepare symbolic outputs (via make_node) states, sens = super().__call__(y0, theta, **kwargs) if return_sens: return states, sens return states
[docs] def perform(self, node, inputs_storage, output_storage): y0, theta = inputs_storage[0], inputs_storage[1] # simulate states and sensitivities in one forward pass output_storage[0][0], output_storage[1][0] = self._simulate(y0, theta)
[docs] def infer_shape(self, fgraph, node, input_shapes): s_y0, s_theta = input_shapes output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)] return output_shapes
[docs] def grad(self, inputs, output_grads): _log.debug(f"grad w.r.t. inputs {hash(tuple(inputs))}") # fetch symbolic sensitivity output node from cache ihash = hash(tuple(inputs)) if ihash in self._output_sensitivities: sens = self._output_sensitivities[ihash] else: _log.debug("No cached sensitivities found!") _, sens = self.__call__(*inputs, return_sens=True) ograds = output_grads[0] # for each parameter, multiply sensitivities with the output gradient and sum the result # sens is (n_times, n_states, n_p) # ograds is (n_times, n_states) grads = [pt.sum(sens[:, :, p] * ograds) for p in range(self.n_p)] # return separate gradient tensors for y0 and theta inputs result = pt.stack(grads[: self.n_states]), pt.stack(grads[self.n_states :]) return result