From 67680d9f143dee8bba5f09958bb65c8609678850 Mon Sep 17 00:00:00 2001
From: Jannis <jannis.schuecker@iais.fraunhofer.de>
Date: Thu, 7 Jun 2018 15:00:50 +0200
Subject: [PATCH] work on 1D example of Figure2

---
 .../SchueckerSchmidt2017/Fig2_EE_example.py   | 103 +++++++++++-------
 1 file changed, 65 insertions(+), 38 deletions(-)

diff --git a/figures/SchueckerSchmidt2017/Fig2_EE_example.py b/figures/SchueckerSchmidt2017/Fig2_EE_example.py
index cc4be27..4f8b6ab 100644
--- a/figures/SchueckerSchmidt2017/Fig2_EE_example.py
+++ b/figures/SchueckerSchmidt2017/Fig2_EE_example.py
@@ -97,7 +97,8 @@ gs1.update(left=0.65, right=0.95, top=0.95, bottom=0.8, hspace=0.4, wspace=0.4)
 axes['B'] = pl.subplot(gs1[0, :])
 
 gs1 = gridspec.GridSpec(2, 1)
-gs1.update(left=0.65, right=0.95, top=0.7, bottom=0.075, hspace=0.4, wspace=0.4)
+gs1.update(left=0.65, right=0.95, top=0.7,
+           bottom=0.075, hspace=0.4, wspace=0.4)
 axes['D'] = pl.subplot(gs1[0, :])
 axes['G'] = pl.subplot(gs1[1, :])
 
@@ -109,8 +110,8 @@ for ax in axes.values():
 
 rate_exts_array = np.arange(150., 170.1, 1.)
 
