Skip to content

Commit 4bd5b42

Browse files
author
C.A.P. Linssen
committed
transform kernels and convolutions using a transformer before code generation
1 parent 410e59a commit 4bd5b42

File tree

13 files changed

+768
-914
lines changed

13 files changed

+768
-914
lines changed

pynestml/codegeneration/code_generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
120120
# Environment for neuron templates
121121
env = Environment(loader=FileSystemLoader(_template_dirs))
122122
env.globals["raise"] = self.raise_helper
123-
env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel
124123

125124
# Load all the templates
126125
_templates = list()

pynestml/codegeneration/nest_code_generator.py

Lines changed: 13 additions & 148 deletions
Large diffs are not rendered by default.

pynestml/codegeneration/nest_compartmental_code_generator.py

Lines changed: 19 additions & 219 deletions
Large diffs are not rendered by default.

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,8 @@ std::vector< std::tuple< int, int > > {{ neuronName }}::rport_to_nestml_buffer_i
261261

262262
// copy state struct S_
263263
{%- for init in neuron.get_state_symbols() %}
264-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
265264
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
266265
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
267-
{%- endif %}
268266
{%- endfor %}
269267

270268
// copy internals V_
@@ -723,28 +721,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
723721
update_delay_variables();
724722
{%- endif %}
725723

726-
/**
727-
* subthreshold updates of the convolution variables
728-
*
729-
* step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions
730-
**/
731-
732-
{%- if uses_analytic_solver %}
733-
{%- for variable_name in analytic_state_variables: %}
734-
{%- if "__X__" in variable_name %}
735-
{%- set update_expr = update_expressions[variable_name] %}
736-
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
737-
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
738-
{%- if use_gap_junctions %}
739-
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }};
740-
{%- else %}
741-
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }};
742-
{%- endif %}
743-
{%- endif %}
744-
{%- endfor %}
745-
{%- endif %}
746-
747-
748724
/**
749725
* Begin NESTML generated code for the update block(s)
750726
**/
@@ -774,30 +750,6 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
774750
}
775751
{%- endfor %}
776752

777-
/**
778-
* subthreshold updates of the convolution variables
779-
*
780-
* step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep.
781-
**/
782-
{% if uses_analytic_solver %}
783-
{%- for variable_name in analytic_state_variables: %}
784-
{%- if "__X__" in variable_name %}
785-
{%- set update_expr = update_expressions[variable_name] %}
786-
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
787-
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
788-
{{ printer.print(var_ast) }} = {{variable_name}}__tmp_;
789-
{%- endif %}
790-
{%- endfor %}
791-
{%- endif %}
792-
793-
794-
/**
795-
* spike updates due to convolutions
796-
**/
797-
{% filter indent(4) %}
798-
{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
799-
{%- endfilter %}
800-
801753
/**
802754
* Begin NESTML generated code for the onCondition block(s)
803755
**/

pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,8 @@ public:
348348

349349
{% filter indent(2, True) -%}
350350
{%- for variable_symbol in neuron.get_state_symbols() %}
351-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
352-
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
353-
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
354-
{% endif %}
351+
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
352+
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
355353
{% endfor %}
356354
{%- endfilter %}
357355
{%- endif %}
@@ -970,14 +968,12 @@ inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
970968
{%- endfilter %}
971969
{%- endfor %}
972970

973-
// initial values for state variables in ODE or kernel
971+
// initial values for state variables in ODEs
974972
{%- for variable_symbol in neuron.get_state_symbols() %}
975973
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
976-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
977974
{%- filter indent(2) %}
978975
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
979976
{%- endfilter %}
980-
{%- endif -%}
981977
{%- endfor %}
982978

983979
{{neuron_parent_class}}::get_status( __d );
@@ -1021,14 +1017,12 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
10211017
{%- endfilter %}
10221018
{%- endfor %}
10231019

1024-
// initial values for state variables in ODE or kernel
1020+
// initial values for state variables in ODEs
10251021
{%- for variable_symbol in neuron.get_state_symbols() %}
10261022
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
1027-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
1028-
{%- filter indent(2) %}
1029-
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
1030-
{%- endfilter %}
1031-
{%- endif %}
1023+
{%- filter indent(2) %}
1024+
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
1025+
{%- endfilter %}
10321026
{%- endfor %}
10331027

10341028
// We now know that (ptmp, stmp) are consistent. We do not
@@ -1047,11 +1041,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
10471041

10481042
{%- for variable_symbol in neuron.get_state_symbols() -%}
10491043
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
1050-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
1051-
{%- filter indent(2) %}
1052-
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
1053-
{%- endfilter %}
1054-
{%- endif %}
1044+
{%- filter indent(2) %}
1045+
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
1046+
{%- endfilter %}
10551047
{%- endfor %}
10561048

10571049
{% for invariant in neuron.get_parameter_invariants() %}

pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -822,17 +822,6 @@ public:
822822
{%- endfilter %}
823823
}
824824

