Skip to content

Commit abba1b4

Browse files
author
C.A.P. Linssen
committed
transform kernels and convolutions using a transformer before code generation
1 parent 026e34c commit abba1b4

File tree

19 files changed

+622
-847
lines changed

19 files changed

+622
-847
lines changed

pynestml/codegeneration/code_generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
117117
# Environment for neuron templates
118118
env = Environment(loader=FileSystemLoader(_template_dirs))
119119
env.globals["raise"] = self.raise_helper
120-
env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel
121120

122121
# Load all the templates
123122
_templates = list()

pynestml/codegeneration/nest_code_generator.py

Lines changed: 14 additions & 148 deletions
Large diffs are not rendered by default.

pynestml/codegeneration/nest_compartmental_code_generator.py

Lines changed: 13 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -280,22 +280,16 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:
280280

281281
def create_ode_indict(self,
282282
neuron: ASTModel,
283-
parameters_block: ASTBlockWithVariables,
284-
kernel_buffers: Mapping[ASTKernel,
285-
ASTInputPort]):
286-
odetoolbox_indict = self.transform_ode_and_kernels_to_json(
287-
neuron, parameters_block, kernel_buffers)
283+
parameters_block: ASTBlockWithVariables):
284+
odetoolbox_indict = self.transform_ode_and_kernels_to_json(neuron, parameters_block)
288285
odetoolbox_indict["options"] = {}
289286
odetoolbox_indict["options"]["output_timestep_symbol"] = "__h"
290287
return odetoolbox_indict
291288

292289
def ode_solve_analytically(self,
293290
neuron: ASTModel,
294-
parameters_block: ASTBlockWithVariables,
295-
kernel_buffers: Mapping[ASTKernel,
296-
ASTInputPort]):
297-
odetoolbox_indict = self.create_ode_indict(
298-
neuron, parameters_block, kernel_buffers)
291+
parameters_block: ASTBlockWithVariables):
292+
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
299293

300294
full_solver_result = analysis(
301295
odetoolbox_indict,
@@ -314,8 +308,7 @@ def ode_solve_analytically(self,
314308

315309
return full_solver_result, analytic_solver
316310

317-
def ode_toolbox_analysis(self, neuron: ASTModel,
318-
kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
311+
def ode_toolbox_analysis(self, neuron: ASTModel):
319312
"""
320313
Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output.
321314
"""
@@ -324,15 +317,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
324317

325318
equations_block = neuron.get_equations_blocks()[0]
326319

327-
if len(equations_block.get_kernels()) == 0 and len(
328-
equations_block.get_ode_equations()) == 0:
320+
if len(equations_block.get_ode_equations()) == 0:
329321
# no equations defined -> no changes to the neuron
330322
return None, None
331323

332324
parameters_block = neuron.get_parameters_blocks()[0]
333325

334-
solver_result, analytic_solver = self.ode_solve_analytically(
335-
neuron, parameters_block, kernel_buffers)
326+
solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block)
336327

337328
# if numeric solver is required, generate a stepping function that
338329
# includes each state variable
@@ -341,8 +332,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
341332
x for x in solver_result if x["solver"].startswith("numeric")]
342333

343334
if numeric_solvers:
344-
odetoolbox_indict = self.create_ode_indict(
345-
neuron, parameters_block, kernel_buffers)
335+
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
346336
solver_result = analysis(
347337
odetoolbox_indict,
348338
disable_stiffness_check=True,
@@ -417,24 +407,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
417407

418408
return []
419409

420-
# goes through all convolve() inside ode's from equations block
421-
# if they have delta kernels, use sympy to expand the expression, then
422-
# find the convolve calls and replace them with constant value 1
423-
# then return every subexpression that had that convolve() replaced
424-
delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)
425-
426-
# goes through all convolve() inside equations block
427-
# extracts what kernel is paired with what spike buffer
428-
# returns pairs (kernel, spike_buffer)
429-
kernel_buffers = ASTUtils.generate_kernel_buffers(
430-
neuron, equations_block)
431-
432-
# replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d]
433-
# done by searching for every ASTSimpleExpression inside equations_block
434-
# which is a convolve call and substituting that call with
435-
# newly created ASTVariable kernel__X__spike_buffer
436-
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)
437-
438410
# substitute inline expressions with each other
439411
# such that no inline expression references another inline expression
440412
ASTUtils.make_inline_expressions_self_contained(
@@ -450,14 +422,13 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
450422
# "update_expressions" key in those solvers contains a mapping
451423
# {expression1: update_expression1, expression2: update_expression2}
452424

453-
analytic_solver, numeric_solver = self.ode_toolbox_analysis(
454-
neuron, kernel_buffers)
425+
analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron)
455426