-network_params = {'K': 210.,
-                  'W': 19.6}
+network_params = {'K': 420.,
+                  'W': 10.}
 
 for label in ['A', 'B', 'C', 'D', 'E', 'F', 'G']:
     pl.text(-0.17, 1.05, r'\bfseries{}' + label,
@@ -143,7 +144,7 @@ ax_inset.tick_params(axis='x', labelsize=4, pad=1)
 ax_inset.tick_params(axis='y', labelsize=4, pad=1)
 
 
-x = np.arange(0, 70., 1.)
+x = np.arange(0, 150., 1.)
 
 ax = axes['C']
 ax.xaxis.set_ticks_position('none')
@@ -158,13 +159,15 @@ y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
 ax.plot(x, y, '0.3')
 
 x_long = np.arange(0, 100000., 500.)
-y_long = np.fromiter([net.Phi(x_long[j])[0] for j in range(len(x_long))], dtype=np.float)
+y_long = np.fromiter([net.Phi(x_long[j])[0]
+                      for j in range(len(x_long))], dtype=np.float)
 ax_inset.plot(x_long, y_long, '0.3')
 
 # Normal network with rate_ext = 160. without refractory period
 input_params = {'rate_ext': 160.}
 network_params2 = copy.deepcopy(network_params)
-network_params2.update({'neuron_params': {'single_neuron_dict': {'t_ref': 0.}}})
+network_params2.update(
+    {'neuron_params': {'single_neuron_dict': {'t_ref': 0.}}})
 net = network1D(network_params2)
 y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
 ax.plot(x, y, color=myred)
@@ -174,15 +177,9 @@ input_params = {'rate_ext': 160.}
 network_params.update({'input_params': input_params})
 net = network1D(network_params)
 NP = net.params['neuron_params']['single_neuron_dict']
-y = []
-
-for j in range(len(x)):
-    mu, sigma = net.theory.mu_sigma(x[j])
-    if mu[0] > NP['V_th']:
-        T = NP['tau_m'] * np.log(mu[0]/(mu[0]-NP['V_th']))
-        y.append(1/T)
-    else:
-        y.append(0.)
+y = np.fromiter([net.Phi_noisefree(x[j])
+                 for j in range(len(x))], dtype=np.float)
+
 ax.plot(x, y, color=myblue)
 
 ax.plot(x, x, '--', color='k')
@@ -199,8 +196,8 @@ markers = ['d', '+', '.']
 for i, rate_ext in enumerate([150., 160., 170.]):
     input_params = {'rate_ext': rate_ext}
     network_params.update({'input_params': input_params})
-    net = network1D(network_params)
-    y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
+    net2 = network1D(network_params)
+    y = np.fromiter([net2.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
     ax.plot(x, y, colors[i])
     # Plot fixed points
     ind = np.where(np.abs(y - x) < 0.2)
@@ -264,44 +261,74 @@ ax.set_ylabel(r'Rate $\nu\quad(1/\mathrm{s})$')
 Panel F: Flux in the bistable case
 """
 ax = axes['F']
+a = pl.axes([0.27, .1, 0.12, .06])
+a.set_xticks([0, 0.03])
+a.set_xlim([0., 0.03])
+a.set_yticks([-0.02, 0.02])
+a.set_ylim([-0.02, 0.02])
+a.tick_params(axis='x', labelsize=4, pad=2)
+a.tick_params(axis='y', labelsize=4, pad=1)
 
-network_params_base = {'K': 210.,
-                       'W': 19.6,
+
+network_params_base = {'K': 420.,
+                       'W': 10.,
                        'input_params': {'rate_ext': 160.}}
-network_params_inc = {'K': 210.,
-                      'W': 19.6,
+network_params_inc = {'K': 420.,
+                      'W': 10.,
                       'input_params': {'rate_ext': 161.}}
-network_params_stab = {'K': 210.,
-                       'W': 19.6,
+network_params_stab = {'K': 420.,
+                       'W': 10.,
                        'input_params': {'rate_ext': 161.}}
 
 # Normal network with rate_ext = 160.
 net = network1D(network_params_base)
 y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
-ax.plot(y, y - x, color='k')
-
+ax.plot(x, y - x, color='k')
+a.plot(x, y - x, color='k')
 ax.hlines(0., 0., 70., linestyles='dashed')
-
+a.hlines(0., 0., 70., linestyles='dashed')
+fp_base = net.fsolve(([18.]))['rates'][0][0]
 
 # Normal network with rate_ext = 160.
 net = network1D(network_params_inc)
 y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
-ax.plot(y, y - x, color=myblue)
-
-fp = net.fsolve(([18.]))['rates'][0][0]
+ax.plot(x, y - x, color=myblue)
+a.plot(x, y - x, color=myblue, lw=4.)
+fp_inc = net.fsolve(([18.]))['rates'][0][0]
 
 # Normal network with rate_ext = 160.
-deltaK = -1. * network_params['K'] * (161. - 160.) / fp
-network_params_stab.update({'K': network_params['K'] + deltaK})
+deltaK = -1. * network_params['K'] * (161. - 160.) / fp_base
+print(network_params['K'] + deltaK)
+network_params_stab.update({'K_stable': network_params['K'] + deltaK})
+print(network_params_stab)
 net = network1D(network_params_stab)
 y = np.fromiter([net.Phi(x[j])[0] for j in range(len(x))], dtype=np.float)
-ax.plot(y, y - x, color=myred)
+ax.plot(x, y - x, color=myred)
+a.plot(x, y - x, color=myred)
+fp_s = net.fsolve(([18.]))['rates'][0][0]
 
 ax.hlines(0., 0., 70., linestyles='dashed')
 
 
 ax.set_xlabel(r'Rate $\nu\quad(1/\mathrm{s})$')
 ax.set_ylabel(r'Flux $\dot\nu\quad(1/\mathrm{s})$', labelpad=0)
+ax.set_xlim([0, 50])
+ax.set_ylim([-10, 7])
+
+ylim = 7.
+y0 = ylim
+x0 = 0.
+height = 0.5
+rect_base = pl.Rectangle((x0, y0 - 1.1 * height), width=fp_base,
+                         height=height, fill=True, color='black')
+rect_2 = pl.Rectangle((x0, y0 - 2.7 * height),
+                      width=fp_inc, height=height, fill=True, color=myblue)
+rect_s = pl.Rectangle((x0, y0 - 4.3 * height),
+                      width=fp_s, height=height, fill=True, color=myred)
+ax.add_patch(rect_base)
+ax.add_patch(rect_2)
+ax.add_patch(rect_s)
+
 
 """
 2D network
@@ -321,14 +348,14 @@ ax = axes['D']
 
 rates_init = np.arange(0., 200., 10.)
 
-network_params_base = {'K': 210.,
-                       'W': 19.6,
+network_params_base = {'K': 420.,
+                       'W': 10.,
                        'input_params': {'rate_ext': 160.}}
-network_params_inc = {'K': 210.,
-                      'W': 19.6,
+network_params_inc = {'K': 420.,
+                      'W': 10.,
                       'input_params': {'rate_ext': 161.}}
-network_params_stab = {'K': 210.,
-                       'W': 19.6,
+network_params_stab = {'K': 420.,
+                       'W': 10.,
                        'input_params': {'rate_ext': 161.}}
 
 colors = ['k', myblue, myred]
-- 
GitLab