Source code for tvb.adapters.visualizers.phase_plane_interactive

# -*- coding: utf-8 -*-
#
#
# TheVirtualBrain-Framework Package. This package holds all Data Management, and
# Web-UI helpful to run brain-simulations. To use it, you also need to download
# TheVirtualBrain-Scientific Package (for simulators). See content of the
# documentation-folder for more details. See also http://www.thevirtualbrain.org
#
# (c) 2012-2023, Baycrest Centre for Geriatric Care ("Baycrest") and others
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE.  See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with this
# program.  If not, see <http://www.gnu.org/licenses/>.
#
#
#   CITATION:
# When using The Virtual Brain for scientific publications, please cite it as explained here:
# https://www.thevirtualbrain.org/tvb/zwei/neuroscience-publications
#
#
"""
.. moduleauthor:: Ionel Ortelecan <ionel.ortelecan@codemart.ro>
.. moduleauthor:: Mihai Andrei <mihai.andrei@codemart.ro>
"""
import numpy
import six
from tvb.basic.logger.builder import get_logger

# To plot nullclines we need a function that computes contours of a scalar field.
# We use the internal matplotlib one.
# Now that newer matplotlib versions have changed this internal API it would be a
# good idea to use a proper library for this.

try:
    from matplotlib import _cntr
    # older matplotlib

    def nullcline(x, y, z):
        c = _cntr.Cntr(x, y, z)
        # trace a contour
        res = c.trace(0.0)
        if not res:
            return numpy.array([])
        # result is a list of arrays of vertices and path codes
        # (see docs for matplotlib.path.Path)
        nseg = len(res) // 2
        segments, codes = res[:nseg], res[nseg:]
        return segments

except ImportError:
    from matplotlib import _contour
    # newer matplotlib >= 2.2

