Skip to content

Commit 5f77773

Browse files
committed
Merge memory allocations without long boxing
1 parent 5f3893c commit 5f77773

File tree

2 files changed

+44
-44
lines changed

2 files changed

+44
-44
lines changed

core/trino-main/src/main/java/io/trino/memory/ClusterMemoryPool.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap;
2525
import org.weakref.jmx.Managed;
2626

27-
import java.util.ArrayList;
2827
import java.util.HashMap;
2928
import java.util.List;
3029
import java.util.Map;
3130

3231
import static com.google.common.base.MoreObjects.toStringHelper;
32+
import static com.google.common.collect.ImmutableList.toImmutableList;
3333
import static java.util.Objects.requireNonNull;
3434

3535
@ThreadSafe
@@ -154,20 +154,22 @@ private List<MemoryAllocation> mergeQueryAllocations(List<MemoryAllocation> left
154154
requireNonNull(left, "left is null");
155155
requireNonNull(right, "right is null");
156156

157-
Map<String, MemoryAllocation> mergedAllocations = new HashMap<>();
157+
Object2LongMap<String> mergedAllocations = new Object2LongOpenHashMap<>();
158158

159159
for (MemoryAllocation allocation : left) {
160-
mergedAllocations.put(allocation.tag(), allocation);
160+
mergedAllocations.put(allocation.tag(), allocation.allocation());
161161
}
162162

163163
for (MemoryAllocation allocation : right) {
164-
mergedAllocations.merge(
164+
mergedAllocations.mergeLong(
165165
allocation.tag(),
166-
allocation,
167-
(a, b) -> new MemoryAllocation(a.tag(), a.allocation() + b.allocation()));
166+
allocation.allocation(),
167+
Long::sum);
168168
}
169169

170-
return new ArrayList<>(mergedAllocations.values());
170+
return mergedAllocations.object2LongEntrySet().stream()
171+
.map(entry -> new MemoryAllocation(entry.getKey(), entry.getLongValue()))
172+
.collect(toImmutableList());
171173
}
172174

173175
@Override

core/trino-main/src/main/java/io/trino/memory/MemoryPool.java

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,22 @@
2323
import io.trino.spi.QueryId;
2424
import io.trino.spi.memory.MemoryAllocation;
2525
import io.trino.spi.memory.MemoryPoolInfo;
26+
import it.unimi.dsi.fastutil.objects.Object2LongMap;
27+
import it.unimi.dsi.fastutil.objects.Object2LongMaps;
28+
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap;
2629
import jakarta.annotation.Nullable;
2730
import org.weakref.jmx.Managed;
2831

29-
import java.util.ArrayList;
3032
import java.util.HashMap;
3133
import java.util.List;
3234
import java.util.Map;
3335
import java.util.Map.Entry;
34-
import java.util.concurrent.ConcurrentHashMap;
3536
import java.util.concurrent.CopyOnWriteArrayList;
3637

3738
import static com.google.common.base.MoreObjects.toStringHelper;
3839
import static com.google.common.base.Preconditions.checkArgument;
3940
import static com.google.common.base.Preconditions.checkState;
41+
import static com.google.common.collect.ImmutableList.toImmutableList;
4042
import static com.google.common.collect.ImmutableMap.toImmutableMap;
4143
import static io.trino.operator.Operator.NOT_BLOCKED;
4244
import static java.util.Objects.requireNonNull;
@@ -56,17 +58,17 @@ public class MemoryPool
5658

5759
// TODO: It would be better if we just tracked QueryContexts, but their lifecycle is managed by a weak reference, so we can't do that
5860
// It is guarded for updates by this, but can be read without holding a lock
59-
private final Map<QueryId, Long> queryMemoryReservations = new ConcurrentHashMap<>();
61+
private final Object2LongMap<QueryId> queryMemoryReservations = new Object2LongOpenHashMap<>();
6062

6163
// This map keeps track of all the tagged allocations, e.g., query-1 -> ['TableScanOperator': 10MB, 'LazyOutputBuffer': 5MB, ...]
6264
@GuardedBy("this")
63-
private final Map<QueryId, Map<String, Long>> taggedMemoryAllocations = new HashMap<>();
65+
private final Map<QueryId, Object2LongMap<String>> taggedMemoryAllocations = new HashMap<>();
6466

