#!/usr/bin/env python3

import arbor
import pandas, seaborn
from math import sqrt

# Run with srun -n NJOBS python network_ring_mpi.py

# Construct a cell with the following morphology.
# The soma (at the root of the tree) is marked 's', and
# the end of each branch i is marked 'bi'.
#
#         b1
#        /
# s----b0
#        \
#         b2

def make_cable_cell(gid):
    # (1) Build a segment tree
    tree = arbor.segment_tree()

    # Soma (tag=1) with radius 6 μm, modelled as cylinder of length 2*radius
    s = tree.append(arbor.mnpos, arbor.mpoint(-12, 0, 0, 6), arbor.mpoint(0, 0, 0, 6), tag=1)

    # Single dendrite (tag=3) of length 50 μm and radius 2 μm attached to soma.
    b0 = tree.append(s, arbor.mpoint(0, 0, 0, 2), arbor.mpoint(50, 0, 0, 2), tag=3)

    # Attach two dendrites (tag=3) of length 50 μm to the end of the first dendrite.
    # Radius tapers from 2 to 0.5 μm over the length of the dendrite.
    b1 = tree.append(b0, arbor.mpoint(50, 0, 0, 2), arbor.mpoint(50+50/sqrt(2), 50/sqrt(2), 0, 0.5), tag=3)
    # Constant radius of 1 μm over the length of the dendrite.
    b2 = tree.append(b0, arbor.mpoint(50, 0, 0, 1), arbor.mpoint(50+50/sqrt(2), -50/sqrt(2), 0, 1), tag=3)

    # Associate labels to tags
    labels = arbor.label_dict()
    labels['soma'] = '(tag 1)'
    labels['dend'] = '(tag 3)'

    # (2) Mark location for synapse at the midpoint of branch 1 (the first dendrite).
    labels['synapse_site'] = '(location 1 0.5)'
    # Mark the root of the tree.
    labels['root'] = '(root)'

    # (3) Create a decor and a cable_cell
    decor = arbor.decor()

    # Put hh dynamics on soma, and passive properties on the dendrites.
    decor.paint('"soma"', 'hh')
    decor.paint('"dend"', 'pas')

    # (4) Attach a single synapse.
    decor.place('"synapse_site"', 'expsyn')

    # Attach a spike detector with threshold of -10 mV.
    decor.place('"root"', arbor.spike_detector(-10))

    cell = arbor.cable_cell(tree, labels, decor)

    return cell

# (5) Create a recipe that generates a network of connected cells.
class ring_recipe (arbor.recipe):

    def __init__(self, ncells):
        # The base C++ class constructor must be called first, to ensure that
        # all memory in the C++ class is initialized correctly.
        arbor.recipe.__init__(self)
        self.ncells = ncells
        self.props = arbor.neuron_cable_properties()
        self.cat = arbor.default_catalogue()
        self.props.register(self.cat)

    # (6) The num_cells method that returns the total number of cells in the model
    # must be implemented.
    def num_cells(self):
        return self.ncells

    # (7) The cell_description method returns a cell
    def cell_description(self, gid):
        return make_cable_cell(gid)

    # The kind method returns the type of cell with gid.
    # Note: this must agree with the type returned by cell_description.
    def cell_kind(self, gid):
        return arbor.cell_kind.cable

    # (8) Make a ring network. For each gid, provide a list of incoming connections.
    def connections_on(self, gid):
        src = (gid-1)%self.ncells
        w = 0.01
        d = 5
        return [arbor.connection((src,0), (gid,0), w, d)]

    def num_targets(self, gid):
        return 1

    def num_sources(self, gid):
        return 1

    # (9) Attach a generator to the first cell in the ring.
    def event_generators(self, gid):
        if gid==0:
            sched = arbor.explicit_schedule([1])
            return [arbor.event_generator((0,0), 0.1, sched)]
        return []

    # (10) Place a probe at the root of each cell.
    def probes(self, gid):
        return [arbor.cable_probe_membrane_voltage('"root"')]

    def global_properties(self, kind):
        return self.props

# (11) Instantiate recipe
ncells = 50
recipe = ring_recipe(ncells)

# (12) Create an MPI communicator, and use it to create a hardware context
arbor.mpi_init()
comm = arbor.mpi_comm()
print(comm)
context = arbor.context(mpi=comm)
print(context)

# (13) Create a default domain decomposition and simulation
decomp = arbor.partition_load_balance(recipe, context)
sim = arbor.simulation(recipe, decomp, context)

# (14) Set spike generators to record
sim.record(arbor.spike_recording.all)

# (15) Attach a sampler to the voltage probe on cell 0. Sample rate of 1 sample every ms.
# Sampling period increased w.r.t network_ring.py to reduce amount of data
handles = [sim.sample((gid, 0), arbor.regular_schedule(1)) for gid in range(ncells)]

# (16) Run simulation
sim.run(ncells*5)
print('Simulation finished')

# (17) Plot the recorded voltages over time.
print("Storing results ...")
df_list = []
for gid in range(ncells):
    if len(sim.samples(handles[gid])):
        samples, meta = sim.samples(handles[gid])[0]
        df_list.append(pandas.DataFrame({'t/ms': samples[:, 0], 'U/mV': samples[:, 1], 'Cell': f"cell {gid}"}))

if len(df_list):
    df = pandas.concat(df_list)
    df.to_csv(f"result_mpi_{context.rank}.csv", float_format='%g')