From 939cb91a36c442d28c158c655be5e7a004325965 Mon Sep 17 00:00:00 2001
From: thorstenhater <24411438+thorstenhater@users.noreply.github.com>
Date: Fri, 20 Aug 2021 10:18:38 +0200
Subject: [PATCH] Fix ambiguous Region/Locset expressions (#1629)

- Add a set of test cases to check the behaviour
- Remove the function `(nil)` from the the DSL
- Remove `nil` and `()` as literals
- Add functions `(region-nil)` and `(locset-nil)`

Fixes an issue where `join` (and likely `intersect`) would not work due to the fact that
`nil` values could not be coerced to `region` or `locset`. This occurred while loading
a NML files.

While investigating it was found that certain calls, eg `(join () ())` are ambiguous and cannot
be resolved without changes in no proportion to the gain. The cause here is that
we can resolve `()` as region or a locset and at the same time `join` maps a list of
locsets or regions to a locset or region, consequently `(join () ())` can be both.

For future reference, in `label_parse.cpp::eval` we look for the first match in the `eval_map`
of a function name that can be successfully evaluated. However, this might not be the best match
(and the ordering depends on the internals of `eval_map`).  We _could_ check all successful
evaluations, but as eval is recursive, this idea has some obvious issues.
---
 arbor/morph/locset.cpp    |  2 +-
 arbor/morph/region.cpp    |  2 +-
 arborio/label_parse.cpp   |  9 ++---
 arborio/parse_helpers.hpp | 21 +++-------
 doc/concepts/labels.rst   |  8 +++-
 doc/scripts/gen-labels.py |  2 +-
 test/unit/test_s_expr.cpp | 82 +++++++++++++++++++++++++++++++++++++++
 7 files changed, 100 insertions(+), 26 deletions(-)

diff --git a/arbor/morph/locset.cpp b/arbor/morph/locset.cpp
index 8b15a295..d69616ac 100644
--- a/arbor/morph/locset.cpp
+++ b/arbor/morph/locset.cpp
@@ -41,7 +41,7 @@ mlocation_list thingify_(const nil_& x, const mprovider&) {
 }
 
 std::ostream& operator<<(std::ostream& o, const nil_& x) {
-    return o << "nil";
+    return o << "(locset-nil)";
 }
 
 // An explicit location.
diff --git a/arbor/morph/region.cpp b/arbor/morph/region.cpp
index 539b4683..b36c994d 100644
--- a/arbor/morph/region.cpp
+++ b/arbor/morph/region.cpp
@@ -41,7 +41,7 @@ mextent thingify_(const nil_& x, const mprovider&) {
 }
 
 std::ostream& operator<<(std::ostream& o, const nil_& x) {
-    return o << "nil";
+    return o << "(region-nil)";
 }
 
 
diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp
index 119342d3..df45c15f 100644
--- a/arborio/label_parse.cpp
+++ b/arborio/label_parse.cpp
@@ -22,8 +22,8 @@ namespace {
 
 std::unordered_multimap<std::string, evaluator> eval_map {
     // Functions that return regions
-    {"nil", make_call<>(arb::reg::nil,
-                "'nil' with 0 arguments")},
+    {"region-nil", make_call<>(arb::reg::nil,
+                "'region-nil' with 0 arguments")},
     {"all", make_call<>(arb::reg::all,
                 "'all' with 0 arguments")},
     {"tag", make_call<int>(arb::reg::tagged,
@@ -74,9 +74,10 @@ std::unordered_multimap<std::string, evaluator> eval_map {
                       "'intersect' with at least 2 arguments: (region region [...region])")},
 
     // Functions that return locsets
+    {"locset-nil", make_call<>(arb::ls::nil,
+                "'locset-nil' with 0 arguments")},
     {"root", make_call<>(arb::ls::root,
                  "'root' with 0 arguments")},
-
     {"location", make_call<int, double>([](int bid, double pos){return arb::ls::location(arb::msize_t(bid), pos);},
                      "'location' with 2 arguments: (branch_id:integer position:real)")},
     {"terminal", make_call<>(arb::ls::terminal,
@@ -138,7 +139,6 @@ std::string eval_description(const char* name, const std::vector<std::any>& args
         if (t==typeid(double))      return "real";
         if (t==typeid(arb::region)) return "region";
         if (t==typeid(arb::locset)) return "locset";
-        if (t==typeid(nil_tag))     return "()";
         return "unknown";
     };
 
@@ -185,7 +185,6 @@ parse_label_hopefully<std::any> eval(const s_expr& e) {
         // Find all candidate functions that match the name of the function.
         auto& name = e.head().atom().spelling;
         auto matches = eval_map.equal_range(name);
-
         // Search for a candidate that matches the argument list.
         for (auto i=matches.first; i!=matches.second; ++i) {
             if (i->second.match_args(*args)) { // found a match: evaluate and return.
diff --git a/arborio/parse_helpers.hpp b/arborio/parse_helpers.hpp
index 655a65df..7a611ae1 100644
--- a/arborio/parse_helpers.hpp
+++ b/arborio/parse_helpers.hpp
@@ -3,6 +3,7 @@
 #include <any>
 #include <string>
 #include <sstream>
+#include <iostream>
 
 #include <arbor/assert.hpp>
 #include <arbor/arbexcept.hpp>
@@ -13,13 +14,15 @@
 namespace arborio {
 using namespace arb;
 
-struct nil_tag {};
-
 // Check typeinfo against expected types
 template <typename T>
 bool match(const std::type_info& info) { return info == typeid(T); }
 template <> inline
 bool match<double>(const std::type_info& info) { return info == typeid(double) || info == typeid(int); }
+template <> inline
+bool match<arb::locset>(const std::type_info& info) { return info == typeid(arb::locset); }
+template <> inline
+bool match<arb::region>(const std::type_info& info) { return info == typeid(arb::region); }
 
 // Convert a value wrapped in a std::any to target type.
 template <typename T>
@@ -32,18 +35,6 @@ double eval_cast<double>(std::any arg) {
     return std::any_cast<double>(arg);
 }
 
-template <> inline
-arb::region eval_cast<arb::region>(std::any arg) {
-    if (arg.type()==typeid(arb::region)) return std::any_cast<arb::region>(arg);
-    return arb::reg::nil();
-}
-
-template <> inline
-arb::locset eval_cast<arb::locset>(std::any arg) {
-    if (arg.type()==typeid(arb::locset)) return std::any_cast<arb::locset>(arg);
-    return arb::ls::nil();
-}
-
 // Test whether a list of arguments passed as a std::vector<std::any> can be converted
 // to the types in Args.
 //
@@ -256,8 +247,6 @@ util::expected<std::any, E> eval_atom(const s_expr& e) {
             return {std::stoi(t.spelling)};
         case tok::real:
             return {std::stod(t.spelling)};
-        case tok::nil:
-            return {nil_tag()};
         case tok::string:
             return {std::string(t.spelling)};
         case tok::symbol:
diff --git a/doc/concepts/labels.rst b/doc/concepts/labels.rst
index 11d418a9..4c1ee631 100644
--- a/doc/concepts/labels.rst
+++ b/doc/concepts/labels.rst
@@ -202,6 +202,10 @@ dendritic tree where the radius first is less than or equal to 0.2 μm.
 Locset expressions
 ~~~~~~~~~~~~~~~~~~
 
+.. label:: (locset-nil)
+
+    The empty locset.
+
 .. figure:: ../gen-images/label_branch.svg
   :width: 800
   :align: center
@@ -351,7 +355,7 @@ Locset expressions
 Region expressions
 ~~~~~~~~~~~~~~~~~~
 
-.. label:: (nil)
+.. label:: (region-nil)
 
     An empty region.
 
@@ -363,7 +367,7 @@ Region expressions
       :width: 600
       :align: center
 
-      The trivial region definitions ``(nil)`` (left) and ``(all)`` (right).
+      The trivial region definitions ``(region-nil)`` (left) and ``(all)`` (right).
 
 .. label:: (tag tag_id:integer)
 
diff --git a/doc/scripts/gen-labels.py b/doc/scripts/gen-labels.py
index 34d7db2d..bcb7c710 100644
--- a/doc/scripts/gen-labels.py
+++ b/doc/scripts/gen-labels.py
@@ -163,7 +163,7 @@ fn = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__), "../c
 swc_morph = arbor.load_swc_arbor(fn)
 
 regions  = {
-            'empty': '(nil)',
+            'empty': '(region-nil)',
             'all': '(all)',
             'tag1': '(tag 1)',
             'tag2': '(tag 2)',
diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp
index 01dd22da..477fc8e9 100644
--- a/test/unit/test_s_expr.cpp
+++ b/test/unit/test_s_expr.cpp
@@ -251,6 +251,7 @@ TEST(regloc, round_tripping) {
         "(cable 2 0.1 0.4)",
         "(region \"foo\")",
         "(all)",
+        "(region-nil)",
         "(tag 42)",
         "(distal-interval (location 3 0))",
         "(distal-interval (location 3 0) 3.2)",
@@ -273,6 +274,7 @@ TEST(regloc, round_tripping) {
     auto locset_literals = {
         "(root)",
         "(locset \"cat man do\")",
+        "(locset-nil)",
         "(location 3 0.2)",
         "(terminal)",
         "(distal (tag 2))",
@@ -307,6 +309,86 @@ TEST(regloc, comments) {
               round_trip_region(multi_line));
 }
 
+TEST(regloc, reg_nil) {
+    auto check = [](const std::string& s) {
+        auto res = parse_region_expression(s);
+        if (!res.has_value()) throw res.error();
+        return true;
+    };
+
+    std::vector<std::string>
+        args{"(nil)",
+             "()",
+             "nil",
+             "(join () (segment 1)",
+             "(intersect (segment 1) nil"};
+    for (const auto& arg: args) {
+        EXPECT_THROW(check(arg), arborio::label_parse_error);
+    }
+}
+
+TEST(regloc, loc_nil) {
+    auto check = [](const std::string& s) {
+        auto res = parse_locset_expression(s);
+        if (!res.has_value()) throw res.error();
+        return true;
+    };
+
+    std::vector<std::string>
+        args{"(nil)",
+             "()",
+             "nil",
+             "(join () (root)",
+             "(intersect (terminal) nil"};
+    for (const auto& arg: args) {
+        EXPECT_THROW(check(arg), arborio::label_parse_error);
+    }
+}
+
+TEST(regloc, reg_fold_expressions) {
+    auto check = [](const std::string& s) {
+        auto res = parse_region_expression(s);
+        if (!res.has_value()) throw res.error();
+        return true;
+    };
+
+    std::vector<std::string>
+        args{"(region-nil) (region-nil)",
+             "(region-nil) (segment 1)",
+             "(segment 0) (segment 1)",
+             "(region-nil) (segment 0) (segment 1)"},
+        funs{"join",
+             "intersect"};
+    for (const auto& fun: funs) {
+        for (const auto& arg: args) {
+
+            EXPECT_TRUE(check("(" + fun + " " + arg + ")"));
+        }
+    }
+}
+
+TEST(regloc, loc_fold_expressions) {
+    auto check = [](const std::string& s) {
+        auto res = parse_locset_expression(s);
+        if (!res.has_value()) throw res.error();
+        return true;
+    };
+
+    std::vector<std::string>
+        args{"(locset-nil) (locset-nil)",
+             "(locset-nil) (locset-nil) (locset-nil)",
+             "(locset-nil) (terminal)",
+             "(root) (terminal)",
+             "(locset-nil) (root) (terminal)"},
+        funs{"sum",
+             "join"};
+    for (const auto& fun: funs) {
+        for (const auto& arg: args) {
+            EXPECT_TRUE(check("(" + fun + " " + arg + ")"));
+        }
+    }
+}
+
 TEST(regloc, errors) {
     for (auto expr: {"axon",         // unquoted region name
                      "(tag 1.2)",    // invalid argument in an otherwise valid region expression
-- 
GitLab