#include <arbor/util/any.hpp>
#include <arbor/morph/region.hpp>
#include <arbor/morph/locset.hpp>
#include <limits>

#include "error.hpp"
#include "s_expr.hpp"
#include "morph_parse.hpp"

namespace pyarb {

struct nil_tag {};

template <typename T>
bool match(const std::type_info& info) {
    return info == typeid(T);
}

template <>
bool match<double>(const std::type_info& info) {
    return info == typeid(double) || info == typeid(int);
}

template <>
bool match<arb::region>(const std::type_info& info) {
    return info == typeid(arb::region) || info == typeid(nil_tag);
}

template <>
bool match<arb::locset>(const std::type_info& info) {
    return info == typeid(arb::locset) || info == typeid(nil_tag);
}

template <typename T>
T eval_cast(arb::util::any arg) {
    return std::move(arb::util::any_cast<T&>(arg));
}

template <>
double eval_cast<double>(arb::util::any arg) {
    if (arg.type()==typeid(int)) return arb::util::any_cast<int>(arg);
    return arb::util::any_cast<double>(arg);
}

template <>
arb::region eval_cast<arb::region>(arb::util::any arg) {
    if (arg.type()==typeid(arb::region)) return arb::util::any_cast<arb::region>(arg);
    return arb::reg::nil();
}

template <>
arb::locset eval_cast<arb::locset>(arb::util::any arg) {
    if (arg.type()==typeid(arb::locset)) return arb::util::any_cast<arb::locset>(arg);
    return arb::ls::nil();
}

template <typename... Args>
struct call_eval {
    using ftype = std::function<arb::util::any(Args...)>;
    ftype f;
    call_eval(ftype f): f(std::move(f)) {}

    template<std::size_t... I>
    arb::util::any expand_args_then_eval(std::vector<arb::util::any> args, std::index_sequence<I...>) {
        return f(eval_cast<Args>(std::move(args[I]))...);
    }

    arb::util::any operator()(std::vector<arb::util::any> args) {
        return expand_args_then_eval(std::move(args), std::make_index_sequence<sizeof...(Args)>());
    }
};

template <typename... Args>
struct call_match {
    template <std::size_t I, typename T, typename Q, typename... Rest>
    bool match_args_impl(const std::vector<arb::util::any>& args) const {
        return match<T>(args[I].type()) && match_args_impl<I+1, Q, Rest...>(args);
    }

    template <std::size_t I, typename T>
    bool match_args_impl(const std::vector<arb::util::any>& args) const {
        return match<T>(args[I].type());
    }

    template <std::size_t I>
    bool match_args_impl(const std::vector<arb::util::any>& args) const {
        return true;
    }

    bool operator()(const std::vector<arb::util::any>& args) const {
        const auto nargs_in = args.size();
        const auto nargs_ex = sizeof...(Args);
        return nargs_in==nargs_ex? match_args_impl<0, Args...>(args): false;
    }
};

template <typename T>
struct fold_eval {
    using fold_fn = std::function<T(T, T)>;
    fold_fn f;

    using anyvec = std::vector<arb::util::any>;
    using iterator = anyvec::iterator;

    fold_eval(fold_fn f): f(std::move(f)) {}

    T fold_impl(iterator left, iterator right) {
        if (std::distance(left,right)==1u) {
            return eval_cast<T>(std::move(*left));
        }
        return f(eval_cast<T>(std::move(*left)), fold_impl(left+1, right));
    }

    arb::util::any operator()(anyvec args) {
        return fold_impl(args.begin(), args.end());
    }
};

template <typename T>
struct fold_match {
    using anyvec = std::vector<arb::util::any>;
    bool operator()(const anyvec& args) const {
        if (args.size()<2u) return false;
        for (auto& a: args) {
            if (!match<T>(a.type())) return false;
        }
        return true;
    }
};

struct evaluator {
    using any_vec = std::vector<arb::util::any>;
    using eval_fn = std::function<arb::util::any(any_vec)>;
    using args_fn = std::function<bool(const any_vec&)>;

    eval_fn eval;
    args_fn match_args;
    const char* message;

    evaluator(eval_fn f, args_fn a, const char* m):
        eval(std::move(f)),
        match_args(std::move(a)),
        message(m)
    {}
};

template <typename... Args>
struct make_call {
    evaluator state;

    template <typename F>
    make_call(F&& f, const char* msg="call"):
        state(call_eval<Args...>(std::forward<F>(f)), call_match<Args...>(), msg)
    {}

    operator evaluator() const {
        return state;
    }
};

template <typename T>
struct make_fold {
    evaluator state;

