1
1
package us .ihmc .perception .cuda ;
2
2
3
+ import org .bytedeco .cuda .cudart .CUevent_st ;
3
4
import org .bytedeco .cuda .cudart .CUfunc_st ;
4
5
import org .bytedeco .cuda .cudart .CUmod_st ;
5
6
import org .bytedeco .cuda .cudart .CUstream_st ;
11
12
import org .bytedeco .javacpp .LongPointer ;
12
13
import org .bytedeco .javacpp .Pointer ;
13
14
import org .bytedeco .javacpp .PointerPointer ;
15
+ import us .ihmc .log .LogTools ;
14
16
15
17
import java .util .ArrayList ;
18
+ import java .util .LinkedList ;
16
19
import java .util .List ;
20
+ import java .util .Optional ;
17
21
18
- import static org .bytedeco .cuda .global .cudart .cuLaunchKernel ;
19
- import static org .bytedeco .cuda .global .cudart .cuModuleGetFunction ;
22
+ import static org .bytedeco .cuda .global .cudart .*;
20
23
import static us .ihmc .perception .cuda .CUDATools .throwCUDAError ;
21
24
22
25
@ SuppressWarnings ("resource" )
23
26
public class CUDAKernel implements AutoCloseable
24
27
{
28
+ private final String name ;
25
29
private final CUfunc_st kernelFunction = new CUfunc_st ();
26
30
private final List <Pointer > parameters = new ArrayList <>();
27
31
private boolean retainParameters = false ;
32
+ private boolean enableKernelTimings = false ;
33
+
34
+ private CUDAKernelTimings kernelTimings ;
35
+ private final CUevent_st start = new CUevent_st ();
36
+ private final CUevent_st end = new CUevent_st ();
28
37
29
38
private int error ;
30
39
31
40
public CUDAKernel (String name , CUmod_st kernelModule ) throws Exception
32
41
{
42
+ this .name = name ;
33
43
error = cuModuleGetFunction (kernelFunction , kernelModule , name );
34
44
throwCUDAError (error );
35
45
}
36
46
47
+ /**
48
+ * Setting this to true enables the ability to run timings on the specific kernel.
49
+ * The timing checks perform synchronization calls.
50
+ */
51
+ public void enableKernelTimings (boolean enableKernelTimings )
52
+ {
53
+ this .enableKernelTimings = enableKernelTimings ;
54
+ kernelTimings = new CUDAKernelTimings ();
55
+ }
56
+
37
57
public void retainParameters (boolean retainParameters )
38
58
{
39
59
this .retainParameters = retainParameters ;
@@ -53,6 +73,13 @@ public void run(CUstream_st stream, dim3 gridSize, dim3 blockSize, int sharedMem
53
73
for (int i = 0 ; i < parameters .size (); ++i )
54
74
parametersPointer .put (i , parameters .get (i ));
55
75
76
+ if (enableKernelTimings )
77
+ {
78
+ cudaEventCreate (start );
79
+ cudaEventCreate (end );
80
+ cudaEventRecord (start );
81
+ }
82
+
56
83
error = cuLaunchKernel (kernelFunction ,
57
84
gridSize .x (),
58
85
gridSize .y (),
@@ -64,6 +91,16 @@ public void run(CUstream_st stream, dim3 gridSize, dim3 blockSize, int sharedMem
64
91
stream ,
65
92
parametersPointer ,
66
93
new PointerPointer <>());
94
+
95
+ if (enableKernelTimings )
96
+ {
97
+ cudaEventRecord (end );
98
+ cudaEventSynchronize (end );
99
+
100
+ kernelTimings .addExecutionTime (start , end );
101
+ kernelTimings .printTimesForKernel ();
102
+ }
103
+
67
104
CUDATools .checkCUDAError (error );
68
105
69
106
if (!retainParameters )
@@ -123,4 +160,96 @@ public void close()
123
160
clearParameters ();
124
161
kernelFunction .close ();
125
162
}
163
+
164
+ /**
165
+ * This class handles the kernel timings.
166
+ * With options to compute the min/max, average, and variance of the dataset
167
+ */
168
+ private class CUDAKernelTimings
169
+ {
170
+ private static final int MAX_ENTRIES = 250 ;
171
+ private final LinkedList <Float > executionTimes = new LinkedList <>();
172
+
173
+ private void addExecutionTime (CUevent_st start , CUevent_st end )
174
+ {
175
+ float [] milliseconds = new float [1 ];
176
+ milliseconds [0 ] = 0.0f ;
177
+ cudaEventElapsedTime (milliseconds , start , end );
178
+ executionTimes .add (milliseconds [0 ]);
179
+
180
+ if (executionTimes .size () > MAX_ENTRIES )
181
+ {
182
+ executionTimes .pollFirst ();
183
+ }
184
+ }
185
+
186
+ public double getAverageTime (String kernelName )
187
+ {
188
+ if (executionTimes .isEmpty ())
189
+ {
190
+ LogTools .info ("No recorded times for " + kernelName );
191
+ return Float .NaN ;
192
+ }
193
+ else
194
+ {
195
+ return executionTimes .stream ().mapToDouble (Float ::doubleValue ).average ().orElse (0.0 );
196
+ }
197
+ }
198
+
199
+ public Float getMinTime (String kernelName )
200
+ {
201
+ if (executionTimes .isEmpty ())
202
+ {
203
+ LogTools .info ("No recorded times for " + kernelName );
204
+ return Float .NaN ;
205
+ }
206
+ Optional <Float > min = executionTimes .stream ().min (Float ::compareTo );
207
+ return min .orElse (null );
208
+ }
209
+
210
+ public Float getMaxTime (String kernelName )
211
+ {
212
+ if (executionTimes .isEmpty ())
213
+ {
214
+ LogTools .info ("No recorded times for " + kernelName );
215
+ return Float .NaN ;
216
+ }
217
+
218
+ Optional <Float > max = executionTimes .stream ().max (Float ::compareTo );
219
+ return max .orElse (null );
220
+ }
221
+
222
+ public double getStandardDeviation (String kernelName )
223
+ {
224
+ if (executionTimes .isEmpty ())
225
+ {
226
+ LogTools .info ("No recorded times for " + kernelName );
227
+ return Float .NaN ;
228
+ }
229
+
230
+ double average = executionTimes .stream ().mapToDouble (Float ::doubleValue ).average ().orElse (0.0 );
231
+ double variance = executionTimes .stream ().mapToDouble (time -> Math .pow (time - average , 2 )).average ().orElse (0.0 );
232
+ return Math .sqrt (variance );
233
+ }
234
+
235
+ public void printTimesForKernel ()
236
+ {
237
+ if (executionTimes .isEmpty ())
238
+ {
239
+ LogTools .info ("No recorded times for " + CUDAKernel .this .name );
240
+ }
241
+
242
+ double average = getAverageTime (CUDAKernel .this .name );
243
+ double variance = getStandardDeviation (CUDAKernel .this .name );
244
+ double min = getMinTime (CUDAKernel .this .name );
245
+ double max = getMaxTime (CUDAKernel .this .name );
246
+
247
+ LogTools .info ("Timings for kernel " + CUDAKernel .this .name + " in milliseconds!" );
248
+ LogTools .info ("| Average time: " + average );
249
+ LogTools .info ("| Variance time: " + variance );
250
+ LogTools .info ("| Min time: " + min );
251
+ LogTools .info ("| Max time: " + max );
252
+ LogTools .warn ("******************************************" );
253
+ }
254
+ }
126
255
}
0 commit comments