825-
/**
826-
* update all convolutions with pre spikes
827-
**/
828-
829-
{%- for spike_updates_for_port in spike_updates.values() %}
830-
{%- for spike_update in spike_updates_for_port %}
831-
{{ printer.print(spike_update.get_variable()) }} += 1.; // XXX: TODO: increment with initial value instead of 1
832-
{%- endfor %}
833-
{%- endfor %}
834-
835-
836825
/**
837826
* in case pre and post spike time coincide and pre update takes priority
838827
**/
@@ -1025,7 +1014,7 @@ void
10251014
{%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %}
10261015
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
10271016
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
1028-
{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) and not variable_symbol.is_inline_expression %}
1017+
{%- if not isHomogeneous and not variable_symbol.is_inline_expression %}
10291018
{%- if variable.get_name() == nest_codegen_opt_delay_variable %}
10301019
// special treatment of NEST delay
10311020
double tmp_{{ nest_codegen_opt_delay_variable }} = get_delay();
@@ -1061,7 +1050,7 @@ if (__d->known(nest::names::weight))
10611050
{%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %}
10621051
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
10631052
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
1064-
{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %}
1053+
{%- if not isHomogeneous %}
10651054
{%- if variable.get_name() == nest_codegen_opt_delay_variable %}
10661055
// special treatment of NEST delay
10671056
set_delay(tmp_{{ nest_codegen_opt_delay_variable }});

pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2

Lines changed: 0 additions & 6 deletions
This file was deleted.

pynestml/codegeneration/resources_python_standalone/point_neuron/@[email protected]

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class Neuron_{{neuronName}}(Neuron):
191191
{%- endif %}
192192
{%- endfor %}
193193
{%- endfilter %}
194+
pass
194195
else:
195196
# internals V_
196197
{%- filter indent(6) %}
@@ -220,10 +221,8 @@ class Neuron_{{neuronName}}(Neuron):
220221
# -------------------------------------------------------------------------
221222
{% filter indent(2, True) -%}
222223
{%- for variable_symbol in neuron.get_state_symbols() %}
223-
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.get_symbol_name())) %}
224224
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
225225
{%- include "directives_py/MemberVariableGetterSetter.jinja2" %}
226-
{%- endif %}
227226
{%- endfor %}
228227
{%- endfilter %}
229228

@@ -264,13 +263,6 @@ class Neuron_{{neuronName}}(Neuron):
264263
{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %}
265264
{%- endif %}
266265

267-
{#- always integrate convolutions in time #}
268-
{%- for var in analytic_state_variables %}
269-
{%- if "__X__" in var %}
270-
{%- set tmp = analytic_state_variables_.append(var) %}
271-
{%- endif %}
272-
{%- endfor %}
273-
274266
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
275267

276268
{%- if uses_numeric_solver %}
@@ -285,14 +277,6 @@ class Neuron_{{neuronName}}(Neuron):
285277
def step(self, origin: float, timestep: float) -> None:
286278
__resolution: float = timestep # do not remove, this is necessary for the resolution() function
287279

288-
# -------------------------------------------------------------------------
289-
# integrate variables related to convolutions
290-
# -------------------------------------------------------------------------
291-
292-
{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
293-
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
294-
{%- endwith %}
295-
296280
# -------------------------------------------------------------------------
297281
# NESTML generated code for the update block
298282
# -------------------------------------------------------------------------
@@ -306,21 +290,6 @@ class Neuron_{{neuronName}}(Neuron):
306290
{%- endfilter %}
307291
{%- endif %}
308292

309-
# -------------------------------------------------------------------------
310-
# integrate variables related to convolutions
311-
# -------------------------------------------------------------------------
312-
313-
{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
314-
{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %}
315-
{%- endwith %}
316-
317-
# -------------------------------------------------------------------------
318-
# process spikes from buffers
319-
# -------------------------------------------------------------------------
320-
{%- filter indent(4, True) -%}
321-
{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %}
322-
{%- endfilter %}
323-
324293
# -------------------------------------------------------------------------
325294
# begin NESTML generated code for the onReceive block(s)
326295
# -------------------------------------------------------------------------

pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2

Lines changed: 0 additions & 6 deletions
This file was deleted.

pynestml/frontend/pynestml_frontend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pynestml.symbols.predefined_types import PredefinedTypes
3838
from pynestml.symbols.predefined_units import PredefinedUnits
3939
from pynestml.symbols.predefined_variables import PredefinedVariables
40+
from pynestml.transformers.convolutions_transformer import ConvolutionsTransformer
4041
from pynestml.transformers.transformer import Transformer
4142
from pynestml.utils.logger import Logger, LoggingLevel
4243
from pynestml.utils.messages import Messages
@@ -59,6 +60,9 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st
5960
if options is None:
6061
options = {}
6162

63+
# for all targets, add the convolutions transformer
64+
transformers.append(ConvolutionsTransformer())
65+
6266
if target_name.upper() in ["NEST", "SPINNAKER"]:
6367
from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer
6468

0 commit comments

Comments
 (0)