From aeee65386e115f9116cc2663b42c8eb50ef2cc4a Mon Sep 17 00:00:00 2001 From: fousekjan <jan.fousek@univ-amu.fr> Date: Mon, 7 Jun 2021 13:40:01 +0200 Subject: [PATCH] phase plane interactive jupyterlab --- phase_plane.py | 124 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 phase_plane.py diff --git a/phase_plane.py b/phase_plane.py new file mode 100644 index 0000000..d8cae2f --- /dev/null +++ b/phase_plane.py @@ -0,0 +1,124 @@ +from tvb.basic.neotraits.api import HasTraits, Attr, NArray, List +from ipywidgets import interact, FloatSlider, Dropdown +import numpy as np +import matplotlib.pylab as plt + + +def phase_plane_interactive(model, integrator): + + + NUMBEROFGRIDPOINTS = 42 + + def plot_phase_plane(**param_kwargs): + # defaults, to be changed + svx = param_kwargs.pop('svx') #x-axis: 1st state variable + svy = param_kwargs.pop('svy') #y-axis: 2nd state variable + + + mode = param_kwargs.pop('mode') + + + # set model params + for k, v in param_kwargs.items(): + setattr(model, k, np.r_[v]) + + # state vector + sv_mean = np.array([model.state_variable_range[key].mean() for key in model.state_variables]) + sv_mean = sv_mean.reshape((model.nvar, 1, 1)) + default_sv = sv_mean.repeat(model.number_of_modes, axis=2) + no_coupling = np.zeros((model.nvar, 1, model.number_of_modes)) + + + # mesh grid + xlo = model.state_variable_range[svx][0] + xhi = model.state_variable_range[svx][1] + ylo = model.state_variable_range[svy][0] + yhi = model.state_variable_range[svy][1] + + X = np.mgrid[xlo:xhi:(NUMBEROFGRIDPOINTS*1j)] + Y = np.mgrid[ylo:yhi:(NUMBEROFGRIDPOINTS*1j)] + + + # Calculate the vector field. + svx_ind = model.state_variables.index(svx) + svy_ind = model.state_variables.index(svy) + + + #Calculate the vector field discretely sampled at a grid of points + grid_point = default_sv.copy() + U = np.zeros((NUMBEROFGRIDPOINTS, NUMBEROFGRIDPOINTS, + model.number_of_modes)) + V = np.zeros((NUMBEROFGRIDPOINTS, NUMBEROFGRIDPOINTS, + model.number_of_modes)) + for ii in range(NUMBEROFGRIDPOINTS): + grid_point[svy_ind] = Y[ii] + for jj in range(NUMBEROFGRIDPOINTS): + #import pdb; pdb.set_trace() + grid_point[svx_ind] = X[jj] + + d = model.dfun(grid_point, no_coupling) + + for kk in range(model.number_of_modes): + U[ii, jj, kk] = d[svx_ind, 0, kk] + V[ii, jj, kk] = d[svy_ind, 0, kk] + + + # plot + fig, ax = plt.subplots() + ax.set( + xlabel = "State Variable " + svx, + ylabel = "State Variable " + svy, + title = model.__class__.__name__ + " mode " + str(mode) + ) + + if np.all(U[:, :, mode] + V[:, :, mode] == 0): + ax.set(title = model_name + " mode " + mode + ": NO MOTION IN THIS PLANE") + X, Y = np.meshgrid(X, Y) + pp_quivers = ax.scatter(X, Y, s=8, marker=".", c="k") + else: + pp_quivers = ax.quiver(X, Y, + U[:, :, mode], + V[:, :, mode], + width=0.001, headwidth=8) + + #Plot the nullclines + nullcline_x = ax.contour(X, Y, + U[:, :, mode], + [0], colors="r") + nullcline_y = ax.contour(X, Y, + V[:, :, mode], + [0], colors="g") + plt.show() + + # setup widgets + param_kwargs = {} + for param_name in type(model).declarative_attrs: + param_def = getattr(type(model), param_name) + if not isinstance(param_def, NArray) or not param_def.dtype == np.float : + continue + param_range = param_def.domain + if param_range is None: + continue + param_value = getattr(model, param_name)[0] + param_kwargs[param_name] = FloatSlider( + min=param_range.lo, max=param_range.hi, value=param_value) + param_kwargs['svx'] = Dropdown( + #options=[(v,i) for i, v in enumerate(model.state_variables)], + options = model.state_variables, + value=model.state_variables[0], + description='X axis' + ) + param_kwargs['svy'] = Dropdown( + #options=[(v,i) for i, v in enumerate(model.state_variables)], + options = model.state_variables, + value=model.state_variables[1], + description='Y axis' + ) + param_kwargs['mode'] = Dropdown( + options=list(range(model.number_of_modes)), + value=0, + description='Mode' + ) + + w = interact(plot_phase_plane, **param_kwargs) + return w -- GitLab