[docs] def nullcline(x, y, z): c = _contour.QuadContourGenerator(x, y, z, None, True, 0) segments = c.create_contour(0.0) return segments[0]
# how much courser is the grid used to show the vectors GRID_SUBSAMPLE = 2 # Set the resolution of the phase-plane and sample trajectories. NUMBEROFGRIDPOINTS = GRID_SUBSAMPLE * 42 class _PhaseSpace(object): """ Dimensionality independent code """ def __init__(self, model, integrator): self.log = get_logger(self.__class__.__module__) self.model = model self.integrator = integrator def _compute_trajectories(self, states, n_steps): """ A vectorized method of computing a number of trajectories in parallel. Returns a collection of nvar-dimensional trajectories. """ scheme = self.integrator.scheme trajs = numpy.zeros((n_steps + 1, self.model.nvar, len(states), self.model.number_of_modes)) # reshape to what dfun expects: from n, sv to sv, n, mode states = numpy.tile(states.T[:, :, numpy.newaxis], self.model.number_of_modes) trajs[0, :] = states no_coupling = numpy.zeros((self.model.nvar, states.shape[1], self.model.number_of_modes)) # grow trajectories step by step for step in range(n_steps): states = scheme(states, self.model.dfun, no_coupling, 0.0, 0.0) trajs[step + 1, :] = states if numpy.isnan(trajs).any(): self.log.warning("NaN in trajectories") return trajs def get_axes_ranges(self, sv): lo, hi = self.model.state_variable_range[sv] default_range = hi - lo min_val = lo - 4.0 * default_range max_val = hi + 4.0 * default_range return min_val, max_val, lo, hi
[docs] class PhasePlane(_PhaseSpace): """ Responsible with computing phase space slices and trajectories. A collection of math-y utilities it is view independent (holds no state related to views). """ @staticmethod def _create_mesh_jitter(): shape = NUMBEROFGRIDPOINTS, NUMBEROFGRIDPOINTS d = 1.0 / (4 * NUMBEROFGRIDPOINTS) return numpy.random.normal(0, d, shape), numpy.random.normal(0, d, shape) @staticmethod def _get_mesh_grid(x_range, y_range, noise=None): """ Generate the phase-plane gridding based on the given state-variable indices and their range values. """ xlo, xhi = x_range ylo, yhi = y_range xg = numpy.mgrid[xlo:xhi:(NUMBEROFGRIDPOINTS * 1j)] yg = numpy.mgrid[ylo:yhi:(NUMBEROFGRIDPOINTS * 1j)] xgr, ygr = numpy.meshgrid(xg, yg) if noise: # add scaled noise xgr += noise[0] * (xhi - xlo) ygr += noise[1] * (yhi - ylo) return xgr, ygr def _calc_phase_plane(self, state, svx_ind, svy_ind, xg, yg): """ Computes a 2d axis aligned rectangle of the vector field returning a u, v vector field. The slice passes through the `state` point and varies along the axes given by svx_ind and svy_ind. The last 2 parameters specify the mesh in the varying directions. To be computed by _get_mesh_grid Vectorized function, it evaluates all grid points at once as if they were connectivity nodes. """ state_variables = numpy.tile(state, (NUMBEROFGRIDPOINTS ** 2, 1)) for mode_idx in range(self.model.number_of_modes): state_variables[svx_ind, :, mode_idx] = xg.flat state_variables[svy_ind, :, mode_idx] = yg.flat no_coupling = numpy.zeros((self.model.nvar, state_variables.shape[1], self.model.number_of_modes)) d_grid = self.model.dfun(state_variables, no_coupling) flat_uv_grid = d_grid[[svx_ind, svy_ind], :, :] # subset of the state variables to be displayed u, v = flat_uv_grid.reshape((2, NUMBEROFGRIDPOINTS, NUMBEROFGRIDPOINTS, self.model.number_of_modes)) if numpy.isnan(u).any() or numpy.isnan(v).any(): self.log.error("NaN") return u, v
[docs] class PhasePlaneD3(PhasePlane): """ Provides data for a d3 client """ def __init__(self, model, integrator): PhasePlane.__init__(self, model, integrator) self.mode = 0 self.svx_ind = 0 # x-axis: 1st state variable if self.model.nvar > 1: self.svy_ind = 1 # y-axis: 2nd state variable else: self.svy_ind = 0 self.x_range = None self.y_range = None # Set up a vector containing the default state-variable values svr = self.model.state_variable_range sv_mean = numpy.array([svr[key].mean() for key in self.model.state_variables]) sv_mean = sv_mean.reshape((self.model.nvar, 1, 1)) self.default_sv = sv_mean.repeat(self.model.number_of_modes, axis=2) self.update_integrator_clamping() self._jitter = None # self._create_mesh_jitter()
[docs] def update_integrator_clamping(self): clamped_sv_indices = [i for i in range(self.model.nvar) if i not in [self.svx_ind, self.svy_ind]] if clamped_sv_indices: self.integrator.clamped_state_variable_indices = numpy.array(clamped_sv_indices) self.integrator.clamped_state_variable_values = self.default_sv[ self.integrator.clamped_state_variable_indices] else: self.integrator.clamped_state_variable_indices = None self.integrator.clamped_state_variable_values = None
[docs] def update_axis(self, mode, svx, svy, x_range, y_range, state_vars): self.mode = int(mode) self.svx_ind = self.model.state_variables.index(svx) self.svy_ind = self.model.state_variables.index(svy) self.x_range = x_range self.y_range = y_range for name, val in six.iteritems(state_vars): k = self.model.state_variables.index(name) self.default_sv[k] = val self.update_integrator_clamping()
[docs] def compute_phase_plane(self): """ :return: A json representation of the phase plane. """ x, y = self._get_mesh_grid(self.x_range, self.y_range, noise=self._jitter) u, v = self._calc_phase_plane(self.default_sv, self.svx_ind, self.svy_ind, x, y) u = u[..., self.mode] # project on active mode v = v[..., self.mode] xnull = [{'path': segment.tolist(), 'nullcline_index': 0} for segment in nullcline(x, y, u)] ynull = [{'path': segment.tolist(), 'nullcline_index': 1} for segment in nullcline(x, y, v)] # a courser mesh for the arrows xsmall = x[::GRID_SUBSAMPLE, ::GRID_SUBSAMPLE] ysmall = y[::GRID_SUBSAMPLE, ::GRID_SUBSAMPLE] usmall = u[::GRID_SUBSAMPLE, ::GRID_SUBSAMPLE] vsmall = v[::GRID_SUBSAMPLE, ::GRID_SUBSAMPLE] d = numpy.dstack((xsmall, ysmall, usmall, vsmall)) d = d.reshape(((NUMBEROFGRIDPOINTS // GRID_SUBSAMPLE) ** 2, 4)).tolist() return {'plane': d, 'nullclines': xnull + ynull}
def _state_dict_to_array(self, state): arr = numpy.zeros(len(self.model.state_variables)) for svn, v in six.iteritems(state): svn_idx = self.model.state_variables.index(svn) arr[svn_idx] = v return arr
[docs] def trajectories(self, starting_points, n_steps=512): """ :param starting_points: A list of starting points represented as dicts of state_var_name to value :return: a tuple of trajectories and signals """ starting_points = numpy.array([self._state_dict_to_array(s) for s in starting_points]) traj = self._compute_trajectories(starting_points, n_steps) # point_on_traj_idx, sv_idx, traj_idx, mode # reshape it and project it on the plane defined by the current axis state vars traj = traj.transpose(2, 0, 1, 3) # traj_idx, point, sv_idx, mode trajectory = traj[:, :, [self.svx_ind, self.svy_ind], self.mode] # traj_idx, points, x, y # signals for last trajectory signal_x = numpy.arange(n_steps + 1) * self.integrator.dt signals = [list(zip(signal_x, traj[-1, :, i, self.mode].tolist())) for i in [self.svx_ind, self.svy_ind]] return trajectory.tolist(), signals
[docs] class PhaseLineD3(_PhaseSpace): def __init__(self, model, integrator): _PhaseSpace.__init__(self, model, integrator) self.mode = 0 def _grid(self): svr = self.model.state_variable_range xlo, xhi = svr[self.model.state_variables[0]] return numpy.linspace(xlo, xhi, NUMBEROFGRIDPOINTS)
[docs] def compute_phase_plane(self): xg = self._grid() # dfun modifies state in place, so we need to copy xg state = xg.reshape((1, NUMBEROFGRIDPOINTS, 1)).copy() # will broadcast to modes no_coupling = numpy.zeros((self.model.nvar, state.shape[1], self.model.number_of_modes)) u = self.model.dfun(state, no_coupling) u = u[0, :, self.mode] d = numpy.vstack((xg, u)).T if numpy.isnan(d).any(): self.log.error("NaN") # find zeroes. This method is not exact zero_crossings = numpy.where(numpy.diff(numpy.sign(u)))[0] zero_crossings = xg[zero_crossings] return {'signal': d.tolist(), 'zeroes': zero_crossings.tolist()}
[docs] def update_axis(self, mode, svx, x_range): self.mode = int(mode) svr = self.model.state_variable_range svr[svx][:] = x_range
[docs] def phase_space_d3(model, integrator): """ :return: A phase plane or a phase line depending on the dimensionality of the model """ if model.nvar == 1: return PhaseLineD3(model, integrator) else: return PhasePlaneD3(model, integrator)