Source code for tvb.adapters.visualizers.pse

# -*- coding: utf-8 -*-
#
#
# TheVirtualBrain-Framework Package. This package holds all Data Management, and 
# Web-UI helpful to run brain-simulations. To use it, you also need to download
# TheVirtualBrain-Scientific Package (for simulators). See content of the
# documentation-folder for more details. See also http://www.thevirtualbrain.org
#
# (c) 2012-2023, Baycrest Centre for Geriatric Care ("Baycrest") and others
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE.  See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with this
# program.  If not, see <http://www.gnu.org/licenses/>.
#
#
#   CITATION:
# When using The Virtual Brain for scientific publications, please cite it as explained here:
# https://www.thevirtualbrain.org/tvb/zwei/neuroscience-publications
#
#

"""
.. moduleauthor:: Lia Domide <lia.domide@codemart.ro>
.. moduleauthor:: Ionel Ortelecan <ionel.ortelecan@codemart.ro>
.. moduleauthor:: Bogdan Neacsa <bogdan.neacsa@codemart.ro>

Data for Parameter Space Exploration view will be defined here.
The purpose of this entities is to be used in Jinja2 UI, or for populating visualizer.
"""

import json
import math
import numpy
from tvb.adapters.datatypes.db.mapped_value import DatatypeMeasureIndex
from tvb.basic.config.utils import EnhancedDictionary
from tvb.core.entities.storage import dao

KEY_GID = "Gid"
KEY_TOOLTIP = "tooltip"
LINE_SEPARATOR = "<br/>"


