From 7074a6abd50332714164ec4f52bae5f4162b40e9 Mon Sep 17 00:00:00 2001
From: Maximilian Schmidt <max.schmidt@fz-juelich.de>
Date: Sat, 6 Apr 2019 11:46:43 +0900
Subject: [PATCH] Fix message about stabilization

Remove misleading message from MultiAreaModel Class
Change NotImplementedError to TypeError
Adapt test
Improve documentation in README
---
 README.md                          |  2 +-
 multiarea_model/multiarea_model.py | 17 +++++++----------
 tests/test_stabilization.py        | 29 ++++++++++++++++++++++-------
 3 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/README.md b/README.md
index 2444167..e24f1e7 100644
--- a/README.md
+++ b/README.md
@@ -76,7 +76,7 @@ later release of NEST, version 2.14.0 .
 This class can be initialized by `MultiAreaModel` or as standalone and
 takes simulation parameters as input. It provides two main features:
 - predict the stable fixed point of the system using mean-field theory and characterize them (for instance by computing the gain matrix).
-- via the script `stabilize.py`, one can execute the stabilization method described in [2] on a network instance.
+- via the script `stabilize.py`, one can execute the stabilization method described in [2] on a network instance. Please see `figures/SchueckerSchmidt2017/stabilization.py` for an example of running the stabilization.
 
 `Analysis`
 
diff --git a/multiarea_model/multiarea_model.py b/multiarea_model/multiarea_model.py
index 055395b..235246e 100644
--- a/multiarea_model/multiarea_model.py
+++ b/multiarea_model/multiarea_model.py
@@ -132,18 +132,15 @@ class MultiAreaModel:
         ind, inda, out, outa = load_degree_data(tmp_data_fn)
         # If K_stable is specified in the params, load the stabilized matrix
         # TODO: Extend this by calling the stabilization method
-        if not self.params['connection_params']['K_stable']:
+        if self.params['connection_params']['K_stable'] is None:
             self.K = ind
         else:
-            if self.params['connection_params']['K_stable'] is True:
-                raise NotImplementedError('Stabilization procedure has '
-                                          'to be integrated.')
-            elif isinstance(self.params['connection_params']['K_stable'], np.ndarray):
-                raise NotImplementedError("Not supported. Please store the "
-                                          "matrix in a file and define the path to the file as "
-                                          "the parameter value.")
-            else:  # Assume that the parameter defines a filename containing the matrix
-                K_stable = np.load(self.params['connection_params']['K_stable'])
+            if not isinstance(self.params['connection_params']['K_stable'], str):
+                raise TypeError("Not supported. Please store the "
+                                "matrix in a binary numpy file and define "
+                                "the path to the file as the parameter value.")
+            # Assume that the parameter defines a filename containing the matrix
+            K_stable = np.load(self.params['connection_params']['K_stable'])
             ext = {area: {pop: ind[area][pop]['external'] for pop in
                           self.structure['V1']} for area in self.area_list}
             self.K = matrix_to_dict(
diff --git a/tests/test_stabilization.py b/tests/test_stabilization.py
index ddedbb9..5e67f83 100644
--- a/tests/test_stabilization.py
+++ b/tests/test_stabilization.py
@@ -1,15 +1,30 @@
-from multiarea_model import MultiAreaModel
+import numpy as np
 import pytest
 
+from multiarea_model import MultiAreaModel
+
 
-def test_meanfield():
+def test_stabilization():
     """
-    Test stabilization procedure. Since this algorithm is not
-    implemented yet, we here test if this properly raises a
-    NotImplementedError.
+    Test stabilization procedure. The stabilized matrix is expected to
+    be stored in a file and the parameter in the dictionary specifies
+    the corresponding name. We here check if the MultiAreaModel class
+    properly throws a TypeError when we try to directly specify the
+    matrix.
     """
 
-    network_params = {'connection_params': {'K_stable': True}}
+    # Create random matrix for indegrees
+    K_stable = np.random.rand(254, 254)
+    np.save('K_stable_test.npy', K_stable)
+
+    # Trying to directly specify the matrix should throw a TypeError.
+    network_params = {'connection_params': {'K_stable': K_stable}}
     theory_params = {}
-    with pytest.raises(NotImplementedError):
+    with pytest.raises(TypeError):
         MultiAreaModel(network_params, theory=True, theory_spec=theory_params)
+
+    # Specifying the file name leads to the correct indegrees being loaded.
+    network_params = {'connection_params': {'K_stable': 'K_stable_test.npy'}}
+    theory_params = {}
+    M = MultiAreaModel(network_params, theory=True, theory_spec=theory_params)
+    assert(np.all(K_stable == M.K_matrix[:, :-1]))
-- 
GitLab