6567
@GuardedBy("this")
66-
private final Map<TaskId, Long> taskMemoryReservations = new HashMap<>();
68+
private final Object2LongMap<TaskId> taskMemoryReservations = new Object2LongOpenHashMap<>();
6769

6870
@GuardedBy("this")
69-
private final Map<TaskId, Long> taskRevocableMemoryReservations = new HashMap<>();
71+
private final Object2LongMap<TaskId> taskRevocableMemoryReservations = new Object2LongOpenHashMap<>();
7072

7173
private final List<MemoryPoolListener> listeners = new CopyOnWriteArrayList<>();
7274

@@ -78,21 +80,20 @@ public MemoryPool(DataSize size)
7880

7981
public synchronized MemoryPoolInfo getInfo()
8082
{
81-
Map<QueryId, List<MemoryAllocation>> memoryAllocations = new HashMap<>();
82-
for (Entry<QueryId, Map<String, Long>> entry : taggedMemoryAllocations.entrySet()) {
83-
List<MemoryAllocation> allocations = new ArrayList<>();
84-
if (entry.getValue() != null) {
85-
entry.getValue().forEach((tag, allocation) -> allocations.add(new MemoryAllocation(tag, allocation)));
86-
}
83+
ImmutableMap.Builder<QueryId, List<MemoryAllocation>> memoryAllocations = ImmutableMap.builder();
84+
for (Entry<QueryId, Object2LongMap<String>> entry : taggedMemoryAllocations.entrySet()) {
85+
List<MemoryAllocation> allocations = entry.getValue().object2LongEntrySet().stream()
86+
.map(allocation -> new MemoryAllocation(allocation.getKey(), allocation.getLongValue()))
87+
.collect(toImmutableList());
8788
memoryAllocations.put(entry.getKey(), allocations);
8889
}
8990

90-
Map<String, Long> stringKeyedTaskMemoryReservations = taskMemoryReservations.entrySet().stream()
91+
Map<String, Long> stringKeyedTaskMemoryReservations = taskMemoryReservations.object2LongEntrySet().stream()
9192
.collect(toImmutableMap(
9293
entry -> entry.getKey().toString(),
9394
Entry::getValue));
9495

95-
Map<String, Long> stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations.entrySet().stream()
96+
Map<String, Long> stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations.object2LongEntrySet().stream()
9697
.collect(toImmutableMap(
9798
entry -> entry.getKey().toString(),
9899
Entry::getValue));
@@ -102,7 +103,7 @@ public synchronized MemoryPoolInfo getInfo()
102103
reservedBytes,
103104
reservedRevocableBytes,
104105
queryMemoryReservations,
105-
memoryAllocations,
106+
memoryAllocations.buildOrThrow(),
106107
stringKeyedTaskMemoryReservations,
107108
stringKeyedTaskRevocableMemoryReservations);
108109
}
@@ -127,9 +128,9 @@ public ListenableFuture<Void> reserve(TaskId taskId, String allocationTag, long
127128
synchronized (this) {
128129
if (bytes != 0) {
129130
QueryId queryId = taskId.queryId();
130-
queryMemoryReservations.merge(queryId, bytes, Long::sum);
131+
queryMemoryReservations.mergeLong(queryId, bytes, Long::sum);
131132
updateTaggedMemoryAllocations(queryId, allocationTag, bytes);
132-
taskMemoryReservations.merge(taskId, bytes, Long::sum);
133+
taskMemoryReservations.mergeLong(taskId, bytes, Long::sum);
133134
}
134135
reservedBytes += bytes;
135136
if (getFreeBytes() <= 0) {
@@ -192,9 +193,9 @@ public boolean tryReserve(TaskId taskId, String allocationTag, long bytes)
192193
reservedBytes += bytes;
193194
if (bytes != 0) {
194195
QueryId queryId = taskId.queryId();
195-
queryMemoryReservations.merge(queryId, bytes, Long::sum);
196+
queryMemoryReservations.mergeLong(queryId, bytes, Long::sum);
196197
updateTaggedMemoryAllocations(queryId, allocationTag, bytes);
197-
taskMemoryReservations.merge(taskId, bytes, Long::sum);
198+
taskMemoryReservations.mergeLong(taskId, bytes, Long::sum);
198199
}
199200
}
200201

@@ -226,17 +227,15 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes)
226227
}
227228

