2323import io .trino .spi .QueryId ;
2424import io .trino .spi .memory .MemoryAllocation ;
2525import io .trino .spi .memory .MemoryPoolInfo ;
26+ import it .unimi .dsi .fastutil .objects .Object2LongMap ;
27+ import it .unimi .dsi .fastutil .objects .Object2LongMaps ;
28+ import it .unimi .dsi .fastutil .objects .Object2LongOpenHashMap ;
2629import jakarta .annotation .Nullable ;
2730import org .weakref .jmx .Managed ;
2831
29- import java .util .ArrayList ;
3032import java .util .HashMap ;
3133import java .util .List ;
3234import java .util .Map ;
3335import java .util .Map .Entry ;
34- import java .util .concurrent .ConcurrentHashMap ;
3536import java .util .concurrent .CopyOnWriteArrayList ;
3637
3738import static com .google .common .base .MoreObjects .toStringHelper ;
3839import static com .google .common .base .Preconditions .checkArgument ;
3940import static com .google .common .base .Preconditions .checkState ;
41+ import static com .google .common .collect .ImmutableList .toImmutableList ;
4042import static com .google .common .collect .ImmutableMap .toImmutableMap ;
4143import static io .trino .operator .Operator .NOT_BLOCKED ;
4244import static java .util .Objects .requireNonNull ;
@@ -56,17 +58,17 @@ public class MemoryPool
5658
5759 // TODO: It would be better if we just tracked QueryContexts, but their lifecycle is managed by a weak reference, so we can't do that
5860 // It is guarded for updates by this, but can be read without holding a lock
59- private final Map <QueryId , Long > queryMemoryReservations = new ConcurrentHashMap <>();
61+ private final Object2LongMap <QueryId > queryMemoryReservations = new Object2LongOpenHashMap <>();
6062
6163 // This map keeps track of all the tagged allocations, e.g., query-1 -> ['TableScanOperator': 10MB, 'LazyOutputBuffer': 5MB, ...]
6264 @ GuardedBy ("this" )
63- private final Map <QueryId , Map <String , Long >> taggedMemoryAllocations = new HashMap <>();
65+ private final Map <QueryId , Object2LongMap <String >> taggedMemoryAllocations = new HashMap <>();
6466
6567 @ GuardedBy ("this" )
66- private final Map <TaskId , Long > taskMemoryReservations = new HashMap <>();
68+ private final Object2LongMap <TaskId > taskMemoryReservations = new Object2LongOpenHashMap <>();
6769
6870 @ GuardedBy ("this" )
69- private final Map <TaskId , Long > taskRevocableMemoryReservations = new HashMap <>();
71+ private final Object2LongMap <TaskId > taskRevocableMemoryReservations = new Object2LongOpenHashMap <>();
7072
7173 private final List <MemoryPoolListener > listeners = new CopyOnWriteArrayList <>();
7274
@@ -78,21 +80,20 @@ public MemoryPool(DataSize size)
7880
7981 public synchronized MemoryPoolInfo getInfo ()
8082 {
81- Map <QueryId , List <MemoryAllocation >> memoryAllocations = new HashMap <>();
82- for (Entry <QueryId , Map <String , Long >> entry : taggedMemoryAllocations .entrySet ()) {
83- List <MemoryAllocation > allocations = new ArrayList <>();
84- if (entry .getValue () != null ) {
85- entry .getValue ().forEach ((tag , allocation ) -> allocations .add (new MemoryAllocation (tag , allocation )));
86- }
83+ ImmutableMap .Builder <QueryId , List <MemoryAllocation >> memoryAllocations = ImmutableMap .builder ();
84+ for (Entry <QueryId , Object2LongMap <String >> entry : taggedMemoryAllocations .entrySet ()) {
85+ List <MemoryAllocation > allocations = entry .getValue ().object2LongEntrySet ().stream ()
86+ .map (allocation -> new MemoryAllocation (allocation .getKey (), allocation .getLongValue ()))
87+ .collect (toImmutableList ());
8788 memoryAllocations .put (entry .getKey (), allocations );
8889 }
8990
90- Map <String , Long > stringKeyedTaskMemoryReservations = taskMemoryReservations .entrySet ().stream ()
91+ Map <String , Long > stringKeyedTaskMemoryReservations = taskMemoryReservations .object2LongEntrySet ().stream ()
9192 .collect (toImmutableMap (
9293 entry -> entry .getKey ().toString (),
9394 Entry ::getValue ));
9495
95- Map <String , Long > stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations .entrySet ().stream ()
96+ Map <String , Long > stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations .object2LongEntrySet ().stream ()
9697 .collect (toImmutableMap (
9798 entry -> entry .getKey ().toString (),
9899 Entry ::getValue ));
@@ -102,7 +103,7 @@ public synchronized MemoryPoolInfo getInfo()
102103 reservedBytes ,
103104 reservedRevocableBytes ,
104105 queryMemoryReservations ,
105- memoryAllocations ,
106+ memoryAllocations . buildOrThrow () ,
106107 stringKeyedTaskMemoryReservations ,
107108 stringKeyedTaskRevocableMemoryReservations );
108109 }
@@ -127,9 +128,9 @@ public ListenableFuture<Void> reserve(TaskId taskId, String allocationTag, long
127128 synchronized (this ) {
128129 if (bytes != 0 ) {
129130 QueryId queryId = taskId .queryId ();
130- queryMemoryReservations .merge (queryId , bytes , Long ::sum );
131+ queryMemoryReservations .mergeLong (queryId , bytes , Long ::sum );
131132 updateTaggedMemoryAllocations (queryId , allocationTag , bytes );
132- taskMemoryReservations .merge (taskId , bytes , Long ::sum );
133+ taskMemoryReservations .mergeLong (taskId , bytes , Long ::sum );
133134 }
134135 reservedBytes += bytes ;
135136 if (getFreeBytes () <= 0 ) {
@@ -192,9 +193,9 @@ public boolean tryReserve(TaskId taskId, String allocationTag, long bytes)
192193 reservedBytes += bytes ;
193194 if (bytes != 0 ) {
194195 QueryId queryId = taskId .queryId ();
195- queryMemoryReservations .merge (queryId , bytes , Long ::sum );
196+ queryMemoryReservations .mergeLong (queryId , bytes , Long ::sum );
196197 updateTaggedMemoryAllocations (queryId , allocationTag , bytes );
197- taskMemoryReservations .merge (taskId , bytes , Long ::sum );
198+ taskMemoryReservations .mergeLong (taskId , bytes , Long ::sum );
198199 }
199200 }
200201
@@ -226,17 +227,15 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes)
226227 }
227228
228229 QueryId queryId = taskId .queryId ();
229- Long queryReservation = queryMemoryReservations .get (queryId );
230- requireNonNull (queryReservation , "queryReservation is null" );
230+ long queryReservation = queryMemoryReservations .getLong (queryId );
231231 checkArgument (queryReservation >= bytes , "tried to free more memory than is reserved by query" );
232232
233- Long taskReservation = taskMemoryReservations .get (taskId );
234- requireNonNull (taskReservation , "taskReservation is null" );
233+ long taskReservation = taskMemoryReservations .getLong (taskId );
235234 checkArgument (taskReservation >= bytes , "tried to free more memory than is reserved by task" );
236235
237236 queryReservation -= bytes ;
238237 if (queryReservation == 0 ) {
239- queryMemoryReservations .remove (queryId );
238+ queryMemoryReservations .removeLong (queryId );
240239 taggedMemoryAllocations .remove (queryId );
241240 }
242241 else {
@@ -246,7 +245,7 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes)
246245
247246 taskReservation -= bytes ;
248247 if (taskReservation == 0 ) {
249- taskMemoryReservations .remove (taskId );
248+ taskMemoryReservations .removeLong (taskId );
250249 }
251250 else {
252251 taskMemoryReservations .put (taskId , taskReservation );
@@ -268,13 +267,12 @@ public synchronized void freeRevocable(TaskId taskId, long bytes)
268267 return ;
269268 }
270269
271- Long taskReservation = taskRevocableMemoryReservations .get (taskId );
272- requireNonNull (taskReservation , "taskReservation is null" );
270+ long taskReservation = taskRevocableMemoryReservations .getLong (taskId );
273271 checkArgument (taskReservation >= bytes , "tried to free more revocable memory than is reserved by task" );
274272
275273 taskReservation -= bytes ;
276274 if (taskReservation == 0 ) {
277- taskRevocableMemoryReservations .remove (taskId );
275+ taskRevocableMemoryReservations .removeLong (taskId );
278276 }
279277 else {
280278 taskRevocableMemoryReservations .put (taskId , taskReservation );
@@ -386,12 +384,12 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String
386384 return ;
387385 }
388386
389- Map <String , Long > allocations = taggedMemoryAllocations .computeIfAbsent (queryId , _ -> new HashMap <>());
390- allocations .compute (allocationTag , (ignored , oldValue ) -> {
387+ Object2LongMap <String > allocations = taggedMemoryAllocations .computeIfAbsent (queryId , _ -> new Object2LongOpenHashMap <>());
388+ allocations .computeLong (allocationTag , (_ , oldValue ) -> {
391389 if (oldValue == null ) {
392390 return delta ;
393391 }
394- long newValue = oldValue . longValue () + delta ;
392+ long newValue = oldValue + delta ;
395393 if (newValue == 0 ) {
396394 return null ;
397395 }
@@ -400,9 +398,9 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String
400398 }
401399
402400 @ VisibleForTesting
403- public synchronized Map <QueryId , Long > getQueryMemoryReservations ()
401+ public synchronized Object2LongMap <QueryId > getQueryMemoryReservations ()
404402 {
405- return ImmutableMap . copyOf (queryMemoryReservations );
403+ return Object2LongMaps . unmodifiable (queryMemoryReservations );
406404 }
407405
408406 @ VisibleForTesting
@@ -412,14 +410,14 @@ public synchronized Map<QueryId, Map<String, Long>> getTaggedMemoryAllocations()
412410 }
413411
414412 @ VisibleForTesting
415- public synchronized Map <TaskId , Long > getTaskMemoryReservations ()
413+ public synchronized Object2LongMap <TaskId > getTaskMemoryReservations ()
416414 {
417- return ImmutableMap . copyOf (taskMemoryReservations );
415+ return Object2LongMaps . unmodifiable (taskMemoryReservations );
418416 }
419417
420418 @ VisibleForTesting
421- public synchronized Map <TaskId , Long > getTaskRevocableMemoryReservations ()
419+ public synchronized Object2LongMap <TaskId > getTaskRevocableMemoryReservations ()
422420 {
423- return ImmutableMap . copyOf (taskRevocableMemoryReservations );
421+ return Object2LongMaps . unmodifiable (taskRevocableMemoryReservations );
424422 }
425423}
0 commit comments