456427
"""
457428
# separate analytic solutions by kernel
458429
# this is is needed for the synaptic case
459430
self.kernel_name_to_analytic_solver[neuron.get_name(
460-
)] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers)
431+
)] = self.ode_toolbox_anaysis_cm_syns(neuron)
461432
"""
462433

463434
self.analytic_solver[neuron.get_name()] = analytic_solver
@@ -472,12 +443,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
472443
# by odetoolbox, higher order variables don't get deleted here
473444
ASTUtils.remove_initial_values_for_kernels(neuron)
474445

475-
# delete all kernels as they are all converted into buffers
476-
# and corresponding update formulas calculated by odetoolbox
477-
# Remember them in a variable though
478-
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(
479-
neuron)
480-
481446
# Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions
482447
# their initial values are replaced by expressions suggested by ODE-toolbox.
483448
# Differential order can now be set to 0, becase they can directly represent the value of the derivative now.
@@ -491,22 +456,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
491456
# corresponding updates
492457
ASTUtils.remove_ode_definitions_from_equations_block(neuron)
493458

494-
# restore state variables that were referenced by kernels
495-
# and set their initial values by those suggested by ODE-toolbox
496-
ASTUtils.create_initial_values_for_kernels(
497-
neuron, [analytic_solver, numeric_solver], kernels)
498-
499459
# Inside all remaining expressions, translate all remaining variable names
500460
# according to the naming conventions of ODE-toolbox.
501461
ASTUtils.replace_variable_names_in_expressions(
502462
neuron, [analytic_solver, numeric_solver])
503463

504-
# find all inline kernels defined as ASTSimpleExpression
505-
# that have a single kernel convolution aliasing variable ('__X__')
506-
# translate all remaining variable names according to the naming
507-
# conventions of ODE-toolbox
508-
ASTUtils.replace_convolution_aliasing_inlines(neuron)
509-
510464
# add variable __h to internals block
511465
ASTUtils.add_timestep_symbol(neuron)
512466

@@ -677,13 +631,9 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
677631
expr_ast.accept(ASTSymbolTableVisitor())
678632
namespace["numeric_update_expressions"][sym] = expr_ast
679633

