diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index fce90c519..3b6abb835 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -580,6 +580,7 @@ def __init__( if self.rail_length <= 0: raise ValueError("Rail length must be a positive value.") self.parachutes = self.rocket.parachutes[:] + self.controllers = self.rocket.controllers[:] self.inclination = inclination self.heading = heading self.max_time = max_time @@ -654,6 +655,9 @@ def __init__( phase.TimeNodes.add_parachutes( self.parachutes, phase.t, phase.time_bound ) + phase.TimeNodes.add_controllers( + self.controllers, phase.t, phase.time_bound + ) # Add lst time node to permanent list phase.TimeNodes.add_node(phase.time_bound, [], []) # Sort time nodes @@ -685,6 +689,9 @@ def __init__( for callback in node.callbacks: callback(self) + for controller in node.controllers: + controller(self.t, self.y_sol) + for parachute in node.parachutes: # Calculate and save pressure signal pressure = self.env.pressure.get_value_opt(self.y_sol[2]) @@ -1698,6 +1705,8 @@ def u_dot_generalized(self, t, u, post_processing=False): drag_coeff = self.rocket.power_on_drag.get_value_opt(free_stream_mach) else: drag_coeff = self.rocket.power_off_drag.get_value_opt(free_stream_mach) + for airbrakes in self.rocket.airbrakes: + drag_coeff += airbrakes.cd_s R3 += -0.5 * rho * (free_stream_speed**2) * self.rocket.area * (drag_coeff) ## Off center moment @@ -3472,6 +3481,20 @@ def add_parachutes(self, parachutes, t_init, t_end): ] self.list += parachute_node_list + def add_controllers(self, controllers, t_init, t_end): + # Iterate over controllers + for controller in controllers: + # Calculate start of sampling time nodes + controller_time_step = 1 / controller.sampling_rate + controller_node_list = [ + self.TimeNode(i * controller_time_step, [], [controller]) + for i in range( + math.ceil(t_init / controller_time_step), + math.floor(t_end / controller_time_step) + 1, + ) + ] + self.list += controller_node_list + def sort(self): self.list.sort(key=(lambda node: node.t)) @@ -3495,10 +3518,11 @@ def flush_after(self, index): del self.list[index + 1 :] class TimeNode: - def __init__(self, t, parachutes, callbacks): + def __init__(self, t, parachutes, controllers): self.t = t self.parachutes = parachutes - self.callbacks = callbacks + self.callbacks = [] + self.controllers = controllers def __repr__(self): return (