@@ -17,17 +17,18 @@ def test_state_interface(physics: Physics):
1717 sim = Sim (physics = physics , control = Control .state )
1818
1919 # Simple P controller for attitude to reach target height
20+ target_height = 0.5
2021 cmd = np .zeros ((1 , 1 , 13 ), dtype = np .float32 )
21- cmd [0 , 0 , 2 ] = 1.0 # Set z position target to 1.0
22+ cmd [0 , 0 , 2 ] = target_height
23+ steps = int (2 * sim .control_freq ) # Run simulation for 2 seconds
2224
23- for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
25+ for i in range (steps ): # Run simulation for 2 seconds
26+ cmd [..., 2 ] = target_height * i / steps # Linearly interpolate target height
2427 sim .state_control (cmd )
2528 sim .step (sim .freq // sim .control_freq )
26- if np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ])) < 0.1 :
27- break
2829
2930 # Check if drone reached target position
30- distance = np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , 1.0 ]))
31+ distance = np .linalg .norm (sim .data .states .pos [0 , 0 ] - np .array ([0.0 , 0.0 , target_height ]))
3132 assert distance < 0.1 , f"Failed to reach target height with { physics } physics"
3233
3334
@@ -36,13 +37,15 @@ def test_state_interface(physics: Physics):
3637def test_attitude_interface (physics : Physics ):
3738 sim = Sim (physics = physics , control = Control .attitude )
3839 target_pos = np .array ([0.0 , 0.0 , 1.0 ])
39- jit_state2attitude = jax .jit (parametrize (state2attitude , drone_model = "cf2x_L250" ))
40+ jit_state2attitude = jax .jit (parametrize (state2attitude , drone_model = sim . drone_model ))
4041
4142 i_error = np .zeros ((1 , 1 , 3 ))
4243 cmd = np .zeros ((1 , 1 , 13 ))
43- cmd [0 , 0 , 2 ] = 1.0 # Set z position target to 1.0
44+ cmd [0 , 0 , 2 ] = 1.0
45+ steps = int (3 * sim .control_freq )
4446
45- for _ in range (int (2 * sim .control_freq )): # Run simulation for 2 seconds
47+ for i in range (steps ):
48+ cmd [..., :3 ] = target_pos * i / steps # Linearly interpolate target position
4649 pos , vel , quat = sim .data .states .pos , sim .data .states .vel , sim .data .states .quat
4750 rpyt , i_error = jit_state2attitude (pos , quat , vel , cmd , (i_error ,), ctrl_freq = 100 )
4851 sim .attitude_control (rpyt )
@@ -57,7 +60,7 @@ def test_attitude_interface(physics: Physics):
5760@pytest .mark .integration
5861def test_rotor_vel_interface ():
5962 sim = Sim (physics = Physics .first_principles , control = Control .rotor_vel )
60- params = load_params ("first_principles" , "cf2x_L250" )
63+ params = load_params ("first_principles" , sim . drone_model )
6164 max_rpm = motor_force2rotor_vel (np .array ([params ["thrust_max" ]]), params ["rpm2thrust" ])[0 ]
6265
6366 sim .data = sim .data .replace (
@@ -77,15 +80,19 @@ def test_rotor_vel_interface():
7780def test_swarm_control (physics : Physics ):
7881 n_worlds , n_drones = 2 , 3
7982 sim = Sim (n_worlds = n_worlds , n_drones = n_drones , physics = physics , control = Control .state )
83+ start_pos = np .asarray (sim .data .states .pos )
8084 target_pos = sim .data .states .pos + np .array ([0.3 , 0.3 , 0.3 ])
81-
8285 cmd = np .zeros ((n_worlds , n_drones , 13 ))
83- cmd [..., :3 ] = target_pos
84- sim .state_control (cmd )
85- sim .step (3 * sim .freq )
86- # Check if drone maintained hover position
86+ steps = int (3 * sim .control_freq )
87+
88+ for i in range (steps ):
89+ alpha = i / (steps )
90+ cmd [..., :3 ] = start_pos * (1 - alpha ) + target_pos * alpha
91+ sim .state_control (cmd )
92+ sim .step (sim .freq // sim .control_freq )
93+
8794 max_dist = np .max (np .linalg .norm (sim .data .states .pos - target_pos , axis = - 1 ))
88- assert max_dist < 0.05 , f"Failed to reach target, max dist: { max_dist } "
95+ assert max_dist < 0.08 , f"Failed to reach target, max dist: { max_dist } "
8996
9097
9198@pytest .mark .integration
0 commit comments