680-
namespace["spike_updates"] = neuron.spike_updates
681-
682634
namespace["recordable_state_variables"] = [
683635
sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type(
684-
sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel(
685-
neuron.get_kernel_by_name(
686-
sym.name))]
636+
sym.get_type_symbol()) == "double" and sym.is_recordable]
687637
namespace["recordable_inline_expressions"] = [
688638
sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type(
689639
sym.get_type_symbol()) == "double" and sym.is_recordable]
@@ -807,7 +757,7 @@ def get_spike_update_expressions(
807757
for var_order in range(
808758
ASTUtils.get_kernel_var_order_from_ode_toolbox_result(
809759
kernel_var.get_name(), solver_dicts)):
810-
kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(
760+
kernel_spike_buf_name = ASTUtils.construct_kernel_spike_buf_name(
811761
kernel_var.get_name(), spike_input_port, var_order)
812762
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(
813763
kernel_spike_buf_name, solver_dicts)
@@ -849,18 +799,9 @@ def get_spike_update_expressions(
849799
def transform_ode_and_kernels_to_json(
850800
self,
851801
neuron: ASTModel,
852-
parameters_block,
853-
kernel_buffers):
802+
parameters_block):
854803
"""
855804
Converts AST node to a JSON representation suitable for passing to ode-toolbox.
856-
857-
Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements
858-
859-
convolve(G, ex_spikes)
860-
convolve(G, in_spikes)
861-
862-
then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`.
863-
864805
:param parameters_block: ASTBlockWithVariables
865806
:return: Dict
866807
"""
@@ -890,43 +831,6 @@ def transform_ode_and_kernels_to_json(
890831
iv_symbol_name)] = expr
891832
odetoolbox_indict["dynamics"].append(entry)
892833

893-
# write a copy for each (kernel, spike buffer) combination
894-
for kernel, spike_input_port in kernel_buffers:
895-
896-
if ASTUtils.is_delta_kernel(kernel):
897-
# delta function -- skip passing this to ode-toolbox
898-
continue
899-
900-
for kernel_var in kernel.get_variables():
901-
expr = ASTUtils.get_expr_from_kernel_var(
902-
kernel, kernel_var.get_complete_name())
903-
kernel_order = kernel_var.get_differential_order()
904-
kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name(
905-
kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")
906-
907-
ASTUtils.replace_rhs_variables(expr, kernel_buffers)
908-
909-
entry = {}
910-
entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr)
911-
912-
# initial values need to be declared for order 1 up to kernel
913-
# order (e.g. none for kernel function f(t) = ...; 1 for kernel
914-
# ODE f'(t) = ...; 2 for f''(t) = ... and so on)
915-
entry["initial_values"] = {}
916-
for order in range(kernel_order):
917-
iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name(
918-
kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
919-
symbol_name_ = kernel_var.get_name() + "'" * order
920-
symbol = equations_block.get_scope().resolve_to_symbol(
921-
symbol_name_, SymbolKind.VARIABLE)
922-
assert symbol is not None, "Could not find initial value for variable " + symbol_name_
923-
initial_value_expr = symbol.get_declaring_expression()
924-
assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
925-
entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(
926-
initial_value_expr)
927-
928-
odetoolbox_indict["dynamics"].append(entry)
929-
930834
odetoolbox_indict["parameters"] = {}
931835
if parameters_block is not None:
932836
for decl in parameters_block.get_declarations():

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,8 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
262262

263263
// copy state struct S_
264264
{%- for init in neuron.get_state_symbols() %}
265-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
266265
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
267266
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
268-
{%- endif %}
269267
{%- endfor %}
270268

271269
// copy internals V_
@@ -786,14 +784,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
786784
{%- endfor %}
787785
{%- endif %}
788786

789-
790-
/**
791-
* spike updates due to convolutions
792-
**/
793-
{% filter indent(4) %}
794-
{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
795-
{%- endfilter %}
796-
797787
/**
798788
* Begin NESTML generated code for the onCondition block(s)
799789
**/
@@ -1149,13 +1139,9 @@ void
11491139
{%- endfor %}
11501140

11511141
/**
1152-
* print updates due to convolutions
1142+
* push back spike history
11531143
**/
11541144

1155-
{%- for _, spike_update in post_spike_updates.items() %}
1156-
{{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.;
1157-
{%- endfor %}
1158-
11591145
last_spike_ = t_sp_ms;
11601146
history_.push_back( histentry__{{neuronName}}( last_spike_
11611147
{%- for var in purely_numeric_state_variables_moved|sort %}

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,11 @@ public:
346346
// Getters/setters for state block
347347
// -------------------------------------------------------------------------
348348

349-
{% filter indent(2, True) -%}
349+
{% filter indent(2, True) -%}
350350
{%- for variable_symbol in neuron.get_state_symbols() %}
351-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
352-
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
353-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
354-
{% endif %}
355-
{% endfor %}
351+
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
352+
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
353+
{% endfor %}
356354
{%- endfilter %}
357355
{%- endif %}
358356

@@ -962,22 +960,20 @@ inline nest_port_t {{neuronName}}::handles_test_event(nest::DataLoggingRequest&
962960
inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
963961
{
964962
// parameters
965-
{%- for variable_symbol in neuron.get_parameter_symbols() %}
966-
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
967-
{%- filter indent(2) %}
963+
{%- filter indent(2) %}
964+
{%- for variable_symbol in neuron.get_parameter_symbols() %}
965+
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
968966
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
969-
{%- endfilter %}
970-
{%- endfor %}
967+
{%- endfor %}
968+
{%- endfilter %}
971969

972970
// initial values for state variables in ODE or kernel
973-
{%- for variable_symbol in neuron.get_state_symbols() %}
974-
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
975-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
976-
{%- filter indent(2) %}
977-
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
978-
{%- endfilter %}
979-
{%- endif -%}
980-
{%- endfor %}
971+
{%- filter indent(2) %}
972+
{%- for variable_symbol in neuron.get_state_symbols() %}
973+
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
974+
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
975+
{%- endfor %}
976+
{%- endfilter %}
981977

982978
{{neuron_parent_class}}::get_status( __d );
983979

@@ -1023,11 +1019,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
10231019
// initial values for state variables in ODE or kernel
10241020
{%- for variable_symbol in neuron.get_state_symbols() %}
10251021
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
1026-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
1027-
{%- filter indent(2) %}
1028-
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
1029-
{%- endfilter %}
1030-
{%- endif %}
1022+
{%- filter indent(2) %}
1023+
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
1024+
{%- endfilter %}
10311025
{%- endfor %}
10321026

10331027
// We now know that (ptmp, stmp) are consistent. We do not
@@ -1046,11 +1040,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
10461040

10471041
{%- for variable_symbol in neuron.get_state_symbols() -%}
10481042
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
1049-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
1050-
{%- filter indent(2) %}
1051-
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
1052-
{%- endfilter %}
1053-
{%- endif %}
1043+
{%- filter indent(2) %}
1044+
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
1045+
{%- endfilter %}
10541046
{%- endfor %}
10551047

10561048
{% for invariant in neuron.get_parameter_invariants() %}

0 commit comments

Comments
 (0)