diff --git a/modcc/module.cpp b/modcc/module.cpp index 7a4aaaee8a22bf441ed278f6aa88f8b256289b46..279fc357452e905693f940cd8c2f3b112ebf8869 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -277,6 +277,7 @@ bool Module::semantic() { auto initial_api = make_empty_api_method("nrn_init", "initial"); auto api_init = initial_api.first; auto proc_init = initial_api.second; + auto& init_body = api_init->body()->statements(); for(auto& e : *proc_init->body()) { @@ -284,7 +285,7 @@ bool Module::semantic() { if (solve_expression) { // Grab SOLVE statements, put them in `body` after translation. std::set<std::string> solved_ids; - std::unique_ptr<SolverVisitorBase> solver = std::make_unique<SparseSolverVisitor>(); + std::unique_ptr<SolverVisitorBase> solver; // The solve expression inside an initial block can only refer to a linear block auto solve_proc = solve_expression->procedure(); @@ -292,9 +293,14 @@ bool Module::semantic() { if (solve_proc->kind() == procedureKind::linear) { solver = std::make_unique<LinearSolverVisitor>(state_vars); linear_rewrite(solve_proc->body(), state_vars)->accept(solver.get()); + } else if (solve_proc->kind() == procedureKind::kinetic && + solve_expression->variant() == solverVariant::steadystate) { + solver = std::make_unique<SparseSolverVisitor>(solverVariant::steadystate); + kinetic_rewrite(solve_proc->body())->accept(solver.get()); } else { - error("A SOLVE expression in an INITIAL block can only be used to solve a LINEAR block, which" + - solve_expression->name() + "is not.", solve_expression->location()); + error("A SOLVE expression in an INITIAL block can only be used to solve a " + "LINEAR block or a KINETIC block at steadystate and " + + solve_expression->name() + " is neither.", solve_expression->location()); return false; } @@ -307,6 +313,9 @@ bool Module::semantic() { } solved_ids.insert(id); } + + solve_block = remove_unused_locals(solve_block->is_block()); + // Copy body into nrn_init. for (auto &stmt: solve_block->is_block()->statements()) { init_body.emplace_back(stmt->clone());