Skip to content
Snippets Groups Projects
Select Git revision
  • a2443eb3c6511dad2016c1d67853eec0ef1f40b3
  • master default protected
  • tut_ring_allen
  • docs_furo
  • docs_reorder_cable_cell
  • docs_graphviz
  • docs_rtd_dev
  • ebrains_mirror
  • doc_recat
  • docs_spike_source
  • docs_sim_sample_clar
  • docs_pip_warn
  • github_template_updates
  • docs_fix_link
  • cv_default_and_doc_clarification
  • docs_add_numpy_req
  • readme_zenodo_05
  • install_python_fix
  • install_require_numpy
  • typofix_propetries
  • docs_recipe_lookup
  • v0.10.0
  • v0.10.1
  • v0.10.0-rc5
  • v0.10.0-rc4
  • v0.10.0-rc3
  • v0.10.0-rc2
  • v0.10.0-rc
  • v0.9.0
  • v0.9.0-rc
  • v0.8.1
  • v0.8
  • v0.8-rc
  • v0.7
  • v0.6
  • v0.5.2
  • v0.5.1
  • v0.5
  • v0.4
  • v0.3
  • v0.2.2
41 results

util.hpp

Blame
  • user avatar
    Sam Yates authored
    * Add file cmake/CompilerOptions.cmake for setting up compiler
      specific options
    * Disable 'missing-braces' warning on Clang
    * Avoid defect in g++ 4.9.2 standard library that omits move
      constructor for `stdd::ifstream`
    * Remove signed/unsigned warning in test_optional.cpp
    034b17bb
    History
    util.hpp 3.95 KiB
    #include <chrono>
    #include <cmath>
    #include <fstream>
    #include <iomanip>
    #include <iostream>
    #include <string>
    #include <vector>
    
    #include <json/src/json.hpp>
    #include <util.hpp>
    
    // helpful code for running tests
    // a bit messy: refactor when it gets heavier and obvious patterns emerge...
    
    namespace testing{
    
    using time_point    = std::chrono::time_point<std::chrono::system_clock>;
    using duration_type = std::chrono::duration<double>;
    
    static inline
    time_point tic()
    {
        return std::chrono::system_clock::now();
    }
    
    static inline
    double toc(time_point start)
    {
        return duration_type(tic() - start).count();
    }
    
    
    [[gnu::unused]] static
    void write_vis_file(const std::string& fname, std::vector<std::vector<double>> values)
    {
        auto m = values.size();
        if(!m) return;
    
        std::ofstream fid(fname);
        if(!fid.is_open()) return;
    
        auto n = values[0].size();
        for(const auto& v : values) {
            if(n!=v.size()) {
                std::cerr << "all output arrays must have the same length\n";
                return;
            }
        }
    
        for(auto i=0u; i<n; ++i) {
            for(auto j=0u; j<m; ++j) {
                fid << " " << values[j][i];
            }
            fid << "\n";
        }
    }
    
    [[gnu::unused]] static
    nlohmann::json
    load_spike_data(const std::string& input_name)
    {
        nlohmann::json cell_data;
        std::ifstream fid(input_name);
        if(!fid.is_open()) {
            std::cerr << "error : unable to open file " << input_name
                      << " : run the validation generation script first\n";
            return {};
        }
    
        try {
            fid >> cell_data;
        }
        catch (...) {
            std::cerr << "error : incorrectly formatted json file " << input_name << "\n";
            return {};
        }
        return cell_data;
    }
    
    template <typename T>
    std::vector<T> find_spikes(std::vector<T> const& v, T threshold, T dt)
    {
        if(v.size()<2) {
            return {};
        }
    
        std::vector<T> times;
        for(auto i=1u; i<v.size(); ++i) {
            if(v[i]>=threshold && v[i-1]<threshold) {
                auto pos = (threshold-v[i-1]) / (v[i]-v[i-1]);
                times.push_back((i-1+pos)*dt);
            }
        }
    
        return times;
    }
    
    struct spike_comparison {
        double min = std::numeric_limits<double>::quiet_NaN();
        double max = std::numeric_limits<double>::quiet_NaN();
        double mean = std::numeric_limits<double>::quiet_NaN();
        double rms = std::numeric_limits<double>::quiet_NaN();
        std::vector<double> diff;
    
        // check whether initialized (i.e. has valid results)
        bool is_valid() const {
            return min == min;
        }
    
        // return maximum relative error
        double max_relative_error() const {
            if(!is_valid()) {
                return std::numeric_limits<double>::quiet_NaN();
            }
    
            return *std::max_element(diff.begin(), diff.end());
        }
    };
    
    [[gnu::unused]] static
    std::ostream&
    operator<< (std::ostream& o, spike_comparison const& spikes)
    {
        // use snprintf because C++ is just awful for formatting output
        char buffer[512];
        snprintf(
            buffer, sizeof(buffer),
            "min,max = %10.8f,%10.8f | mean,rms = %10.8f,%10.8f | max_rel = %10.8f",
            spikes.min, spikes.max, spikes.mean, spikes.rms,
            spikes.max_relative_error()
        );
        return o << buffer;
    }
    
    template <typename T>
    spike_comparison compare_spikes(
        std::vector<T> const& spikes,
        std::vector<T> const& baseline)
    {
        spike_comparison c;
    
        // return default initialized (all NaN) if number of spikes differs
        if(spikes.size() != baseline.size()) {
            return c;
        }
    
        c.min  = std::numeric_limits<double>::max();
        c.max  = 0.;
        c.mean = 0.;
        c.rms  = 0.;
    
        auto n = spikes.size();
        for(auto i=0u; i<n; ++i) {
            auto error = std::fabs(spikes[i] - baseline[i]);
            c.min = std::min(c.min, error);
            c.max = std::max(c.max, error);
            c.mean += error;
            c.rms += error*error;
            // relative difference
            c.diff.push_back(error/baseline[i]);
        }
    
        c.mean /= n;
        c.rms = std::sqrt(c.rms/n);
    
        return c;
    }
    
    } // namespace testing