diff --git a/multiarea_model/multiarea_model.py b/multiarea_model/multiarea_model.py index 0cd14df17838caf448196befdca6774f616b9f15..e6766d8060724b58ef63e1e16ecffec4917c14ef 100644 --- a/multiarea_model/multiarea_model.py +++ b/multiarea_model/multiarea_model.py @@ -206,6 +206,15 @@ class Model: ana_spec = keywords['ana_spec'] self.init_analysis(ana_spec) + def create(self): + self.simulation.create() + + def connect(self): + self.simulation.connect() + + def simulate(self): + self.simulation.simulate() + def __str__(self): s = "Multi-area network {} with custom parameters: \n".format(self.label) s += pprint.pformat(self.params, width=1) diff --git a/multiarea_model/simulation.py b/multiarea_model/simulation.py index 1615c86bcd3aad17bf4bd20fb337c487a0dd9fab..66f719c9b374d5d55123a5eb302eac5c4ee9c653 100644 --- a/multiarea_model/simulation.py +++ b/multiarea_model/simulation.py @@ -93,6 +93,8 @@ class Simulation: self.areas_recorded = self.params['recording_dict']['areas_recorded'] self.T = self.params['t_sim'] + self.prepare() + def __eq__(self, other): # Two simulations are equal if the simulation parameters and # the simulated networks are equal. @@ -185,6 +187,13 @@ class Simulation: status_dict.update({'label': label}) nest.SetStatus(self.voltmeter, status_dict) + def connect_areas(self): + """ + Create all areas with their populations and internal connections. + """ + for area in self.areas: + area.connect_area() + def create_areas(self): """ Create all areas with their populations and internal connections. @@ -272,40 +281,26 @@ class Simulation: source_area.name, cc_input[source_area.name]) - def simulate(self): + def create(self): """ - Create the network and execute simulation. - Record used memory and wallclock time. + Create the network. """ - t0 = time.time() - self.base_memory = self.memory() - self.prepare() - t1 = time.time() - self.time_prepare = t1 - t0 - print("Prepared simulation in {0:.2f} seconds.".format(self.time_prepare)) - self.create_recording_devices() self.create_areas() - t2 = time.time() - self.time_network_local = t2 - t1 - print("Created areas and internal connections in {0:.2f} seconds.".format( - self.time_network_local)) + def connect(self): + """ + Connect the network. + """ + self.connect_areas() self.cortico_cortical_input() - t3 = time.time() - self.network_memory = self.memory() - self.time_network_global = t3 - t2 - print("Created cortico-cortical connections in {0:.2f} seconds.".format( - self.time_network_global)) - self.save_network_gids() + def simulate(self): + """ + Simulate the network. + """ nest.Simulate(self.T) - t4 = time.time() - self.time_simulate = t4 - t3 - self.total_memory = self.memory() - print("Simulated network in {0:.2f} seconds.".format(self.time_simulate)) - self.logging() def memory(self): """ @@ -403,8 +398,6 @@ class Area: self.external_synapses[pop] = self.network.K[self.name][pop]['external']['external'] self.create_populations() - self.connect_devices() - self.connect_populations() print("Rank {}: created area {} with {} local nodes".format(nest.Rank(), self.name, self.num_local_nodes)) @@ -469,6 +462,13 @@ class Area: ) self.num_local_nodes += len(local_nodes_pop) + def connect_area(self): + """ + Create connections from devices and between populations. + """ + self.connect_devices() + self.connect_populations() + def connect_populations(self): """ Create connections between populations. diff --git a/run_example_downscaled.py b/run_example_downscaled.py index 6ccfe3c2767ca342d72ec7759bf36896cfe361c3..c875439494d54634e88a82710595143f224cfacc 100644 --- a/run_example_downscaled.py +++ b/run_example_downscaled.py @@ -1,7 +1,7 @@ import numpy as np import os -from multiarea_model import Model +import multiarea_model as model from config import base_path """ @@ -37,11 +37,10 @@ sim_params = {'t_sim': 2000., theory_params = {'dt': 0.1} -M = Model(network_params, simulation=True, +M = model.Model(network_params, simulation=True, sim_spec=sim_params, theory=True, theory_spec=theory_params) -p, r = M.theory.integrate_siegert() -print("Mean-field theory predicts an average " - "rate of {0:.3f} spikes/s across all populations.".format(np.mean(r[:, -1]))) -M.simulation.simulate() +M.create() +M.connect() +M.simulate()