diff --git a/figures/Schmidt2018_dyn/Snakefile b/figures/Schmidt2018_dyn/Snakefile index 8600bcacfb01d405a50058ec6007e131dbca9a70..494de7b8cfe8e8767bb06047a7935232670bdff0 100644 --- a/figures/Schmidt2018_dyn/Snakefile +++ b/figures/Schmidt2018_dyn/Snakefile @@ -49,7 +49,9 @@ ORIGINAL_SIMULATIONS = {'all': ['533d73357fbe99f6178029e6054b571b485f40f6', 'f18158895a5d682db5002489d12d27d7a974146f', '08a3a1a88c19193b0af9d9d8f7a52344d1b17498', '5bdd72887b191ec22a5abcc04ca4a488ea216e32', - '99c0024eacc275d13f719afd59357f7d12f02b77']} + '99c0024eacc275d13f719afd59357f7d12f02b77'], + 'Fig8': '99c0024eacc275d13f719afd59357f7d12f02b77', + 'Fig9': '99c0024eacc275d13f719afd59357f7d12f02b77'} if LOAD_ORIGINAL_DATA: @@ -65,7 +67,8 @@ rule all: 'Fig5_ground_state.eps', 'Fig6_comparison_exp_spiking_data.eps', 'Fig7_temporal_hierarchy.eps', - 'Fig8_interactions.eps' + 'Fig8_interactions.eps', + 'Fig9_laminar_interactions.eps' include: './Snakefile_preprocessing' @@ -178,3 +181,11 @@ rule Fig8_interactions: shell: 'python3 Fig8_interactions.py' +rule Fig9_laminar_interactions: + input: + expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'granger_causality', 'granger_causality_{area}_{pop}.json'), + simulation=SIMULATIONS['Fig9'], area=area_list, pop=population_list) + output: + 'Fig9_laminar_interactions.eps' + shell: + 'python3 Fig9_laminar_interactions.py' diff --git a/figures/Schmidt2018_dyn/Snakefile_preprocessing b/figures/Schmidt2018_dyn/Snakefile_preprocessing index e2f17eb33e1f4883981d6e0a57b6163dd87a8b2d..8caabbfe20906d1499359fd6dd2c892b3094e610 100644 --- a/figures/Schmidt2018_dyn/Snakefile_preprocessing +++ b/figures/Schmidt2018_dyn/Snakefile_preprocessing @@ -34,14 +34,14 @@ rule power_spectrum: shell: 'python3 compute_power_spectrum.py {} {{wildcards.simulation}} {{wildcards.area}} {{wildcards.method}}'.format(DATA_DIR) -rule stationary_rate: +rule pop_rates: input: expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings', 'spikes_{area}_{pop}.npy'), area=area_list, pop=population_list) output: os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'pop_rates.json') shell: - 'python3 compute_stationary_rate.py {{wildcards.simulation}}'.format(DATA_DIR) + 'python3 compute_pop_rates.py {{wildcards.simulation}}'.format(DATA_DIR) rule process_chu2014_data: input: @@ -125,3 +125,13 @@ rule bold_signal: shell: 'python3 compute_bold_signal.py {} {{wildcards.simulation}} {{wildcards.area}}'.format(DATA_DIR) +rule granger_causality: + input: + os.path.join(DATA_DIR, '{simulation}', 'Analysis', + 'rate_time_series_full', 'rate_time_series_full_{area}_{pop}.npy') + output: + os.path.join(DATA_DIR, '{simulation}', 'Analysis', + 'granger_causality', 'granger_causality_{area}_{pop}.json') + shell: + 'python3 compute_granger_causality.py {} {{wildcards.simulation}} {{wildcards.area}} {{wildcards.pop}}'.format(DATA_DIR) + diff --git a/figures/Schmidt2018_dyn/compute_granger_causality.py b/figures/Schmidt2018_dyn/compute_granger_causality.py new file mode 100644 index 0000000000000000000000000000000000000000..d84734e183c616dbf9419367ab50bc8638f6fc10 --- /dev/null +++ b/figures/Schmidt2018_dyn/compute_granger_causality.py @@ -0,0 +1,148 @@ +import correlation_toolbox.helper as ch +import json +import numpy as np +import os +import sys + +from multiarea_model.multiarea_model import MultiAreaModel +from multiarea_model.multiarea_helpers import create_mask +from scipy.stats import levene +from statsmodels.tsa.vector_ar.var_model import VAR + +data_path = sys.argv[1] +label = sys.argv[2] +area = sys.argv[3] +pop = sys.argv[4] +target_pair = (area, pop) + +load_path = os.path.join(data_path, + label, + 'Analysis', + 'rate_time_series_full') +save_path = os.path.join(data_path, + label, + 'Analysis', + 'granger_causality') +try: + os.mkdir(save_path) +except FileExistsError: + pass + +with open(os.path.join(data_path, label, 'custom_params_{}'.format(label)), 'r') as f: + sim_params = json.load(f) +T = sim_params['T'] + +""" +Create MultiAreaModel instance to have access to data structures +""" +connection_params = {'g': -11., + 'cc_weights_factor': sim_params['cc_weights_factor'], + 'cc_weights_I_factor': sim_params['cc_weights_I_factor'], + 'K_stable': '../SchueckerSchmidt2017/K_prime_original.npy'} +network_params = {'connection_params': connection_params} +M = MultiAreaModel(network_params) +# We exclude external input from the analysis +K = M.K_matrix[:, :-1] + + +def indices_to_population(structure, indices): + complete = [] + for area in M.area_list: + for pop in structure[area]: + complete.append(area + '-' + pop) + + complete = np.array(complete) + return complete[indices] + + +rate_time_series = {} +for source_area in M.area_list: + rate_time_series[source_area] = {} + for source_pop in M.structure[source_area]: + fn = os.path.join(load_path, + 'rate_time_series_full_{}_{}.npy'.format(source_area, source_pop)) + dat = np.load(fn) + rate_time_series[source_area][source_pop] = dat +fn = os.path.join(load_path, + 'rate_time_series_full_Parameters.json') +with open(fn, 'r') as f: + rate_time_series['Parameters'] = json.load(f) + +tmin, tmax = (500., T) +imax = int(tmax - rate_time_series['Parameters']['t_min']) +imin = int(tmin - rate_time_series['Parameters']['t_min']) + + +# Order of vector auto-regressive model + +# As potentially Granger-causal populations, we only consider source +# population with an indegree > 1 +mask = create_mask(M.structure, target_pops=[pop], target_areas=[area], external=False)[:, :-1] +pairs = indices_to_population(M.structure, np.where(K[mask] > 1.)) + +# Build a list of the time series of all source pairs onto the target pair +all_rates = [ch.centralize(rate_time_series[area][pop][imin:imax], units=True)] +target_index = 0 +source_pairs = [target_pair] +for pair in pairs: + source_area = pair.split('-')[0] + source_pop = pair.split('-')[1] + if (source_area, source_pop) != target_pair: + all_rates.append(ch.centralize(rate_time_series[source_area][source_pop][imin:imax], + units=True)) + source_pairs.append((source_area, source_pop)) + +# Fit VAR with all rates +dat = np.vstack(all_rates) +dat = dat.transpose() +model = VAR(dat) +# Order of auto-regressive regression model +selected_order = 25 +res = model.fit(selected_order) +Sigma_matrix = np.cov(res.resid.transpose()) +# Residual variance of the target population in the VAR incl. all time +# series +variance = Sigma_matrix[target_index][target_index] + +dim = res.resid[:, 0].size +k = dat.shape[1] * selected_order + +# Now we loop through all source pairs, compute the reduced VAR +# (neglecting the time series of that source pair) and then compute +# the conditional Granger causality based on this result +# causality, significance, res = [], [], [] +gc = {area: {} for area in M.area_list} +for source_index, source_pair in enumerate(source_pairs): + if source_pair != target_pair: + print(source_pair) + source_area = source_pair[0] + source_pop = source_pair[1] + # Fit marginal VAR + dat_reduced = np.vstack(all_rates[:source_index] + all_rates[source_index+1:]) + source_pairs_reduced = source_pairs[:source_index] + source_pairs[source_index+1:] + dat_reduced = dat_reduced.transpose() + model_reduced = VAR(dat_reduced) + res_reduced = model_reduced.fit(selected_order) + + Sigma_matrix_reduced = np.cov(res_reduced.resid.transpose()) + target_index_reduced = source_pairs_reduced.index(target_pair) + # Compute the conditional Granger causality as the log-ratio of the residual variances + variance_reduced = Sigma_matrix_reduced[target_index_reduced][target_index_reduced] + cause = np.log(variance_reduced / variance) + + k_reduced = dat_reduced.shape[1] * selected_order + + # Test if the residual variances are significantly different + test = levene(np.sqrt((dim - 1.)/(dim - k)) * res.resid[:, target_index], + np.sqrt((dim - 1.)/(dim - k_reduced)) * res_reduced.resid[:, target_index_reduced]) + + # causality.append(cause) + # significance.append(test[1]) + # res.append(res_reduced) + + gc[source_area][source_pop] = (cause, test[1]) + +fn = os.path.join(save_path, + 'granger_causality_{}_{}.json'.format(area, pop)) +with open(fn, 'w') as f: + json.dump(gc, f) diff --git a/figures/Schmidt2018_dyn/compute_pop_rates.py b/figures/Schmidt2018_dyn/compute_pop_rates.py new file mode 100644 index 0000000000000000000000000000000000000000..784f9a00e4345eab71b63ca2c998f0d4ede7457b --- /dev/null +++ b/figures/Schmidt2018_dyn/compute_pop_rates.py @@ -0,0 +1,50 @@ +import json +import numpy as np +import os + +from multiarea_model.analysis_helpers import pop_rate +from multiarea_model import MultiAreaModel +import sys + +data_path = sys.argv[1] +label = sys.argv[2] + +load_path = os.path.join(data_path, + label, + 'recordings') +save_path = os.path.join(data_path, + label, + 'Analysis') + +with open(os.path.join(data_path, label, 'custom_params_{}'.format(label)), 'r') as f: + sim_params = json.load(f) +T = sim_params['T'] + +""" +Create MultiAreaModel instance to have access to data structures +""" +M = MultiAreaModel({}) + +spike_data = {} +pop_rates = {} +for area in ['V1', 'V2', 'FEF']: + pop_rates[area] = {} + rate_list = [] + N = [] + for pop in M.structure[area]: + fp = '-'.join((label, + 'spikes', # assumes that the default label for spike files was used + area, + pop)) + fn = '{}/{}.npy'.format(load_path, fp) + dat = np.load(fn) + print(area, pop) + pop_rates[area][pop] = pop_rate(dat, 500., T, M.N[area][pop]) + rate_list.append(pop_rates[area][pop]) + N.append(M.N[area][pop]) + pop_rates[area]['total'] = np.average(rate_list, weights=N) + +fn = os.path.join(save_path, + 'pop_rates.json') +with open(fn, 'w') as f: + json.dump(pop_rates, f) diff --git a/figures/Schmidt2018_dyn/compute_rate_time_series.py b/figures/Schmidt2018_dyn/compute_rate_time_series.py index ec3a776fb0041cf6382211b30088490cf6e9ca1e..4a2ebfcc1d82a335955d4b72aa5a175ade5ab54e 100644 --- a/figures/Schmidt2018_dyn/compute_rate_time_series.py +++ b/figures/Schmidt2018_dyn/compute_rate_time_series.py @@ -31,6 +31,10 @@ save_path = os.path.join(data_path, label, 'Analysis', 'rate_time_series_{}'.format(method)) +try: + os.mkdir(save_path) +except FileExistsError: + pass with open(os.path.join(data_path, label, 'custom_params_{}'.format(label)), 'r') as f: sim_params = json.load(f) @@ -90,10 +94,6 @@ for pop in M.structure[area]: method, area, pop)) - try: - os.mkdir(save_path) - except FileExistsError: - pass np.save('{}/{}.npy'.format(save_path, fp), time_series) time_series_list = np.array(time_series_list)