228229
QueryId queryId = taskId.queryId();
229-
Long queryReservation = queryMemoryReservations.get(queryId);
230-
requireNonNull(queryReservation, "queryReservation is null");
230+
long queryReservation = queryMemoryReservations.getLong(queryId);
231231
checkArgument(queryReservation >= bytes, "tried to free more memory than is reserved by query");
232232

233-
Long taskReservation = taskMemoryReservations.get(taskId);
234-
requireNonNull(taskReservation, "taskReservation is null");
233+
long taskReservation = taskMemoryReservations.getLong(taskId);
235234
checkArgument(taskReservation >= bytes, "tried to free more memory than is reserved by task");
236235

237236
queryReservation -= bytes;
238237
if (queryReservation == 0) {
239-
queryMemoryReservations.remove(queryId);
238+
queryMemoryReservations.removeLong(queryId);
240239
taggedMemoryAllocations.remove(queryId);
241240
}
242241
else {
@@ -246,7 +245,7 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes)
246245

247246
taskReservation -= bytes;
248247
if (taskReservation == 0) {
249-
taskMemoryReservations.remove(taskId);
248+
taskMemoryReservations.removeLong(taskId);
250249
}
251250
else {
252251
taskMemoryReservations.put(taskId, taskReservation);
@@ -268,13 +267,12 @@ public synchronized void freeRevocable(TaskId taskId, long bytes)
268267
return;
269268
}
270269

271-
Long taskReservation = taskRevocableMemoryReservations.get(taskId);
272-
requireNonNull(taskReservation, "taskReservation is null");
270+
long taskReservation = taskRevocableMemoryReservations.getLong(taskId);
273271
checkArgument(taskReservation >= bytes, "tried to free more revocable memory than is reserved by task");
274272

275273
taskReservation -= bytes;
276274
if (taskReservation == 0) {
277-
taskRevocableMemoryReservations.remove(taskId);
275+
taskRevocableMemoryReservations.removeLong(taskId);
278276
}
279277
else {
280278
taskRevocableMemoryReservations.put(taskId, taskReservation);
@@ -386,12 +384,12 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String
386384
return;
387385
}
388386

389-
Map<String, Long> allocations = taggedMemoryAllocations.computeIfAbsent(queryId, _ -> new HashMap<>());
390-
allocations.compute(allocationTag, (ignored, oldValue) -> {
387+
Object2LongMap<String> allocations = taggedMemoryAllocations.computeIfAbsent(queryId, _ -> new Object2LongOpenHashMap<>());
388+
allocations.computeLong(allocationTag, (_, oldValue) -> {
391389
if (oldValue == null) {
392390
return delta;
393391
}
394-
long newValue = oldValue.longValue() + delta;
392+
long newValue = oldValue + delta;
395393
if (newValue == 0) {
396394
return null;
397395
}
@@ -400,9 +398,9 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String
400398
}
401399

402400
@VisibleForTesting
403-
public synchronized Map<QueryId, Long> getQueryMemoryReservations()
401+
public synchronized Object2LongMap<QueryId> getQueryMemoryReservations()
404402
{
405-
return ImmutableMap.copyOf(queryMemoryReservations);
403+
return Object2LongMaps.unmodifiable(queryMemoryReservations);
406404
}
407405

408406
@VisibleForTesting
@@ -412,14 +410,14 @@ public synchronized Map<QueryId, Map<String, Long>> getTaggedMemoryAllocations()
412410
}
413411

414412
@VisibleForTesting
415-
public synchronized Map<TaskId, Long> getTaskMemoryReservations()
413+
public synchronized Object2LongMap<TaskId> getTaskMemoryReservations()
416414
{
417-
return ImmutableMap.copyOf(taskMemoryReservations);
415+
return Object2LongMaps.unmodifiable(taskMemoryReservations);
418416
}
419417

420418
@VisibleForTesting
421-
public synchronized Map<TaskId, Long> getTaskRevocableMemoryReservations()
419+
public synchronized Object2LongMap<TaskId> getTaskRevocableMemoryReservations()
422420
{
423-
return ImmutableMap.copyOf(taskRevocableMemoryReservations);
421+
return Object2LongMaps.unmodifiable(taskRevocableMemoryReservations);
424422
}
425423
}

0 commit comments

Comments
 (0)