[docs]class PSEModel(object): def __init__(self, operation): self.operation = operation self.datatype_measure = self.determine_operation_result() self.metrics = dict() if self.datatype_measure: self.metrics = json.loads(self.datatype_measure.metrics) self.range1_key = None self.range1_value = None self.range2_key = None self.range2_value = 0 self._extract_range_values() def _extract_range_values(self): ranges = json.loads(self.operation.range_values) range_keys = list(ranges.keys()) self.range1_key = range_keys[0] self.range1_value = ranges[self.range1_key] if len(range_keys) > 1: self.range2_key = range_keys[1] self.range2_value = ranges[self.range2_key] def __is_float(self, value): is_float = True try: float(value) except ValueError: is_float = False return is_float
[docs] def is_range1_float(self): return self.__is_float(self.range1_value)
[docs] def is_range2_float(self): if not self.range2_key: return False return self.__is_float(self.range2_value)
def _determine_type_and_label(self, value): if self.__is_float(value): return value label = dao.get_datatype_by_gid(value).display_name return label
[docs] def get_range1_label(self): return self._determine_type_and_label(self.range1_value)
[docs] def get_range2_label(self): if not self.range2_key: return '' return self._determine_type_and_label(self.range2_value)
[docs] def prepare_node_info(self): """ Build a dictionary with all the required information to be displayed for a given node. """ KEY_NODE_TYPE = "dataType" KEY_OPERATION_ID = "operationId" node_info = dict() if self.operation.has_finished and self.datatype_measure is not None: ts = dao.get_datatype_by_gid(self.datatype_measure.fk_source_gid) node_info[KEY_GID] = ts.gid node_info[KEY_NODE_TYPE] = ts.type node_info[KEY_OPERATION_ID] = ts.fk_from_operation node_info[KEY_TOOLTIP] = self._prepare_node_text(ts) else: tooltip = "No result available. Operation is in status: %s" % self.operation.status.split('-')[1] node_info[KEY_TOOLTIP] = tooltip return node_info
@staticmethod def _prepare_node_text(ts): """ Prepare text to display as tooltip for a PSE circle :return: str """ return str("Operation id: " + str(ts.fk_from_operation) + LINE_SEPARATOR + "Datatype gid: " + str(ts.gid) + LINE_SEPARATOR + "Datatype type: " + str(ts.type) + LINE_SEPARATOR + "Datatype subject: " + str(ts.subject) + LINE_SEPARATOR + "Datatype invalid: " + str(ts.invalid))
[docs] def determine_operation_result(self): datatype_measure = None if self.operation.has_finished: datatypes = dao.get_results_for_operation(self.operation.id) if len(datatypes) > 0: datatype = datatypes[0] if datatype.type == "DatatypeMeasureIndex": # Load proper entity class from DB. measures = dao.get_generic_entity(DatatypeMeasureIndex, datatype.gid, 'gid') else: measures = dao.get_generic_entity(DatatypeMeasureIndex, datatype.gid, 'fk_source_gid') if len(measures) > 0: datatype_measure = measures[0] return datatype_measure
[docs]class PSEGroupModel(object): def __init__(self, datatype_group_gid): self.datatype_group = dao.get_datatype_group_by_gid(datatype_group_gid) if self.datatype_group is None: raise Exception("Selected DataTypeGroup is no longer present in the database. " "It might have been remove or the specified id is not the correct one.") self.operation_group = dao.get_operationgroup_by_id(self.datatype_group.fk_operation_group) self.operations = dao.get_operations_in_group(self.operation_group.id) self.pse_model_list = self.parse_pse_data_for_display() self.all_metrics = dict() self._prepare_ranges_data() @property def range1_values(self): if self.pse_model_list[0].is_range1_float(): return self.range1_orig_values return list(range(len(self.range1_orig_values))) @property def range2_values(self): if self.pse_model_list[0].is_range2_float(): return self.range2_orig_values return list(range(len(self.range2_orig_values)))
[docs] def parse_pse_data_for_display(self): pse_model_list = [] for operation in self.operations: pse_model = PSEModel(operation) pse_model_list.append(pse_model) return pse_model_list
[docs] def get_range1_key(self): return self.pse_model_list[0].range1_key
[docs] def get_range2_key(self): return self.pse_model_list[0].range2_key
def _prepare_ranges_data(self): value_to_label1 = dict() value_to_label2 = dict() for pse_model in self.pse_model_list: label1 = pse_model.get_range1_label() value_to_label1.update({pse_model.range1_value: label1}) label2 = pse_model.get_range2_label() value_to_label2.update({pse_model.range2_value: label2}) value_to_label1 = dict(sorted(value_to_label1.items())) value_to_label2 = dict(sorted(value_to_label2.items())) self.range1_orig_values = list(value_to_label1.keys()) self.range1_labels = list(value_to_label1.values()) self.range2_orig_values = list(value_to_label2.keys()) self.range2_labels = list(value_to_label2.values())
[docs] def get_all_node_info(self): all_node_info = dict() for pse_model in self.pse_model_list: node_info = pse_model.prepare_node_info() range1_val = self.range1_values[self.range1_orig_values.index(pse_model.range1_value)] range2_val = self.range2_values[self.range2_orig_values.index(pse_model.range2_value)] if not range1_val in all_node_info: all_node_info[range1_val] = {} all_node_info[range1_val][range2_val] = node_info return all_node_info
[docs] def get_all_metrics(self): if len(self.all_metrics) == 0: for pse_model in self.pse_model_list: if pse_model.datatype_measure: ts = dao.get_datatype_by_gid(pse_model.datatype_measure.fk_source_gid) self.all_metrics.update({ts.gid: pse_model.metrics}) return self.all_metrics
[docs] def get_available_metric_keys(self): return list(self.pse_model_list[0].metrics)
[docs]class DiscretePSE(EnhancedDictionary): def __init__(self, datatype_group_gid, color_metric, size_metric, back_page): super(DiscretePSE, self).__init__() self.datatype_group_gid = datatype_group_gid self.min_color = float('inf') self.max_color = - float('inf') self.min_shape_size = float('inf') self.max_shape_size = - float('inf') self.has_started_ops = False self.status = 'started' self.title_x, self.title_y = "", "" self.values_x, self.labels_x, self.values_y, self.labels_y = [], [], [], [] self.series_array, self.available_metrics = [], [] self.d3_data = {} self.color_metric = color_metric self.size_metric = size_metric self.pse_back_page = back_page
[docs] def prepare_individual_jsons(self): """ Apply JSON.dumps on all attributes which can not be passes as they are towards UI. """ self.labels_x = json.dumps(self.labels_x) self.labels_y = json.dumps(self.labels_y) self.values_x = json.dumps(self.values_x) self.values_y = json.dumps(self.values_y) self.d3_data = json.dumps(self.d3_data) self.has_started_ops = json.dumps(self.has_started_ops)
[docs]class PSEDiscreteGroupModel(PSEGroupModel): def __init__(self, datatype_group_gid, color_metric, size_metric, back_page): super(PSEDiscreteGroupModel, self).__init__(datatype_group_gid) self.color_metric = color_metric self.size_metric = size_metric self.determine_default_metrics() self.pse_context = DiscretePSE(datatype_group_gid, self.color_metric, self.size_metric, back_page) self.fill_pse_context() self.prepare_display_data()
[docs] def determine_default_metrics(self): if self.color_metric is None and self.size_metric is None: metrics = self.get_available_metric_keys() if len(metrics) > 0: self.color_metric = metrics[0] self.size_metric = metrics[0] if len(metrics) > 1: self.size_metric = metrics[1]
[docs] def fill_pse_context(self): sizes_array = [metrics[self.size_metric] for metrics in self.get_all_metrics().values()] colors_array = [metrics[self.color_metric] for metrics in self.get_all_metrics().values()] if len(sizes_array) > 0: self.pse_context.min_shape_size = numpy.min(sizes_array) self.pse_context.max_shape_size = numpy.max(sizes_array) self.pse_context.min_color = numpy.min(colors_array) self.pse_context.max_color = numpy.max(colors_array) self.pse_context.available_metrics = self.get_available_metric_keys() self.pse_context.title_x = self.get_range1_key() self.pse_context.title_y = self.get_range2_key() self.pse_context.values_x = self.range1_values self.pse_context.values_y = self.range2_values self.pse_context.labels_x = self.range1_labels self.pse_context.labels_y = self.range2_labels
def _prepare_coords_json(self, val1, val2): return '{"x":' + str(val1) + ', "y":' + str(val2) + '}'
[docs] def prepare_display_data(self): d3_data = self.get_all_node_info() series_array = [] for val1 in d3_data.keys(): for val2 in d3_data[val1].keys(): current = d3_data[val1][val2] coords = self._prepare_coords_json(val1, val2) datatype_gid = None if KEY_GID in current: # This means the operation was finished datatype_gid = current[KEY_GID] color_weight, shape_type_1 = self.__get_color_weight(datatype_gid) if (shape_type_1 is not None) and (datatype_gid is not None): d3_data[val1][val2][KEY_TOOLTIP] += LINE_SEPARATOR + " Color metric has NaN values" current['color_weight'] = color_weight radius, shape_type_2 = self.__get_node_size(datatype_gid, len(self.range1_values), len(self.range2_values)) if (shape_type_2 is not None) and (datatype_gid is not None): d3_data[val1][val2][KEY_TOOLTIP] += LINE_SEPARATOR + " Size metric has NaN values" symbol = shape_type_1 or shape_type_2 series_array.append(self.__get_node_json(symbol, radius, coords)) self.pse_context.d3_data = d3_data self.pse_context.series_array = self.__build_series_json(series_array)
@staticmethod def __build_series_json(list_of_series): """ Given a list with all the data points, build the final FLOT JSON. """ final_json = "[" for i, value in enumerate(list_of_series): if i: final_json += "," final_json += value final_json += "]" return final_json @staticmethod def __get_node_json(symbol, radius, coords): """ For each data point entry, build the FLOT specific JSON. """ series = '{"points": {' if symbol is not None: series += '"symbol": "' + symbol + '", ' series += '"radius": ' + str(radius) + '}, ' series += '"coords": ' + coords + '}' return series def __get_node_size(self, datatype_gid, range1_length, range2_length): """ Computes the size of the shape used for representing the dataType with GID given. """ min_size, max_size = self.__get_boundaries(range1_length, range2_length) if datatype_gid is None: return max_size / 2.0, "cross" node_info = self.get_all_metrics()[datatype_gid] if self.size_metric in node_info: valid_metric = True try: if math.isnan(float(node_info[self.size_metric])) or math.isinf(float(node_info[self.size_metric])): valid_metric = False except ValueError: valid_metric = False if valid_metric: shape_weight = node_info[self.size_metric] values_range = self.pse_context.max_shape_size - self.pse_context.min_shape_size shape_range = max_size - min_size if values_range != 0: return min_size + (shape_weight - self.pse_context.min_shape_size) / float( values_range) * shape_range, None else: return min_size, None else: return max_size / 2.0, "cross" return max_size / 2.0, None def __get_color_weight(self, datatype_gid): """ Returns the color weight of the shape used for representing the dataType which id equal to 'datatype_gid'. :param: datatype_gid - It should exists into the 'datatype_indexes' dictionary. """ if datatype_gid is None: return 0, "cross" node_info = self.get_all_metrics()[datatype_gid] valid_metric = True if self.color_metric in node_info: try: if math.isnan(float(node_info[self.color_metric])) or math.isinf(float(node_info[self.color_metric])): valid_metric = False except ValueError: valid_metric = False if valid_metric: return node_info[self.color_metric], None else: return 0, "cross" return 0, None @staticmethod def __get_boundaries(range1_length, range2_length): """ Returns the MIN and the max values of the interval from which may be set the size of a certain shape. """ # the arrays 'intervals' and 'values' should have the same size intervals = [0, 3, 5, 8, 10, 20, 30, 40, 50, 60, 90, 110, 120] values = [(10, 50), (10, 40), (10, 33), (8, 25), (5, 15), (4, 10), (3, 8), (2, 6), (2, 5), (1, 4), (1, 3), (1, 2), (1, 2)] max_length = max([range1_length, range2_length]) if max_length <= intervals[0]: return values[0] elif max_length >= intervals[-1]: return values[-1] else: for i, interval in enumerate(intervals): if max_length <= interval: return values[i - 1]