    template <typename F>
    make_fold(F&& f, const char* msg="fold"):
        state(fold_eval<T>(std::forward<F>(f)), fold_match<T>(), msg)
    {}

    operator evaluator() const {
        return state;
    }
};

std::unordered_multimap<std::string, evaluator> eval_map {
    // Functions that return regions
    {"nil",     make_call<>(arb::reg::nil,
                            "'nil' with 0 arguments")},
    {"all",     make_call<>(arb::reg::all,
                            "'all' with 0 arguments")},
    {"tag",     make_call<int>(arb::reg::tagged,
                            "'tag' with 1 argment: (tag_id:integer)")},
    {"branch",  make_call<int>(arb::reg::branch,
                            "'branch' with 1 argument: (branch_id:integer)")},
    {"cable",   make_call<int, double, double>(arb::reg::cable,
                            "'cable' with 3 arguments: (branch_id:integer prox:real dist:real)")},
    {"region",  make_call<std::string>(arb::reg::named,
                            "'region' with 1 argument: (name:string)")},
    {"distal_interval",  make_call<arb::locset, double>(arb::reg::distal_interval,
                            "'distal_interval' with 2 arguments: (start:locset extent:real)")},
    {"distal_interval", make_call<arb::locset>(
                            [](arb::locset ls){return arb::reg::distal_interval(std::move(ls), std::numeric_limits<double>::max());},
                            "'distal_interval' with 1 argument: (start:locset)")},
    {"proximal_interval",make_call<arb::locset, double>(arb::reg::proximal_interval,
                            "'proximal_interval' with 2 arguments: (start:locset extent:real)")},
    {"proximal_interval", make_call<arb::locset>(
                            [](arb::locset ls){return arb::reg::proximal_interval(std::move(ls), std::numeric_limits<double>::max());},
                            "'proximal_interval' with 1 argument: (start:locset)")},
    {"complete", make_call<arb::region>(arb::reg::complete,
                            "'super' with 1 argment: (reg:region)")},
    {"radius_lt",make_call<arb::region, double>(arb::reg::radius_lt,
                            "'radius_lt' with 2 arguments: (reg:region radius:real)")},
    {"radius_le",make_call<arb::region, double>(arb::reg::radius_le,
                            "'radius_le' with 2 arguments: (reg:region radius:real)")},
    {"radius_gt",make_call<arb::region, double>(arb::reg::radius_gt,
                            "'radius_gt' with 2 arguments: (reg:region radius:real)")},
    {"radius_ge",make_call<arb::region, double>(arb::reg::radius_ge,
                            "'radius_ge' with 2 arguments: (reg:region radius:real)")},
    {"z_dist_from_root_lt",make_call<double>(arb::reg::z_dist_from_root_lt,
                            "'z_dist_from_root_lt' with 1 arguments: (distance:real)")},
    {"z_dist_from_root_le",make_call<double>(arb::reg::z_dist_from_root_le,
                            "'z_dist_from_root_le' with 1 arguments: (distance:real)")},
    {"z_dist_from_root_gt",make_call<double>(arb::reg::z_dist_from_root_gt,
                            "'z_dist_from_root_gt' with 1 arguments: (distance:real)")},
    {"z_dist_from_root_ge",make_call<double>(arb::reg::z_dist_from_root_ge,
                            "'z_dist_from_root_ge' with 1 arguments: (distance:real)")},
    {"join",    make_fold<arb::region>(static_cast<arb::region(*)(arb::region, arb::region)>(arb::join),
                            "'join' with at least 2 arguments: (region region [...region])")},
    {"intersect",make_fold<arb::region>(static_cast<arb::region(*)(arb::region, arb::region)>(arb::intersect),
                            "'intersect' with at least 2 arguments: (region region [...region])")},

    // Functions that return locsets
    {"root",    make_call<>(arb::ls::root,
                            "'root' with 0 arguments")},

    {"location", make_call<int, double>([](int bid, double pos){return arb::ls::location(arb::msize_t(bid), pos);},
                            "'location' with 2 arguments: (branch_id:integer position:real)")},
    {"terminal", make_call<>(arb::ls::terminal,
                            "'terminal' with 0 arguments")},
    {"distal",  make_call<arb::region>(arb::ls::most_distal,
                            "'distal' with 1 argument: (reg:region)")},
    {"proximal",make_call<arb::region>(arb::ls::most_proximal,
                            "'proximal' with 1 argument: (reg:region)")},
    {"uniform",make_call<arb::region, int, int, int>(arb::ls::uniform,
                            "'uniform' with 4 arguments: (reg:region, first:int, last:int, seed:int)")},
    {"on_branches",make_call<double>(arb::ls::on_branches,
                            "'on_branches' with 1 argument: (pos:double)")},
    {"locset",  make_call<std::string>(arb::ls::named,
                            "'locset' with 1 argument: (name:string)")},
    {"restrict",  make_call<arb::locset, arb::region>(arb::ls::restrict,
                            "'restrict' with 2 arguments: (ls:locset, reg:region)")},
    {"join",    make_fold<arb::locset>(static_cast<arb::locset(*)(arb::locset, arb::locset)>(arb::join),
                            "'join' with at least 2 arguments: (locset locset [...locset])")},
    {"sum",     make_fold<arb::locset>(static_cast<arb::locset(*)(arb::locset, arb::locset)>(arb::sum),
                            "'sum' with at least 2 arguments: (locset locset [...locset])")},
};

parse_hopefully<arb::util::any> eval(const s_expr& e);

parse_hopefully<std::vector<arb::util::any>> eval_args(const s_expr& e) {
    if (!e) return {std::vector<arb::util::any>{}}; // empty argument list
    const s_expr* h = &e;
    std::vector<arb::util::any> args;
    while (*h) {
        auto arg = eval(h->head());
        if (!arg) return std::move(arg.error());
        args.push_back(std::move(*arg));
        h = &h->tail();
    }
    return args;
}

// Generate a string description of a function evaluation of the form:
// Example output:
//  'foo' with 1 argument: (real)
//  'bar' with 0 arguments
//  'cat' with 3 arguments: (locset region integer)
// Where 'foo', 'bar' and 'cat' are the name of the function, and the
// types (integer, real, region, locset) are inferred from the arguments.
std::string eval_description(const char* name, const std::vector<arb::util::any>& args) {
    auto type_string = [](const std::type_info& t) -> const char* {
        if (t==typeid(int))         return "integer";
        if (t==typeid(double))      return "real";
        if (t==typeid(arb::region)) return "region";
        if (t==typeid(arb::locset)) return "locset";
        if (t==typeid(nil_tag))     return "()";
        return "unknown";
    };

    const auto nargs = args.size();
    std::string msg =
        util::pprintf("'{}' with {} argument{}",
                      name, nargs,
                      nargs==0?"s": nargs==1u?":": "s:");
    if (nargs) {
        msg += " (";
        bool first = true;
        for (auto& a: args) {
            msg += util::pprintf("{}{}", first?"":" ", type_string(a.type()));
            first = false;
        }
        msg += ")";
    }
    return msg;
}

// Evaluate an s expression.
// On success the result is wrapped in util::any, where the result is one of:
//      int         : an integer atom
//      double      : a real atom
//      arb::region : a region
//      arb::locset : a locset
//
// If there invalid input is detected, hopefully return value contains
// a parse_error_state with an error string and location.
//
// If there was an unexpected/fatal error, an exception will be thrown.
parse_hopefully<arb::util::any> eval(const s_expr& e) {
    if (e.is_atom()) {
        auto& t = e.atom();
        switch (t.kind) {
            case tok::integer:
                return {std::stoi(t.spelling)};
            case tok::real:
                return {std::stod(t.spelling)};
            case tok::nil:
                return {nil_tag()};
            case tok::string:
                return {std::string(t.spelling)};
            case tok::error:
                return parse_error_state{e.atom().spelling, location(e)};
            default:
                return parse_error_state{
                    util::pprintf("Unexpected term: {}", e), location(e)};
        }
    }
    if (e.head().is_atom()) {
        // This must be a function evaluation, where head is the function name, and
        // tail is a list of arguments.

        // Evaluate the arguments, and return error state if an error ocurred.
        auto args = eval_args(e.tail());
        if (!args) {
            return args.error();
        }

        // Find all candidate functions that match the name of the function.
        auto& name = e.head().atom().spelling;
        auto matches = eval_map.equal_range(name);

        // Search for a candidate that matches the argument list.
        for (auto i=matches.first; i!=matches.second; ++i) {
            if (i->second.match_args(*args)) { // found a match: evaluate and return.
                return i->second.eval(*args);
            }
        }

        // Unable to find a match: try to return a helpful error message.
        const auto nc = std::distance(matches.first, matches.second);
        auto msg = util::pprintf("No matches for {}", eval_description(name.c_str(), *args));
        msg += util::pprintf("\n  There are {} potential candiates{}", nc, nc?":":".");
        int count = 0;
        for (auto i=matches.first; i!=matches.second; ++i) {
            msg += util::pprintf("\n  Candidate {}  {}", ++count, i->second.message);
        }
        return parse_error_state{std::move(msg), location(e)};
    }

    return parse_error_state{
        util::pprintf("Unable to evaluate '{}': expression must be either integer, real expression of the form (op <args>)", e),
            location(e)};
}

} // namespace pyarb