-
bc10b556
M2E_visualize_resting_state.py 16.92 KiB
import json
import numpy as np
import os
import sys
sys.path.append('./figures/Schmidt2018_dyn')
# from helpers import original_data_path, population_labels
from helpers import population_labels
from multiarea_model import MultiAreaModel
from multiarea_model import Analysis
from plotcolors import myred, myblue
import matplotlib.pyplot as pl
from matplotlib import gridspec
# from matplotlib import rc_file
# rc_file('plotstyle.rc')
icolor = myred
ecolor = myblue
from M2E_LOAD_DATA import load_and_create_data
def set_boxplot_props(d):
for i in range(len(d['boxes'])):
if i % 2 == 0:
d['boxes'][i].set_facecolor(icolor)
d['boxes'][i].set_color(icolor)
else:
d['boxes'][i].set_facecolor(ecolor)
d['boxes'][i].set_color(ecolor)
pl.setp(d['whiskers'], color='k')
pl.setp(d['fliers'], color='k', markerfacecolor='k', marker='+')
pl.setp(d['medians'], color='none')
pl.setp(d['caps'], color='k')
pl.setp(d['means'], marker='x', color='k',
markerfacecolor='k', markeredgecolor='k', markersize=3.)
def plot_resting_state(M, data_path, raster_areas=['V1', 'V2', 'FEF']):
"""
Analysis class.
An instance of the analysis class for the given network and simulation.
Can be created as a member class of a multiarea_model instance or standalone.
Parameters
----------
network : MultiAreaModel
An instance of the multiarea_model class that specifies
the network to be analyzed.
simulation : Simulation
An instance of the simulation class that specifies
the simulation to be analyzed.
data_list : list of strings {'spikes', vm'}, optional
Specifies which type of data is to load. Defaults to ['spikes'].
load_areas : list of strings with area names, optional
Specifies the areas for which data is to be loaded.
Default value is None and leads to loading of data for all
simulated areas.
"""
# Instantiate an analysis class and load spike data
A = Analysis(network=M,
simulation=M.simulation,
data_list=['spikes'],
load_areas=None)
# load data
load_and_create_data(M, A)
t_sim = M.simulation.params["t_sim"]
"""
Figure layout
"""
nrows = 4
ncols = 4
# width = 7.0866
width = 10
panel_wh_ratio = 0.7 * (1. + np.sqrt(5)) / 2. # golden ratio
height = width / panel_wh_ratio * float(nrows) / ncols
pl.rcParams['figure.figsize'] = (width, height)
fig = pl.figure()
axes = {}
gs1 = gridspec.GridSpec(1, 3)
gs1.update(left=0.06, right=0.72, top=0.95, wspace=0.4, bottom=0.35)
# axes['A'] = pl.subplot(gs1[:-1, :1])
# axes['B'] = pl.subplot(gs1[:-1, 1:2])
# axes['C'] = pl.subplot(gs1[:-1, 2:])
axes['A'] = pl.subplot(gs1[:1, :1])
axes['B'] = pl.subplot(gs1[:1, 1:2])
axes['C'] = pl.subplot(gs1[:1, 2:])
gs2 = gridspec.GridSpec(3, 1)
gs2.update(left=0.78, right=0.95, top=0.95, bottom=0.35)
axes['D'] = pl.subplot(gs2[:1, :1])
axes['E'] = pl.subplot(gs2[1:2, :1])
axes['F'] = pl.subplot(gs2[2:3, :1])
gs3 = gridspec.GridSpec(1, 1)
gs3.update(left=0.1, right=0.95, top=0.3, bottom=0.075)
axes['G'] = pl.subplot(gs3[:1, :1])
# areas = ['V1', 'V2', 'FEF']
area_list = ['V1', 'V2', 'VP', 'V3', 'V3A', 'MT', 'V4t', 'V4', 'VOT', 'MSTd',
'PIP', 'PO', 'DP', 'MIP', 'MDP', 'VIP', 'LIP', 'PITv', 'PITd',
'MSTl', 'CITv', 'CITd', 'FEF', 'TF', 'AITv', 'FST', '7a', 'STPp',
'STPa', '46', 'AITd', 'TH']
if len(raster_areas) !=3:
raise Exception("Error! Please give 3 areas to display as raster plots.")
for area in raster_areas:
if area not in area_list:
raise Exception("Error! Given raster areas are either not from complete_area_list, please input correct areas to diaply the raster plots.")
areas = raster_areas
labels = ['A', 'B', 'C']
for area, label in zip(areas, labels):
label_pos = [-0.2, 1.01]
# pl.text(label_pos[0], label_pos[1], r'\bfseries{}' + label + ': ' + area,
# fontdict={'fontsize': 10, 'weight': 'bold',
# 'horizontalalignment': 'left', 'verticalalignment':
# 'bottom'}, transform=axes[label].transAxes)
pl.text(label_pos[0], label_pos[1], label + ': ' + area,
fontdict={'fontsize': 10, 'weight': 'bold',
'horizontalalignment': 'left', 'verticalalignment':
'bottom'}, transform=axes[label].transAxes)
label = 'G'
label_pos = [-0.1, 0.92]
# pl.text(label_pos[0], label_pos[1], r'\bfseries{}' + label,
# fontdict={'fontsize': 10, 'weight': 'bold',
# 'horizontalalignment': 'left', 'verticalalignment':
# 'bottom'}, transform=axes[label].transAxes)
pl.text(label_pos[0], label_pos[1], label,
fontdict={'fontsize': 10, 'weight': 'bold',
'horizontalalignment': 'left', 'verticalalignment':
'bottom'}, transform=axes[label].transAxes)
labels = ['E', 'D', 'F']
for label in labels:
label_pos = [-0.2, 1.05]
# pl.text(label_pos[0], label_pos[1], r'\bfseries{}' + label,
# fontdict={'fontsize': 10, 'weight': 'bold',
# 'horizontalalignment': 'left', 'verticalalignment':
# 'bottom'}, transform=axes[label].transAxes)
pl.text(label_pos[0], label_pos[1], label,
fontdict={'fontsize': 10, 'weight': 'bold',
'horizontalalignment': 'left', 'verticalalignment':
'bottom'}, transform=axes[label].transAxes)
labels = ['A', 'B', 'C', 'D', 'E', 'F']
for label in labels:
axes[label].spines['right'].set_color('none')
axes[label].spines['top'].set_color('none')
axes[label].yaxis.set_ticks_position("left")
axes[label].xaxis.set_ticks_position("bottom")
for label in ['A', 'B', 'C']:
axes[label].yaxis.set_ticks_position('none')
# """
# Load data
# """
# LOAD_ORIGINAL_DATA = True
# if LOAD_ORIGINAL_DATA:
# # use T=10500 simulation for spike raster plots
# label_spikes = '3afaec94d650c637ef8419611c3f80b3cb3ff539'
# # and T=100500 simulation for all other panels
# label = '99c0024eacc275d13f719afd59357f7d12f02b77'
# data_path = original_data_path
# else:
# from network_simulations import init_models
# from config import data_path
# models = init_models('Fig5')
# label_spikes = models[0].simulation.label
# label = models[1].simulation.label
# """
# Create MultiAreaModel instance to have access to data structures
# """
# M = MultiAreaModel({})
# spike data
# spike_data = {}
# for area in areas:
# spike_data[area] = {}
# for pop in M.structure[area]:
# spike_data[area][pop] = np.load(os.path.join(data_path,
# label_spikes,
# 'recordings',
# '{}-spikes-{}-{}.npy'.format(label_spikes,
# area, pop)))
spike_data = A.spike_data
label_spikes = M.simulation.label
label = M.simulation.label
# stationary firing rates
fn = os.path.join(data_path, label, 'Analysis', 'pop_rates.json')
with open(fn, 'r') as f:
pop_rates = json.load(f)
# time series of firing rates
rate_time_series = {}
for area in areas:
# fn = os.path.join(data_path, label,
# 'Analysis',
# 'rate_time_series_full',
# 'rate_time_series_full_{}.npy'.format(area))
fn = os.path.join(data_path, label,
'Analysis',
'rate_time_series-{}.npy'.format(area))
rate_time_series[area] = np.load(fn)
# time series of firing rates convolved with a kernel
# rate_time_series_auto_kernel = {}
# for area in areas:
# fn = os.path.join(data_path, label,
# 'Analysis',
# 'rate_time_series_auto_kernel',
# 'rate_time_series_auto_kernel_{}.npy'.format(area))
# rate_time_series_auto_kernel[area] = np.load(fn)
# local variance revised (LvR)
fn = os.path.join(data_path, label, 'Analysis', 'pop_LvR.json')
with open(fn, 'r') as f:
pop_LvR = json.load(f)
# correlation coefficients
# fn = os.path.join(data_path, label, 'Analysis', 'corrcoeff.json')
fn = os.path.join(data_path, label, 'Analysis', 'synchrony.json')
with open(fn, 'r') as f:
corrcoeff = json.load(f)
"""
Plotting
"""
# print("Raster plots")
# t_min = 3000.
# t_max = 3500.
t_min = t_sim - 500
t_max = t_sim
icolor = myred
ecolor = myblue
# frac_neurons = 0.03
frac_neurons = 1
for i, area in enumerate(areas):
ax = axes[labels[i]]
if area in spike_data:
n_pops = len(spike_data[area])
# Determine number of neurons that will be plotted for this area (for
# vertical offset)
offset = 0
n_to_plot = {}
for pop in M.structure[area]:
n_to_plot[pop] = int(M.N[area][pop] * frac_neurons)
offset = offset + n_to_plot[pop]
y_max = offset + 1
prev_pop = ''
yticks = []
yticklocs = []
for jj, pop in enumerate(M.structure[area]):
if pop[0:-1] != prev_pop:
prev_pop = pop[0:-1]
yticks.append('L' + population_labels[jj][0:-1])
yticklocs.append(offset - 0.5 * n_to_plot[pop])
ind = np.where(np.logical_and(
spike_data[area][pop][:, 1] <= t_max, spike_data[area][pop][:, 1] >= t_min))
pop_data = spike_data[area][pop][ind]
pop_neurons = np.unique(pop_data[:, 0])
neurons_to_ = np.arange(np.min(spike_data[area][pop][:, 0]), np.min(
spike_data[area][pop][:, 0]) + n_to_plot[pop], 1)
if pop.find('E') > (-1):
pcolor = ecolor
else:
pcolor = icolor
for kk in range(n_to_plot[pop]):
spike_times = pop_data[pop_data[:, 0] == neurons_to_[kk], 1]
_ = ax.plot(spike_times, np.zeros(len(spike_times)) +
offset - kk, '.', color=pcolor, markersize=1)
offset = offset - n_to_plot[pop]
y_min = offset
ax.set_xlim([t_min, t_max])
ax.set_ylim([y_min, y_max])
ax.set_yticklabels(yticks)
ax.set_yticks(yticklocs)
ax.set_xlabel('Time (s)', labelpad=-0.1)
ax.set_xticks([t_min, t_min + 250., t_max])
# ax.set_xticklabels([r'$3.$', r'$3.25$', r'$3.5$'])
l = t_min/1000
m = (t_min + t_max)/2000
r = t_max/1000
ax.set_xticklabels([f'{l:.2f}', f'{m:.2f}', f'{r:.2f}'])
# print("plotting Population rates")
rates = np.zeros((len(M.area_list), 8))
for i, area in enumerate(M.area_list):
for j, pop in enumerate(M.structure[area][::-1]):
rate = pop_rates[area][pop][0]
if rate == 0.0:
rate = 1e-5
if area == 'TH' and j > 3: # To account for missing layer 4 in TH
rates[i][j + 2] = rate
else:
rates[i][j] = rate
rates = np.transpose(rates)
masked_rates = np.ma.masked_where(rates < 1e-4, rates)
ax = axes['D']
d = ax.boxplot(np.transpose(rates), vert=False,
patch_artist=True, whis=1.5, showmeans=True)
set_boxplot_props(d)
ax.plot(np.mean(rates, axis=1), np.arange(
1., len(M.structure['V1']) + 1., 1.), 'x', color='k', markersize=3)
ax.set_yticklabels(population_labels[::-1], size=8)
ax.set_yticks(np.arange(1., len(M.structure['V1']) + 1., 1.))
ax.set_ylim((0., len(M.structure['V1']) + .5))
x_max = 220.
ax.set_xlim((-1., x_max))
ax.set_xlabel(r'Rate (spikes/s)', labelpad=-0.1)
ax.set_xticks([0., 50., 100.])
# print("plotting Synchrony")
syn = np.zeros((len(M.area_list), 8))
for i, area in enumerate(M.area_list):
for j, pop in enumerate(M.structure[area][::-1]):
value = corrcoeff[area][pop]
if value == 0.0:
value = 1e-5
if area == 'TH' and j > 3: # To account for missing layer 4 in TH
syn[i][j + 2] = value
else:
syn[i][j] = value
syn = np.transpose(syn)
masked_syn = np.ma.masked_where(syn < 1e-4, syn)
ax = axes['E']
d = ax.boxplot(np.transpose(syn), vert=False,
patch_artist=True, whis=1.5, showmeans=True)
set_boxplot_props(d)
ax.plot(np.mean(syn, axis=1), np.arange(
1., len(M.structure['V1']) + 1., 1.), 'x', color='k', markersize=3)
ax.set_yticklabels(population_labels[::-1], size=8)
ax.set_yticks(np.arange(1., len(M.structure['V1']) + 1., 1.))
ax.set_ylim((0., len(M.structure['V1']) + .5))
# ax.set_xticks(np.arange(0.0, 0.601, 0.2))
ax.set_xticks(np.arange(0.0, 10.0, 2))
ax.set_xlabel('Correlation coefficient', labelpad=-0.1)
# print("plotting Irregularity")
LvR = np.zeros((len(M.area_list), 8))
for i, area in enumerate(M.area_list):
for j, pop in enumerate(M.structure[area][::-1]):
value = pop_LvR[area][pop]
if value == 0.0:
value = 1e-5
if area == 'TH' and j > 3: # To account for missing layer 4 in TH
LvR[i][j + 2] = value
else:
LvR[i][j] = value
LvR = np.transpose(LvR)
masked_LvR = np.ma.masked_where(LvR < 1e-4, LvR)
ax = axes['F']
d = ax.boxplot(np.transpose(LvR), vert=False,
patch_artist=True, whis=1.5, showmeans=True)
set_boxplot_props(d)
ax.plot(np.mean(LvR, axis=1), np.arange(
1., len(M.structure['V1']) + 1., 1.), 'x', color='k', markersize=3)
ax.set_yticklabels(population_labels[::-1], size=8)
ax.set_yticks(np.arange(1., len(M.structure['V1']) + 1., 1.))
ax.set_ylim((0., len(M.structure['V1']) + .5))
x_max = 2.9
ax.set_xlim((0., x_max))
ax.set_xlabel('Irregularity', labelpad=-0.1)
ax.set_xticks([0., 1., 2.])
axes['G'].spines['right'].set_color('none')
axes['G'].spines['left'].set_color('none')
axes['G'].spines['top'].set_color('none')
axes['G'].spines['bottom'].set_color('none')
axes['G'].yaxis.set_ticks_position("none")
axes['G'].xaxis.set_ticks_position("none")
axes['G'].set_xticks([])
axes['G'].set_yticks([])
# print("Plotting rate time series")
pos = axes['G'].get_position()
ax = []
h = pos.y1 - pos.y0
w = pos.x1 - pos.x0
ax.append(pl.axes([pos.x0, pos.y0, w, 0.28 * h]))
ax.append(pl.axes([pos.x0, pos.y0 + 0.33 * h, w, 0.28 * h]))
ax.append(pl.axes([pos.x0, pos.y0 + 0.67 * h, w, 0.28 * h]))
colors = ['0.5', '0.3', '0.0']
# t_min = 500.
# t_max = 10500.
t_min = 500.
t_max = t_sim
# time = np.arange(500., t_max)
time = np.arange(500, t_max)
for i, area in enumerate(areas[::-1]):
ax[i].spines['right'].set_color('none')
ax[i].spines['top'].set_color('none')
ax[i].yaxis.set_ticks_position("left")
ax[i].xaxis.set_ticks_position("none")
binned_spikes = rate_time_series[area][np.where(
np.logical_and(time >= t_min, time < t_max))]
ax[i].plot(time, binned_spikes, color=colors[0], label=area)
# rate = rate_time_series_auto_kernel[area]
rate = rate_time_series[area]
ax[i].plot(time, rate, color=colors[2], label=area)
ax[i].set_xlim((t_min, t_max))
ax[i].text(0.8, 0.7, area, transform=ax[i].transAxes)
if i > 0:
ax[i].spines['bottom'].set_color('none')
ax[i].set_xticks([])
ax[i].set_yticks([0., 30.])
else:
# ax[i].set_xticks([1000., 5000., 10000.])
ax[i].set_xticks([t_min, (t_min+t_max)/2, t_max])
l = t_min/1000
m = (t_min + t_max)/2000
r = t_max/1000
# ax[i].set_xticklabels([r'$1.$', r'$5.$', r'$10.$'])
ax[i].set_xticklabels([f'{l:.2f}', f'{m:.2f}', f'{r:.2f}'])
ax[i].set_yticks([0., 5.])
if i == 1:
ax[i].set_ylabel(r'Rate (spikes/s)')
ax[0].set_xlabel('Time (s)', labelpad=-0.05)
fig.subplots_adjust(left=0.05, right=0.95, top=0.95,
bottom=0.075, wspace=1., hspace=.5)
# pl.savefig('Fig5_ground_state.eps')