15
15
from . import bee_simulator
16
16
17
17
class StoneResults (ExperimentResults ):
18
- def __init__ (self , name : str , parameters : dict , headings , velocities , log , cpu4_snapshot ) -> None :
18
+ def __init__ (self , name : str , parameters : dict , headings , velocities , log , cpu4_snapshot , recordings : dict ) -> None :
19
19
super ().__init__ (name , parameters )
20
20
self .headings = headings
21
21
self .velocities = velocities
22
22
self .log = log
23
23
self .cpu4_snapshot = cpu4_snapshot
24
+ self .recordings = recordings
24
25
25
26
def report (self ):
26
27
logger .info ("plotting route" )
@@ -40,7 +41,7 @@ def serialize(self):
40
41
return {
41
42
"headings" : self .headings .tolist (),
42
43
"velocities" : self .headings .tolist (),
43
-
44
+ "recordings" : { layer : [ entry . tolist () for entry in entries ] for layer , entries in self . recordings . items ()}
44
45
# annoying to serialize:
45
46
#"log": self.log,
46
47
#"cpu4_snapshot": self.cpu4_snapshot,
@@ -66,20 +67,20 @@ def __init__(self, parameters: dict) -> None:
66
67
noise = self .parameters ["noise" ]
67
68
cx_type = self .parameters ["cx" ]
68
69
69
- phi = self .parameters ["phi" ]
70
- beta = self .parameters ["beta" ]
71
- T_half = self .parameters ["T_half" ]
72
- epsilon = self .parameters ["epsilon" ]
73
- length = self .parameters ["length" ]
74
- c_tot = self .parameters ["c_tot" ]
75
-
76
70
if cx_type == "basic" :
77
71
cx = basic .CXBasic ()
78
72
elif cx_type == "rate" :
79
73
cx = rate .CXRate (noise = noise )
80
74
elif cx_type == "pontine" :
81
75
cx = rate .CXRatePontine (noise = noise )
82
76
elif cx_type == "dye" :
77
+ phi = self .parameters ["phi" ]
78
+ beta = self .parameters ["beta" ]
79
+ T_half = self .parameters ["T_half" ]
80
+ epsilon = self .parameters ["epsilon" ]
81
+ length = self .parameters ["length" ]
82
+ c_tot = self .parameters ["c_tot" ]
83
+
83
84
cx = dye .CXDye (
84
85
noise = noise ,
85
86
phi = phi ,
@@ -103,6 +104,9 @@ def run(self, name: str) -> ExperimentResults:
103
104
cx_type = self .parameters ["cx" ]
104
105
time_subdivision = self .parameters ["time_subdivision" ] if "time_subdivision" in self .parameters else 1
105
106
107
+ layers_to_record = self .parameters ["record" ] if "record" in self .parameters else []
108
+ recordings = {layer : [] for layer in layers_to_record }
109
+
106
110
logger .info ("initializing central complex" )
107
111
108
112
headings = np .zeros (T_outbound + T_inbound )
@@ -120,6 +124,8 @@ def run(self, name: str) -> ExperimentResults:
120
124
for heading , velocity in zip (headings [0 :T_outbound ], velocities [0 :T_outbound , :]):
121
125
for ts in range (time_subdivision ):
122
126
self .cx .update (dt , heading , velocity )
127
+ for layer in layers_to_record :
128
+ recordings [layer ].append (self .cx .network .output (layer ))
123
129
124
130
for t in range (T_outbound , T_outbound + T_inbound ):
125
131
heading = headings [t - 1 ]
@@ -139,6 +145,8 @@ def run(self, name: str) -> ExperimentResults:
139
145
)
140
146
141
147
headings [t ], velocities [t ,:] = heading , velocity
148
+ for layer in layers_to_record :
149
+ recordings [layer ].append (self .cx .network .output (layer ))
142
150
143
- return StoneResults (name , self .parameters , headings , velocities , log = None , cpu4_snapshot = None )
151
+ return StoneResults (name , self .parameters , headings , velocities , log = None , cpu4_snapshot = None , recordings = recordings )
144
152
0 commit comments