diff --git a/include/grenade/vx/grenade.h b/include/grenade/vx/grenade.h index d1ddc3f43667b78f2112b709986b8655868417ee..6c0591bd495e83672e847e59686fa1b9494b6183 100644 --- a/include/grenade/vx/grenade.h +++ b/include/grenade/vx/grenade.h @@ -11,6 +11,7 @@ #include "grenade/vx/network/network_graph.h" #include "grenade/vx/network/network_graph_builder.h" #include "grenade/vx/network/network_graph_statistics.h" +#include "grenade/vx/network/plasticity_rule_generator.h" #include "grenade/vx/network/population.h" #include "grenade/vx/network/projection.h" #include "grenade/vx/network/routing_builder.h" diff --git a/include/grenade/vx/network/plasticity_rule_generator.h b/include/grenade/vx/network/plasticity_rule_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..ad137286e1556dc6871c658104f36c3f578d9369 --- /dev/null +++ b/include/grenade/vx/network/plasticity_rule_generator.h @@ -0,0 +1,48 @@ +#pragma once +#include "grenade/vx/network/plasticity_rule.h" +#include "hate/visibility.h" +#include <optional> +#include <set> +#include <vector> + +#if defined(__GENPYBIND__) || defined(__GENPYBIND_GENERATED__) +#include <pybind11/stl.h> +#endif + + +namespace grenade::vx { +struct IODataMap; +} // namespace grenade::vx + +namespace grenade::vx GENPYBIND_TAG_GRENADE_VX { +namespace network { + +struct PlasticityRuleDescriptor; +struct NetworkGraph; + +struct GENPYBIND(visible) OnlyRecordingPlasticityRuleGenerator +{ + /** + * Observables, which can be recorded. + */ + enum class Observable + { + weights, + correlation_causal, + correlation_acausal + }; + + OnlyRecordingPlasticityRuleGenerator(std::set<Observable> const& observables) SYMBOL_VISIBLE; + + /** + * Generate plasticity rule which only executes given recording. + * Timing and projection information is left default/empty. + */ + PlasticityRule generate() const SYMBOL_VISIBLE; + +private: + std::set<Observable> m_observables; +}; + +} // namespace grenade::vx +} // namespace network diff --git a/include/grenade/vx/types.h b/include/grenade/vx/types.h index ac100acc7fe1cbe3e833f6ad218657adda4e224c..713f6ef5b7fb0289c3e1e5c0b4f908f43ce1310e 100644 --- a/include/grenade/vx/types.h +++ b/include/grenade/vx/types.h @@ -1,8 +1,9 @@ #pragma once +#include "grenade/vx/genpybind.h" #include "halco/common/geometry.h" #include "haldls/vx/v3/padi.h" -namespace grenade::vx { +namespace grenade::vx GENPYBIND_TAG_GRENADE_VX { /** * 5 bit wide unsigned activation value type. @@ -24,7 +25,7 @@ struct UInt32 : public halco::common::detail::BaseType<UInt32, uint32_t> * int8_t, which we don't currently have, since it relies on the interpretation of having the CADC * baseline in the middle of the value range. */ -struct Int8 : public halco::common::detail::BaseType<Int8, int8_t> +struct GENPYBIND(inline_base("*")) Int8 : public halco::common::detail::BaseType<Int8, int8_t> { constexpr explicit Int8(value_type const value = 0) : base_t(value) {} }; diff --git a/src/grenade/vx/network/plasticity_rule_generator.cpp b/src/grenade/vx/network/plasticity_rule_generator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b98e76ead486f318d2c36e0cf8b741fc501ef2e0 --- /dev/null +++ b/src/grenade/vx/network/plasticity_rule_generator.cpp @@ -0,0 +1,214 @@ +#include "grenade/vx/network/plasticity_rule_generator.h" + +#include "grenade/vx/io_data_map.h" +#include "grenade/vx/network/network_graph.h" +#include "halco/common/iter_all.h" +#include "halco/common/typed_array.h" +#include "halco/hicann-dls/vx/v3/ppu.h" +#include <sstream> + +namespace grenade::vx::network { + +OnlyRecordingPlasticityRuleGenerator::OnlyRecordingPlasticityRuleGenerator( + std::set<Observable> const& observables) : + m_observables(observables) +{} + +PlasticityRule OnlyRecordingPlasticityRuleGenerator::generate() const +{ + std::stringstream kernel; + + kernel << "#include \"grenade/vx/ppu/synapse_array_view_handle.h\"\n"; + kernel << "#include \"libnux/vx/location.h\"\n"; + kernel << "#include \"libnux/vx/correlation.h\"\n"; + kernel << "#include \"libnux/vx/vector_convert.h\"\n"; + kernel << "#include \"libnux/vx/time.h\"\n"; + kernel << "#include \"libnux/vx/helper.h\"\n"; + kernel << "#include \"hate/tuple.h\"\n"; + + kernel << "using namespace grenade::vx::ppu;\n"; + kernel << "using namespace libnux::vx;\n"; + + kernel << "extern uint64_t time_origin;\n"; + kernel << "extern volatile PPUOnDLS ppu;\n"; + + kernel << "template <size_t N>\n"; + kernel << "void PLASTICITY_RULE_KERNEL(std::array<SynapseArrayViewHandle, N>& synapses, " + "std::array<PPUOnDLS, N> synrams, Recording& recording)\n"; + kernel << "{\n"; + + kernel << " recording.time = now() - time_origin;\n"; + + if (m_observables.contains(Observable::weights)) { + kernel << " {\n"; + kernel << " hate::for_each([&](auto& " + "weights_per_synapse_view, auto& synapse, auto const& synram) {\n"; + kernel << " if (synram == ppu) {\n"; + kernel << " size_t row = 0;\n"; + kernel << " for (size_t i = 0; i < synapse.rows.size; ++i) {\n"; + kernel << " if (synapse.rows.test(i)) {\n"; + kernel << " auto const tmp = " + << "static_cast<VectorRowFracSat8>(synapse.get_weights(i));\n"; + kernel << " size_t column = 0;\n"; + kernel << " for (size_t j = 0; j < synapse.columns.size; ++j) {\n"; + kernel << " if (synapse.columns.test(j)) {\n"; + kernel << " weights_per_synapse_view[row][column] = tmp[j];\n"; + kernel << " column++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " do_not_optimize_away(weights_per_synapse_view[row]);\n"; + kernel << " do_not_optimize_away(weights_per_synapse_view);\n"; + kernel << " row++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }, recording.weights, synapses, synrams);\n"; + kernel << " }\n"; + } + if (m_observables.contains(Observable::correlation_causal) && + !m_observables.contains(Observable::correlation_acausal)) { + kernel << " {\n"; + kernel << " hate::for_each(" + "[&](auto& correlation_causal_per_synapse_view, auto const& synapse, " + "auto const& synram) {\n"; + kernel << " if (synram == ppu) {\n"; + kernel << " size_t row = 0;\n"; + kernel << " for (size_t i = 0; i < synapse.rows.size; ++i) {\n"; + kernel << " if (synapse.rows.test(i)) {\n"; + kernel << " vector_row_t causal;\n"; + kernel << " get_causal_correlation(&(causal.even.data), " + " &(causal.odd.data), i);\n"; + kernel << " do_not_optimize_away(causal);\n"; + kernel << " reset_correlation(i);\n"; + kernel << " VectorRowFracSat8 const tmp = " + "causal.convert_contiguous();\n"; + kernel << " size_t column = 0;\n"; + kernel << " for (size_t j = 0; j < synapse.columns.size; ++j) {\n"; + kernel << " if (synapse.columns.test(j)) {\n"; + kernel << " correlation_causal_per_synapse_view[row][column] " + "= tmp[j];\n"; + kernel << " column++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " row++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }, recording.correlation_causal, synapses, synrams);\n"; + kernel << " }\n"; + } + if (!m_observables.contains(Observable::correlation_causal) && + m_observables.contains(Observable::correlation_acausal)) { + kernel << " {\n"; + kernel << " hate::for_each(" + "[&](auto& correlation_acausal_per_synapse_view, auto const& synapse, " + "auto const& synram) {\n"; + kernel << " if (synram == ppu) {\n"; + kernel << " size_t row = 0;\n"; + kernel << " for (size_t i = 0; i < synapse.rows.size; ++i) {\n"; + kernel << " if (synapse.rows.test(i)) {\n"; + kernel << " vector_row_t acausal;\n"; + kernel << " get_acausal_correlation(&(acausal.even.data), " + " &(acausal.odd.data), i);\n"; + kernel << " do_not_optimize_away(acausal);\n"; + kernel << " reset_correlation(i);\n"; + kernel << " VectorRowFracSat8 const tmp = " + "acausal.convert_contiguous();\n"; + kernel << " size_t column = 0;\n"; + kernel << " for (size_t j = 0; j < synapse.columns.size; ++j) {\n"; + kernel << " if (synapse.columns.test(j)) {\n"; + kernel << " correlation_acausal_per_synapse_view[row][column] " + "= tmp[j];\n"; + kernel << " column++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " row++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }, recording.correlation_acausal, synapses, synrams);\n"; + kernel << " }\n"; + } + if (m_observables.contains(Observable::correlation_causal) && + m_observables.contains(Observable::correlation_acausal)) { + kernel << " {\n"; + kernel << " hate::for_each(" + "[&](auto& correlation_causal_per_synapse_view, auto& " + "correlation_acausal_per_synapse_view, auto const& synapse, " + "auto const& synram) {\n"; + kernel << " if (synram == ppu) {\n"; + kernel << " size_t row = 0;\n"; + kernel << " for (size_t i = 0; i < synapse.rows.size; ++i) {\n"; + kernel << " if (synapse.rows.test(i)) {\n"; + kernel << " vector_row_t causal;\n"; + kernel << " vector_row_t acausal;\n"; + kernel << " get_correlation(&(causal.even.data), " + " &(causal.odd.data), &(acausal.even.data), " + "&(acausal.odd.data), i);\n"; + kernel << " do_not_optimize_away(causal);\n"; + kernel << " do_not_optimize_away(acausal);\n"; + kernel << " reset_correlation(i);\n"; + kernel << " VectorRowFracSat8 const tmp_causal = " + "causal.convert_contiguous();\n"; + kernel << " VectorRowFracSat8 const tmp_acausal = " + "acausal.convert_contiguous();\n"; + kernel << " size_t column = 0;\n"; + kernel << " for (size_t j = 0; j < synapse.columns.size; ++j) {\n"; + kernel << " if (synapse.columns.test(j)) {\n"; + kernel << " correlation_causal_per_synapse_view[row][column] " + "= tmp_causal[j];\n"; + kernel << " correlation_acausal_per_synapse_view[row][column] " + "= tmp_acausal[j];\n"; + kernel << " column++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " row++;\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }\n"; + kernel << " }, recording.correlation_causal, recording.correlation_acausal, " + "synapses, synrams);\n"; + kernel << " }\n"; + } + + kernel << "}\n"; + + PlasticityRule rule; + PlasticityRule::TimedRecording recording; + for (auto const& observable : m_observables) { + switch (observable) { + case Observable::weights: { + recording.observables["weights"] = + PlasticityRule::TimedRecording::ObservablePerSynapse{ + PlasticityRule::TimedRecording::ObservablePerSynapse::Type::int8, + PlasticityRule::TimedRecording::ObservablePerSynapse::LayoutPerRow:: + packed_active_columns}; + break; + } + case Observable::correlation_causal: { + recording.observables["correlation_causal"] = + PlasticityRule::TimedRecording::ObservablePerSynapse{ + PlasticityRule::TimedRecording::ObservablePerSynapse::Type::int8, + PlasticityRule::TimedRecording::ObservablePerSynapse::LayoutPerRow:: + packed_active_columns}; + break; + } + case Observable::correlation_acausal: { + recording.observables["correlation_acausal"] = + PlasticityRule::TimedRecording::ObservablePerSynapse{ + PlasticityRule::TimedRecording::ObservablePerSynapse::Type::int8, + PlasticityRule::TimedRecording::ObservablePerSynapse::LayoutPerRow:: + packed_active_columns}; + break; + } + default: { + throw std::logic_error("Observable not implemented."); + } + } + } + rule.recording = recording; + rule.kernel = kernel.str(); + return rule; +} + +} // namespace grenade::vx::network diff --git a/src/grenade/vx/ppu_program_generator.cpp b/src/grenade/vx/ppu_program_generator.cpp index e1bc7ae3f0723c5642b010b4ff0bda7f297080d9..b20718c9ce6b7d81d3d458ba5851f58b81ce71d1 100644 --- a/src/grenade/vx/ppu_program_generator.cpp +++ b/src/grenade/vx/ppu_program_generator.cpp @@ -164,9 +164,11 @@ std::vector<std::string> PPUProgramGenerator::done() source << "#include \"libnux/scheduling/Service.hpp\"\n"; source << "#include \"libnux/vx/mailbox.h\"\n"; source << "#include \"libnux/vx/dls.h\"\n"; + source << "#include \"libnux/vx/time.h\"\n"; source << "extern volatile libnux::vx::PPUOnDLS ppu;\n"; source << "volatile uint32_t runtime;\n"; source << "volatile uint32_t scheduler_event_drop_count;\n"; + source << "uint64_t time_origin = 0;\n"; for (auto const& [i, _, __] : m_plasticity_rules) { source << "extern Timer timer_" << i << ";\n"; source << "volatile uint32_t timer_" << i << "_event_drop_count;\n"; @@ -200,6 +202,7 @@ std::vector<std::string> PPUProgramGenerator::done() source << "static_cast<void>(runtime);\n"; if (!m_plasticity_rules.empty()) { source << "auto current = get_time();\n"; + source << "time_origin = libnux::vx::now();\n"; source << "SchedulerSignallerTimer timer(current, current + runtime);\n"; for (auto const& [i, _, __] : m_plasticity_rules) { source << "timer_" << i << ".set_first_deadline(current + timer_" << i diff --git a/tests/hw/grenade/vx/network/test-plasticity_rule_generator.cpp b/tests/hw/grenade/vx/network/test-plasticity_rule_generator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac8ae47ea7f802631893d236c43cb583704cb8c5 --- /dev/null +++ b/tests/hw/grenade/vx/network/test-plasticity_rule_generator.cpp @@ -0,0 +1,185 @@ +#include "grenade/vx/backend/connection.h" +#include "grenade/vx/execution_instance.h" +#include "grenade/vx/graph.h" +#include "grenade/vx/jit_graph_executor.h" +#include "grenade/vx/network/extract_output.h" +#include "grenade/vx/network/network.h" +#include "grenade/vx/network/network_builder.h" +#include "grenade/vx/network/network_graph.h" +#include "grenade/vx/network/network_graph_builder.h" +#include "grenade/vx/network/plasticity_rule_generator.h" +#include "grenade/vx/network/population.h" +#include "grenade/vx/network/projection.h" +#include "grenade/vx/network/routing_builder.h" +#include "grenade/vx/types.h" +#include "halco/hicann-dls/vx/v3/chip.h" +#include <gtest/gtest.h> +#include <log4cxx/logger.h> + +using namespace halco::common; +using namespace halco::hicann_dls::vx::v3; +using namespace stadls::vx::v3; +using namespace lola::vx::v3; +using namespace haldls::vx::v3; + +TEST(OnlyRecordingPlasticityRuleGenerator, weights) +{ + grenade::vx::coordinate::ExecutionInstance instance; + + grenade::vx::JITGraphExecutor::ChipConfigs chip_configs; + chip_configs[instance] = lola::vx::v3::Chip(); + + // build network + grenade::vx::network::NetworkBuilder network_builder; + + // population at beginning of row + grenade::vx::network::Population::Neurons neurons_1{AtomicNeuronOnDLS()}; + grenade::vx::network::Population::EnableRecordSpikes enable_record_spikes_1{false}; + grenade::vx::network::Population population_1{ + std::move(neurons_1), std::move(enable_record_spikes_1)}; + auto const population_descriptor_1 = network_builder.add(population_1); + + // population not at beginning of row + grenade::vx::network::Population::Neurons neurons_2{ + AtomicNeuronOnDLS(NeuronColumnOnDLS(1), NeuronRowOnDLS(0))}; + grenade::vx::network::Population::EnableRecordSpikes enable_record_spikes_2{false}; + grenade::vx::network::Population population_2{ + std::move(neurons_2), std::move(enable_record_spikes_2)}; + auto const population_descriptor_2 = network_builder.add(population_2); + + // population on bottom row + grenade::vx::network::Population::Neurons neurons_4{ + AtomicNeuronOnDLS(NeuronColumnOnDLS(1), NeuronRowOnDLS(1))}; + grenade::vx::network::Population::EnableRecordSpikes enable_record_spikes_4{false}; + grenade::vx::network::Population population_4{ + std::move(neurons_4), std::move(enable_record_spikes_4)}; + auto const population_descriptor_4 = network_builder.add(population_4); + + // population on both rows + grenade::vx::network::Population::Neurons neurons_5{ + AtomicNeuronOnDLS(NeuronColumnOnDLS(2), NeuronRowOnDLS(0)), + AtomicNeuronOnDLS(NeuronColumnOnDLS(2), NeuronRowOnDLS(1))}; + grenade::vx::network::Population::EnableRecordSpikes enable_record_spikes_5{false, false}; + grenade::vx::network::Population population_5{ + std::move(neurons_5), std::move(enable_record_spikes_5)}; + auto const population_descriptor_5 = network_builder.add(population_5); + + // projection at beginning of row + grenade::vx::network::Projection::Connections projection_connections_1; + for (size_t i = 0; i < population_1.neurons.size(); ++i) { + projection_connections_1.push_back( + {i, i, grenade::vx::network::Projection::Connection::Weight(1)}); + } + grenade::vx::network::Projection projection_1{ + grenade::vx::network::Projection::ReceptorType::excitatory, projection_connections_1, + population_descriptor_1, population_descriptor_1}; + auto const projection_descriptor_1 = network_builder.add(projection_1); + + // projection not at beginning of row + grenade::vx::network::Projection::Connections projection_connections_2; + for (size_t i = 0; i < population_2.neurons.size(); ++i) { + projection_connections_2.push_back( + {i, i, grenade::vx::network::Projection::Connection::Weight(2)}); + } + grenade::vx::network::Projection projection_2{ + grenade::vx::network::Projection::ReceptorType::excitatory, projection_connections_2, + population_descriptor_2, population_descriptor_2}; + auto const projection_descriptor_2 = network_builder.add(projection_2); + + // projection not at beginning of column + grenade::vx::network::Projection::Connections projection_connections_3; + for (size_t i = 0; i < population_1.neurons.size(); ++i) { + projection_connections_3.push_back( + {i, i, grenade::vx::network::Projection::Connection::Weight(3)}); + } + grenade::vx::network::Projection projection_3{ + grenade::vx::network::Projection::ReceptorType::excitatory, projection_connections_3, + population_descriptor_1, population_descriptor_1}; + auto const projection_descriptor_3 = network_builder.add(projection_3); + + // projection on bottom hemisphere + grenade::vx::network::Projection::Connections projection_connections_4; + for (size_t i = 0; i < population_4.neurons.size(); ++i) { + projection_connections_4.push_back( + {i, i, grenade::vx::network::Projection::Connection::Weight(4)}); + } + grenade::vx::network::Projection projection_4{ + grenade::vx::network::Projection::ReceptorType::excitatory, projection_connections_4, + population_descriptor_4, population_descriptor_4}; + auto const projection_descriptor_4 = network_builder.add(projection_4); + + // projection over two hemispheres + grenade::vx::network::Projection::Connections projection_connections_5; + for (size_t i = 0; i < population_5.neurons.size(); ++i) { + projection_connections_5.push_back( + {0, i, grenade::vx::network::Projection::Connection::Weight(5 + i)}); + } + grenade::vx::network::Projection projection_5{ + grenade::vx::network::Projection::ReceptorType::excitatory, projection_connections_5, + population_descriptor_5, population_descriptor_5}; + auto const projection_descriptor_5 = network_builder.add(projection_5); + + + grenade::vx::network::OnlyRecordingPlasticityRuleGenerator recording_generator( + {grenade::vx::network::OnlyRecordingPlasticityRuleGenerator::Observable::weights}); + + grenade::vx::network::PlasticityRule plasticity_rule = recording_generator.generate(); + plasticity_rule.timer = grenade::vx::network::PlasticityRule::Timer{ + grenade::vx::network::PlasticityRule::Timer::Value(0), + grenade::vx::network::PlasticityRule::Timer::Value( + Timer::Value::fpga_clock_cycles_per_us * 300), + 2}; + plasticity_rule.projections = std::vector{ + projection_descriptor_1, projection_descriptor_2, projection_descriptor_3, + projection_descriptor_4, projection_descriptor_5}; + + auto const plasticity_rule_descriptor = network_builder.add(plasticity_rule); + + auto const network = network_builder.done(); + + auto const routing_result = grenade::vx::network::build_routing(network); + auto const network_graph = grenade::vx::network::build_network_graph(network, routing_result); + + grenade::vx::IODataMap inputs; + inputs.runtime[instance].push_back(Timer::Value(Timer::Value::fpga_clock_cycles_per_us * 1000)); + + // Construct connection to HW + grenade::vx::backend::Connection connection; + std::map<DLSGlobal, grenade::vx::backend::Connection> connections; + connections.emplace(DLSGlobal(), std::move(connection)); + grenade::vx::JITGraphExecutor executor(std::move(connections)); + + // run graph with given inputs and return results + auto const result_map = + grenade::vx::run(executor, network_graph.get_graph(), inputs, chip_configs); + + auto const recorded_data = std::get<grenade::vx::network::PlasticityRule::TimedRecordingData>( + grenade::vx::network::extract_plasticity_rule_recording_data( + result_map, network_graph, plasticity_rule_descriptor)); + EXPECT_EQ(recorded_data.data_per_synapse.size(), 1); + EXPECT_EQ(recorded_data.data_array.size(), 0); + EXPECT_TRUE(recorded_data.data_per_synapse.contains("weights")); + EXPECT_EQ(recorded_data.data_per_synapse.at("weights").size(), 5 /* #projections */); + for (auto const& [descriptor, ws] : recorded_data.data_per_synapse.at("weights")) { + auto const& weights = + std::get<std::vector<grenade::vx::TimedDataSequence<std::vector<int8_t>>>>(ws); + EXPECT_EQ(weights.size(), inputs.batch_size()); + for (size_t i = 0; i < weights.size(); ++i) { + auto const& samples = weights.at(i); + for (auto const& sample : samples) { + if (descriptor == projection_descriptor_5) { + EXPECT_EQ(sample.data.size(), 2 /* #connections/projection */); + } else { + EXPECT_EQ(sample.data.size(), 1 /* #connections/projection */); + } + for (size_t i = 0; i < sample.data.size(); ++i) { + EXPECT_EQ( + static_cast<int>(sample.data.at(i)), network_graph.get_network() + ->projections.at(descriptor) + .connections.at(i) + .weight); + } + } + } + } +}