diff --git a/core/trino-main/src/main/java/io/trino/connector/system/QuerySystemTable.java b/core/trino-main/src/main/java/io/trino/connector/system/QuerySystemTable.java index 49c147a1a2cf..ddf1a428314c 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/QuerySystemTable.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/QuerySystemTable.java @@ -104,12 +104,8 @@ public ConnectorTableMetadata getTableMetadata() public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, ConnectorSession session, TupleDomain constraint) { checkState(dispatchManager.isPresent(), "Query system table can return results only on coordinator"); - - List queries = dispatchManager.get().getQueries(); - queries = filterQueries(((FullConnectorSession) session).getSession().getIdentity(), queries, accessControl); - Builder table = InMemoryRecordSet.builder(QUERY_TABLE); - for (BasicQueryInfo queryInfo : queries) { + for (BasicQueryInfo queryInfo : filterQueries(dispatchManager.get().getQueries(), ((FullConnectorSession) session).getSession().getIdentity(), accessControl)) { Optional fullQueryInfo = dispatchManager.get().getFullQueryInfo(queryInfo.getQueryId()); if (fullQueryInfo.isEmpty()) { continue; diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java index 745c485d8241..550c29f2de63 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryManager.java @@ -344,7 +344,7 @@ private void enforceMemoryLimits() List runningQueries = queryTracker.getAllQueries().stream() .filter(query -> query.getState() == RUNNING) .collect(toImmutableList()); - memoryManager.process(runningQueries, this::getQueries); + memoryManager.process(runningQueries, queryTracker::tryGetQuery); } /** diff --git a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryLeakDetector.java b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryLeakDetector.java index 3bc2372195a0..abdb10eebef0 100644 --- a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryLeakDetector.java +++ b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryLeakDetector.java @@ -14,21 +14,19 @@ package io.trino.memory; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; import com.google.errorprone.annotations.ThreadSafe; import com.google.errorprone.annotations.concurrent.GuardedBy; import io.airlift.log.Logger; -import io.trino.server.BasicQueryInfo; +import io.trino.execution.QueryExecution; +import io.trino.execution.QueryState; import io.trino.spi.QueryId; +import it.unimi.dsi.fastutil.objects.Object2LongMap; import java.time.Instant; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; -import java.util.function.Supplier; +import java.util.function.Function; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.execution.QueryState.RUNNING; import static java.time.Instant.now; import static java.util.Objects.requireNonNull; @@ -46,47 +44,48 @@ public class ClusterMemoryLeakDetector private Set leakedQueries; /** - * @param queryInfoSupplier All queries that the coordinator knows about. + * @param executionInfoSupplier Provided QueryId returns a QueryExecution if the query is still tracked by the coordinator. * @param queryMemoryReservations The memory reservations of queries in the cluster memory pool. */ - void checkForMemoryLeaks(Supplier> queryInfoSupplier, Map queryMemoryReservations) + void checkForMemoryLeaks(Function> executionInfoSupplier, Object2LongMap queryMemoryReservations) { - requireNonNull(queryInfoSupplier); requireNonNull(queryMemoryReservations); - Map queryIdToInfo = Maps.uniqueIndex(queryInfoSupplier.get(), BasicQueryInfo::getQueryId); - - Map leakedQueryReservations = queryMemoryReservations.entrySet() - .stream() - .filter(entry -> entry.getValue() > 0) - .filter(entry -> isLeaked(queryIdToInfo, entry.getKey())) - .collect(toImmutableMap(Entry::getKey, Entry::getValue)); + ImmutableSet.Builder leakedQueriesBuilder = ImmutableSet.builder(); + queryMemoryReservations.forEach((queryId, reservation) -> { + if (reservation > 0) { + if (isLeaked(executionInfoSupplier.apply(queryId))) { + leakedQueriesBuilder.add(queryId); + } + } + }); + Set leakedQueryReservations = leakedQueriesBuilder.build(); if (!leakedQueryReservations.isEmpty()) { log.debug("Memory leak detected. The following queries are already finished, " + "but they have memory reservations on some worker node(s): %s", leakedQueryReservations); } synchronized (this) { - leakedQueries = ImmutableSet.copyOf(leakedQueryReservations.keySet()); + leakedQueries = ImmutableSet.copyOf(leakedQueryReservations); } } - private static boolean isLeaked(Map queryIdToInfo, QueryId queryId) + private static boolean isLeaked(Optional execution) { - BasicQueryInfo queryInfo = queryIdToInfo.get(queryId); - - if (queryInfo == null) { + if (execution.isEmpty()) { + // We have a memory reservation but query isn't tracked return true; } - Instant queryEndTime = queryInfo.getQueryStats().getEndTime(); + Optional queryEndTime = execution.orElseThrow().getEndTime(); + QueryState state = execution.orElseThrow().getState(); - if (queryInfo.getState() == RUNNING || queryEndTime == null) { + if (state == RUNNING || queryEndTime.isEmpty()) { return false; } - return queryEndTime.plusSeconds(DEFAULT_LEAK_CLAIM_DELTA_SEC).isBefore(now()); + return queryEndTime.orElseThrow().plusSeconds(DEFAULT_LEAK_CLAIM_DELTA_SEC).isBefore(now()); } synchronized boolean wasQueryPossiblyLeaked(QueryId queryId) diff --git a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java index d118d0e56ebf..82144c603ced 100644 --- a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java +++ b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java @@ -18,8 +18,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.Streams; import com.google.common.io.Closer; import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Inject; @@ -30,6 +28,7 @@ import io.trino.execution.LocationFactory; import io.trino.execution.QueryExecution; import io.trino.execution.QueryInfo; +import io.trino.execution.QueryState; import io.trino.execution.TaskId; import io.trino.execution.TaskInfo; import io.trino.execution.scheduler.NodeSchedulerConfig; @@ -39,7 +38,6 @@ import io.trino.node.InternalNode; import io.trino.node.InternalNodeManager; import io.trino.operator.RetryPolicy; -import io.trino.server.BasicQueryInfo; import io.trino.server.ServerConfig; import io.trino.spi.QueryId; import io.trino.spi.TrinoException; @@ -51,6 +49,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -60,14 +59,13 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.function.Function; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.MoreCollectors.toOptional; import static com.google.common.collect.Sets.difference; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.units.DataSize.succinctBytes; @@ -174,18 +172,23 @@ public synchronized void addChangeListener(Consumer listener) changeListeners.add(listener); } - public synchronized void process(Iterable runningQueries, Supplier> allQueryInfoSupplier) + public synchronized void process(Collection allQueries, Function> executionInfoSupplier) { // TODO revocable memory reservations can also leak and may need to be detected in the future // We are only concerned about the leaks in the memory pool. - memoryLeakDetector.checkForMemoryLeaks(allQueryInfoSupplier, pool.getQueryMemoryReservations()); + memoryLeakDetector.checkForMemoryLeaks(executionInfoSupplier, pool.getQueryMemoryReservations()); boolean outOfMemory = isClusterOutOfMemory(); boolean queryKilled = false; long totalUserMemoryBytes = 0L; long totalMemoryBytes = 0L; - for (QueryExecution query : runningQueries) { + int queriesCount = 0; + for (QueryExecution query : allQueries) { + if (query.getState() != QueryState.RUNNING) { + continue; + } + queriesCount++; boolean resourceOvercommit = resourceOvercommit(query.getSession()); long userMemoryReservation = query.getUserMemoryReservation().toBytes(); long totalMemoryReservation = query.getTotalMemoryReservation().toBytes(); @@ -226,20 +229,21 @@ public synchronized void process(Iterable runningQueries, Suppli if (!lowMemoryKillers.isEmpty() && outOfMemory && !queryKilled) { if (isLastKillTargetGone()) { - callOomKiller(runningQueries); + callOomKiller(allQueries, executionInfoSupplier); } else { log.debug("Last killed target is still not gone: %s", lastKillTarget); } } - updateMemoryPool(Iterables.size(runningQueries)); + updateMemoryPool(queriesCount); updateNodes(); } - private synchronized void callOomKiller(Iterable runningQueries) + private synchronized void callOomKiller(Collection allQueries, Function> queryExecutionSupplier) { - List runningQueryInfos = Streams.stream(runningQueries) + List runningQueryInfos = allQueries + .stream() .map(this::createQueryMemoryInfo) .collect(toImmutableList()); @@ -257,7 +261,7 @@ private synchronized void callOomKiller(Iterable runningQueries) if (killTarget.get().isWholeQuery()) { QueryId queryId = killTarget.get().getQuery(); log.debug("Low memory killer chose %s", queryId); - Optional chosenQuery = findRunningQuery(runningQueries, killTarget.get().getQuery()); + Optional chosenQuery = queryExecutionSupplier.apply(killTarget.get().getQuery()); if (chosenQuery.isPresent()) { // See comments in isQueryGone for why chosenQuery might be absent. chosenQuery.get().fail(new TrinoException(CLUSTER_OUT_OF_MEMORY, "Query killed because the cluster is out of memory. Please try again in a few minutes.")); @@ -271,7 +275,7 @@ private synchronized void callOomKiller(Iterable runningQueries) log.debug("Low memory killer chose %s", tasks); ImmutableSet.Builder killedTasksBuilder = ImmutableSet.builder(); for (TaskId task : tasks) { - Optional runningQuery = findRunningQuery(runningQueries, task.queryId()); + Optional runningQuery = queryExecutionSupplier.apply(task.queryId()); if (runningQuery.isPresent()) { runningQuery.get().failTask(task, new TrinoException(CLUSTER_OUT_OF_MEMORY, "Task killed because the cluster is out of memory.")); tasksKilledDueToOutOfMemory.incrementAndGet(); @@ -343,11 +347,6 @@ private Set getRunningTasks() .collect(toImmutableSet()); } - private Optional findRunningQuery(Iterable runningQueries, QueryId queryId) - { - return Streams.stream(runningQueries).filter(query -> queryId.equals(query.getQueryId())).collect(toOptional()); - } - private void logQueryKill(QueryId killedQueryId, Map nodeMemoryInfosByNode) { if (!log.isInfoEnabled()) { diff --git a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryPool.java b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryPool.java index 901748099a09..cd9263e29b78 100644 --- a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryPool.java +++ b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryPool.java @@ -19,14 +19,17 @@ import io.trino.spi.QueryId; import io.trino.spi.memory.MemoryAllocation; import io.trino.spi.memory.MemoryPoolInfo; +import it.unimi.dsi.fastutil.objects.Object2LongMap; +import it.unimi.dsi.fastutil.objects.Object2LongMaps; +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap; import org.weakref.jmx.Managed; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @ThreadSafe @@ -52,14 +55,11 @@ public class ClusterMemoryPool // Does not include queries with zero memory usage @GuardedBy("this") - private final Map queryMemoryReservations = new HashMap<>(); + private final Object2LongMap queryMemoryReservations = new Object2LongOpenHashMap<>(); @GuardedBy("this") private final Map> queryMemoryAllocations = new HashMap<>(); - @GuardedBy("this") - private final Map queryMemoryRevocableReservations = new HashMap<>(); - public synchronized MemoryPoolInfo getInfo() { return new MemoryPoolInfo( @@ -68,7 +68,6 @@ public synchronized MemoryPoolInfo getInfo() reservedRevocableDistributedBytes, ImmutableMap.copyOf(queryMemoryReservations), ImmutableMap.copyOf(queryMemoryAllocations), - ImmutableMap.copyOf(queryMemoryRevocableReservations), // not providing per-task memory info for cluster-wide pool ImmutableMap.of(), ImmutableMap.of()); @@ -116,14 +115,9 @@ public synchronized int getAssignedQueries() return assignedQueries; } - public synchronized Map getQueryMemoryReservations() - { - return ImmutableMap.copyOf(queryMemoryReservations); - } - - public synchronized Map getQueryMemoryRevocableReservations() + public synchronized Object2LongMap getQueryMemoryReservations() { - return ImmutableMap.copyOf(queryMemoryRevocableReservations); + return Object2LongMaps.unmodifiable(queryMemoryReservations); } public synchronized void update(List memoryInfos, int assignedQueries) @@ -136,7 +130,6 @@ public synchronized void update(List memoryInfos, int assignedQuerie this.assignedQueries = assignedQueries; this.queryMemoryReservations.clear(); this.queryMemoryAllocations.clear(); - this.queryMemoryRevocableReservations.clear(); for (MemoryInfo info : memoryInfos) { MemoryPoolInfo poolInfo = info.getPool(); @@ -148,14 +141,11 @@ public synchronized void update(List memoryInfos, int assignedQuerie reservedDistributedBytes += poolInfo.getReservedBytes(); reservedRevocableDistributedBytes += poolInfo.getReservedRevocableBytes(); for (Map.Entry entry : poolInfo.getQueryMemoryReservations().entrySet()) { - queryMemoryReservations.merge(entry.getKey(), entry.getValue(), Long::sum); + queryMemoryReservations.mergeLong(entry.getKey(), entry.getValue(), Long::sum); } for (Map.Entry> entry : poolInfo.getQueryMemoryAllocations().entrySet()) { queryMemoryAllocations.merge(entry.getKey(), entry.getValue(), this::mergeQueryAllocations); } - for (Map.Entry entry : poolInfo.getQueryMemoryRevocableReservations().entrySet()) { - queryMemoryRevocableReservations.merge(entry.getKey(), entry.getValue(), Long::sum); - } } } @@ -164,20 +154,22 @@ private List mergeQueryAllocations(List left requireNonNull(left, "left is null"); requireNonNull(right, "right is null"); - Map mergedAllocations = new HashMap<>(); + Object2LongMap mergedAllocations = new Object2LongOpenHashMap<>(); for (MemoryAllocation allocation : left) { - mergedAllocations.put(allocation.getTag(), allocation); + mergedAllocations.put(allocation.tag(), allocation.allocation()); } for (MemoryAllocation allocation : right) { - mergedAllocations.merge( - allocation.getTag(), - allocation, - (a, b) -> new MemoryAllocation(a.getTag(), a.getAllocation() + b.getAllocation())); + mergedAllocations.mergeLong( + allocation.tag(), + allocation.allocation(), + Long::sum); } - return new ArrayList<>(mergedAllocations.values()); + return mergedAllocations.object2LongEntrySet().stream() + .map(entry -> new MemoryAllocation(entry.getKey(), entry.getLongValue())) + .collect(toImmutableList()); } @Override @@ -193,7 +185,6 @@ public synchronized String toString() .add("assignedQueries", assignedQueries) .add("queryMemoryReservations", queryMemoryReservations) .add("queryMemoryAllocations", queryMemoryAllocations) - .add("queryMemoryRevocableReservations", queryMemoryRevocableReservations) .toString(); } } diff --git a/core/trino-main/src/main/java/io/trino/memory/MemoryPool.java b/core/trino-main/src/main/java/io/trino/memory/MemoryPool.java index b6cf4c315182..0941f4d00491 100644 --- a/core/trino-main/src/main/java/io/trino/memory/MemoryPool.java +++ b/core/trino-main/src/main/java/io/trino/memory/MemoryPool.java @@ -23,20 +23,22 @@ import io.trino.spi.QueryId; import io.trino.spi.memory.MemoryAllocation; import io.trino.spi.memory.MemoryPoolInfo; +import it.unimi.dsi.fastutil.objects.Object2LongMap; +import it.unimi.dsi.fastutil.objects.Object2LongMaps; +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap; import jakarta.annotation.Nullable; import org.weakref.jmx.Managed; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.operator.Operator.NOT_BLOCKED; import static java.util.Objects.requireNonNull; @@ -56,20 +58,17 @@ public class MemoryPool // TODO: It would be better if we just tracked QueryContexts, but their lifecycle is managed by a weak reference, so we can't do that // It is guarded for updates by this, but can be read without holding a lock - private final Map queryMemoryReservations = new ConcurrentHashMap<>(); + private final Object2LongMap queryMemoryReservations = new Object2LongOpenHashMap<>(); // This map keeps track of all the tagged allocations, e.g., query-1 -> ['TableScanOperator': 10MB, 'LazyOutputBuffer': 5MB, ...] @GuardedBy("this") - private final Map> taggedMemoryAllocations = new HashMap<>(); + private final Map> taggedMemoryAllocations = new HashMap<>(); @GuardedBy("this") - private final Map queryRevocableMemoryReservations = new HashMap<>(); + private final Object2LongMap taskMemoryReservations = new Object2LongOpenHashMap<>(); @GuardedBy("this") - private final Map taskMemoryReservations = new HashMap<>(); - - @GuardedBy("this") - private final Map taskRevocableMemoryReservations = new HashMap<>(); + private final Object2LongMap taskRevocableMemoryReservations = new Object2LongOpenHashMap<>(); private final List listeners = new CopyOnWriteArrayList<>(); @@ -81,21 +80,20 @@ public MemoryPool(DataSize size) public synchronized MemoryPoolInfo getInfo() { - Map> memoryAllocations = new HashMap<>(); - for (Entry> entry : taggedMemoryAllocations.entrySet()) { - List allocations = new ArrayList<>(); - if (entry.getValue() != null) { - entry.getValue().forEach((tag, allocation) -> allocations.add(new MemoryAllocation(tag, allocation))); - } + ImmutableMap.Builder> memoryAllocations = ImmutableMap.builder(); + for (Entry> entry : taggedMemoryAllocations.entrySet()) { + List allocations = entry.getValue().object2LongEntrySet().stream() + .map(allocation -> new MemoryAllocation(allocation.getKey(), allocation.getLongValue())) + .collect(toImmutableList()); memoryAllocations.put(entry.getKey(), allocations); } - Map stringKeyedTaskMemoryReservations = taskMemoryReservations.entrySet().stream() + Map stringKeyedTaskMemoryReservations = taskMemoryReservations.object2LongEntrySet().stream() .collect(toImmutableMap( entry -> entry.getKey().toString(), Entry::getValue)); - Map stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations.entrySet().stream() + Map stringKeyedTaskRevocableMemoryReservations = taskRevocableMemoryReservations.object2LongEntrySet().stream() .collect(toImmutableMap( entry -> entry.getKey().toString(), Entry::getValue)); @@ -105,8 +103,7 @@ public synchronized MemoryPoolInfo getInfo() reservedBytes, reservedRevocableBytes, queryMemoryReservations, - memoryAllocations, - queryRevocableMemoryReservations, + memoryAllocations.buildOrThrow(), stringKeyedTaskMemoryReservations, stringKeyedTaskRevocableMemoryReservations); } @@ -131,9 +128,9 @@ public ListenableFuture reserve(TaskId taskId, String allocationTag, long synchronized (this) { if (bytes != 0) { QueryId queryId = taskId.queryId(); - queryMemoryReservations.merge(queryId, bytes, Long::sum); + queryMemoryReservations.mergeLong(queryId, bytes, Long::sum); updateTaggedMemoryAllocations(queryId, allocationTag, bytes); - taskMemoryReservations.merge(taskId, bytes, Long::sum); + taskMemoryReservations.mergeLong(taskId, bytes, Long::sum); } reservedBytes += bytes; if (getFreeBytes() <= 0) { @@ -164,7 +161,6 @@ public ListenableFuture reserveRevocable(TaskId taskId, long bytes) ListenableFuture result; synchronized (this) { if (bytes != 0) { - queryRevocableMemoryReservations.merge(taskId.queryId(), bytes, Long::sum); taskRevocableMemoryReservations.merge(taskId, bytes, Long::sum); } reservedRevocableBytes += bytes; @@ -197,9 +193,9 @@ public boolean tryReserve(TaskId taskId, String allocationTag, long bytes) reservedBytes += bytes; if (bytes != 0) { QueryId queryId = taskId.queryId(); - queryMemoryReservations.merge(queryId, bytes, Long::sum); + queryMemoryReservations.mergeLong(queryId, bytes, Long::sum); updateTaggedMemoryAllocations(queryId, allocationTag, bytes); - taskMemoryReservations.merge(taskId, bytes, Long::sum); + taskMemoryReservations.mergeLong(taskId, bytes, Long::sum); } } @@ -231,17 +227,15 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes) } QueryId queryId = taskId.queryId(); - Long queryReservation = queryMemoryReservations.get(queryId); - requireNonNull(queryReservation, "queryReservation is null"); + long queryReservation = queryMemoryReservations.getLong(queryId); checkArgument(queryReservation >= bytes, "tried to free more memory than is reserved by query"); - Long taskReservation = taskMemoryReservations.get(taskId); - requireNonNull(taskReservation, "taskReservation is null"); + long taskReservation = taskMemoryReservations.getLong(taskId); checkArgument(taskReservation >= bytes, "tried to free more memory than is reserved by task"); queryReservation -= bytes; if (queryReservation == 0) { - queryMemoryReservations.remove(queryId); + queryMemoryReservations.removeLong(queryId); taggedMemoryAllocations.remove(queryId); } else { @@ -251,7 +245,7 @@ public synchronized void free(TaskId taskId, String allocationTag, long bytes) taskReservation -= bytes; if (taskReservation == 0) { - taskMemoryReservations.remove(taskId); + taskMemoryReservations.removeLong(taskId); } else { taskMemoryReservations.put(taskId, taskReservation); @@ -273,26 +267,12 @@ public synchronized void freeRevocable(TaskId taskId, long bytes) return; } - QueryId queryId = taskId.queryId(); - Long queryReservation = queryRevocableMemoryReservations.get(queryId); - requireNonNull(queryReservation, "queryReservation is null"); - checkArgument(queryReservation >= bytes, "tried to free more revocable memory than is reserved by query"); - - Long taskReservation = taskRevocableMemoryReservations.get(taskId); - requireNonNull(taskReservation, "taskReservation is null"); + long taskReservation = taskRevocableMemoryReservations.getLong(taskId); checkArgument(taskReservation >= bytes, "tried to free more revocable memory than is reserved by task"); - queryReservation -= bytes; - if (queryReservation == 0) { - queryRevocableMemoryReservations.remove(queryId); - } - else { - queryRevocableMemoryReservations.put(queryId, queryReservation); - } - taskReservation -= bytes; if (taskReservation == 0) { - taskRevocableMemoryReservations.remove(taskId); + taskRevocableMemoryReservations.removeLong(taskId); } else { taskRevocableMemoryReservations.put(taskId, taskReservation); @@ -353,12 +333,6 @@ long getQueryMemoryReservation(QueryId queryId) return queryMemoryReservations.getOrDefault(queryId, 0L); } - @VisibleForTesting - synchronized long getQueryRevocableMemoryReservation(QueryId queryId) - { - return queryRevocableMemoryReservations.getOrDefault(queryId, 0L); - } - @VisibleForTesting synchronized long getTaskMemoryReservation(TaskId taskId) { @@ -410,12 +384,12 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String return; } - Map allocations = taggedMemoryAllocations.computeIfAbsent(queryId, _ -> new HashMap<>()); - allocations.compute(allocationTag, (ignored, oldValue) -> { + Object2LongMap allocations = taggedMemoryAllocations.computeIfAbsent(queryId, _ -> new Object2LongOpenHashMap<>()); + allocations.computeLong(allocationTag, (_, oldValue) -> { if (oldValue == null) { return delta; } - long newValue = oldValue.longValue() + delta; + long newValue = oldValue + delta; if (newValue == 0) { return null; } @@ -424,9 +398,9 @@ private synchronized void updateTaggedMemoryAllocations(QueryId queryId, String } @VisibleForTesting - public synchronized Map getQueryMemoryReservations() + public synchronized Object2LongMap getQueryMemoryReservations() { - return ImmutableMap.copyOf(queryMemoryReservations); + return Object2LongMaps.unmodifiable(queryMemoryReservations); } @VisibleForTesting @@ -436,20 +410,14 @@ public synchronized Map> getTaggedMemoryAllocations() } @VisibleForTesting - public synchronized Map getQueryRevocableMemoryReservations() - { - return ImmutableMap.copyOf(queryRevocableMemoryReservations); - } - - @VisibleForTesting - public synchronized Map getTaskMemoryReservations() + public synchronized Object2LongMap getTaskMemoryReservations() { - return ImmutableMap.copyOf(taskMemoryReservations); + return Object2LongMaps.unmodifiable(taskMemoryReservations); } @VisibleForTesting - public synchronized Map getTaskRevocableMemoryReservations() + public synchronized Object2LongMap getTaskRevocableMemoryReservations() { - return ImmutableMap.copyOf(taskRevocableMemoryReservations); + return Object2LongMaps.unmodifiable(taskRevocableMemoryReservations); } } diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlUtil.java b/core/trino-main/src/main/java/io/trino/security/AccessControlUtil.java index 8e74c5ca5c0a..fa730303bbe5 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlUtil.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlUtil.java @@ -34,7 +34,7 @@ public static void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwn accessControl.checkCanViewQueryOwnedBy(identity, queryOwner); } - public static List filterQueries(Identity identity, List queries, AccessControl accessControl) + public static List filterQueries(Collection queries, Identity identity, AccessControl accessControl) { Collection owners = queries.stream() .map(BasicQueryInfo::getSession) diff --git a/core/trino-main/src/main/java/io/trino/server/QueryResource.java b/core/trino-main/src/main/java/io/trino/server/QueryResource.java index 46d71f8ea1b3..dbb9b5317065 100644 --- a/core/trino-main/src/main/java/io/trino/server/QueryResource.java +++ b/core/trino-main/src/main/java/io/trino/server/QueryResource.java @@ -80,9 +80,7 @@ public List getAllQueryInfo(@QueryParam("state") Set sta .map(QueryState::valueOf) .collect(toImmutableSet()); - List queries = dispatchManager.getQueries(); - queries = filterQueries(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders), queries, accessControl); - + List queries = filterQueries(dispatchManager.getQueries(), sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders), accessControl); if (expectedStates.isEmpty()) { return queries; } diff --git a/core/trino-main/src/main/java/io/trino/server/ui/UiQueryResource.java b/core/trino-main/src/main/java/io/trino/server/ui/UiQueryResource.java index 35c13102b2ef..ef5aa6fa17b1 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/UiQueryResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/UiQueryResource.java @@ -89,9 +89,7 @@ public UiQueryResource(ObjectMapper objectMapper, DispatchManager dispatchManage public List getAllQueryInfo(@QueryParam("state") String stateFilter, @Context HttpServletRequest servletRequest, @Context HttpHeaders httpHeaders) { QueryState expectedState = stateFilter == null ? null : QueryState.valueOf(stateFilter.toUpperCase(Locale.ENGLISH)); - - List queries = dispatchManager.getQueries(); - queries = filterQueries(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders), queries, accessControl); + List queries = filterQueries(dispatchManager.getQueries(), sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders), accessControl); ImmutableList.Builder builder = ImmutableList.builder(); for (BasicQueryInfo queryInfo : queries) { diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/BenchmarkBinPackingNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/BenchmarkBinPackingNodeAllocator.java index b97743a39175..908d2457b03f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/BenchmarkBinPackingNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/BenchmarkBinPackingNodeAllocator.java @@ -190,8 +190,7 @@ private MemoryInfo buildWorkerMemoryInfo(DataSize usedMemory, Map entry.getKey().toString(), - entry -> entry.getValue().toBytes())), - ImmutableMap.of())); + entry -> entry.getValue().toBytes())))); } private void assertAcquired(NodeAllocator.NodeLease lease) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestBinPackingNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestBinPackingNodeAllocator.java index f5de85da004b..3c03b0ef948f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestBinPackingNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestBinPackingNodeAllocator.java @@ -138,7 +138,6 @@ private MemoryInfo buildWorkerMemoryInfo(DataSize usedMemory, Map entry.getKey().toString(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java index 30149806a8e0..b9e65b692fbe 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java @@ -287,7 +287,6 @@ private MemoryInfo buildWorkerMemoryInfo(DataSize usedMemory) ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), - ImmutableMap.of(), ImmutableMap.of())); } } diff --git a/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java b/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java index 6c73d3b781a2..b21a5f6adcc5 100644 --- a/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java +++ b/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java @@ -70,7 +70,6 @@ static List toNodeMemoryInfoList(long memoryPoolMaxBytes, Map testPool.freeRevocable(q1task1, 9)) .hasMessage("tried to free more revocable memory than is reserved by task"); - assertThat(testPool.getQueryRevocableMemoryReservations().keySet()).hasSize(2); - assertThat(testPool.getQueryRevocableMemoryReservation(query1)).isEqualTo(15L); - assertThat(testPool.getQueryRevocableMemoryReservation(query2)).isEqualTo(9L); assertThat(testPool.getTaskRevocableMemoryReservations().keySet()).hasSize(3); assertThat(testPool.getTaskRevocableMemoryReservation(q1task1)).isEqualTo(8L); assertThat(testPool.getTaskRevocableMemoryReservation(q1task2)).isEqualTo(7L); @@ -400,9 +385,6 @@ void testPerTaskRevocableAllocations() // zero memory for one of the tasks testPool.freeRevocable(q1task1, 8); - assertThat(testPool.getQueryRevocableMemoryReservations().keySet()).hasSize(2); - assertThat(testPool.getQueryRevocableMemoryReservation(query1)).isEqualTo(7L); - assertThat(testPool.getQueryRevocableMemoryReservation(query2)).isEqualTo(9L); assertThat(testPool.getTaskRevocableMemoryReservations().keySet()).hasSize(2); assertThat(testPool.getTaskRevocableMemoryReservation(q1task1)).isEqualTo(0L); assertThat(testPool.getTaskRevocableMemoryReservation(q1task2)).isEqualTo(7L); @@ -410,9 +392,6 @@ void testPerTaskRevocableAllocations() // zero memory for all query the tasks testPool.freeRevocable(q1task2, 7); - assertThat(testPool.getQueryRevocableMemoryReservations().keySet()).hasSize(1); - assertThat(testPool.getQueryRevocableMemoryReservation(query1)).isEqualTo(0L); - assertThat(testPool.getQueryRevocableMemoryReservation(query2)).isEqualTo(9L); assertThat(testPool.getTaskRevocableMemoryReservations().keySet()).hasSize(1); assertThat(testPool.getTaskRevocableMemoryReservation(q1task1)).isEqualTo(0L); assertThat(testPool.getTaskRevocableMemoryReservation(q1task2)).isEqualTo(0L); diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 14ab18f30ce2..b5e5dc1876c0 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -297,6 +297,83 @@ method void io.trino.spi.block.PageBuilderStatus::<init>() Method is unnecessary + + java.annotation.attributeValueChanged + parameter void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, ===java.util.Map<io.trino.spi.QueryId, java.lang.Long>===, java.util.Map<java.lang.String, java.lang.Long>, java.util.Map<java.lang.String, java.lang.Long>) + parameter void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, ===java.util.Map<java.lang.String, java.lang.Long>===, java.util.Map<java.lang.String, java.lang.Long>) + com.fasterxml.jackson.annotation.JsonProperty + value + "queryMemoryRevocableReservations" + "taskMemoryReservations" + Removed unused field + + + true + java.annotation.attributeValueChanged + parameter void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, ===java.util.Map<java.lang.String, java.lang.Long>===, java.util.Map<java.lang.String, java.lang.Long>) + parameter void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, java.util.Map<java.lang.String, java.lang.Long>, ===java.util.Map<java.lang.String, java.lang.Long>===) + com.fasterxml.jackson.annotation.JsonProperty + value + "taskMemoryReservations" + "taskMemoryRevocableReservations" + Removed unused field + + + true + java.method.numberOfParametersChanged + method void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<java.lang.String, java.lang.Long>, java.util.Map<java.lang.String, java.lang.Long>) + method void io.trino.spi.memory.MemoryPoolInfo::<init>(long, long, long, java.util.Map<io.trino.spi.QueryId, java.lang.Long>, java.util.Map<io.trino.spi.QueryId, java.util.List<io.trino.spi.memory.MemoryAllocation>>, java.util.Map<java.lang.String, java.lang.Long>, java.util.Map<java.lang.String, java.lang.Long>) + Removed unused field + + + true + java.method.removed + method java.util.Map<io.trino.spi.QueryId, java.lang.Long> io.trino.spi.memory.MemoryPoolInfo::getQueryMemoryRevocableReservations() + Removed unused field + + + true + java.annotation.removed + parameter void io.trino.spi.memory.MemoryAllocation::<init>(===java.lang.String===, long) + parameter void io.trino.spi.memory.MemoryAllocation::<init>(===java.lang.String===, long) + @com.fasterxml.jackson.annotation.JsonProperty("tag") + Class converted to a record + + + true + java.annotation.removed + parameter void io.trino.spi.memory.MemoryAllocation::<init>(java.lang.String, ===long===) + parameter void io.trino.spi.memory.MemoryAllocation::<init>(java.lang.String, ===long===) + @com.fasterxml.jackson.annotation.JsonProperty("allocation") + Class converted to a record + + + true + java.annotation.removed + method void io.trino.spi.memory.MemoryAllocation::<init>(java.lang.String, long) + method void io.trino.spi.memory.MemoryAllocation::<init>(java.lang.String, long) + @com.fasterxml.jackson.annotation.JsonCreator + Class converted to a record + + + true + java.method.removed + method long io.trino.spi.memory.MemoryAllocation::getAllocation() + Class converted to a record + + + true + java.method.removed + method java.lang.String io.trino.spi.memory.MemoryAllocation::getTag() + Class converted to a record + + + true + java.class.kindChanged + class io.trino.spi.memory.MemoryAllocation + class io.trino.spi.memory.MemoryAllocation + Class converted to a record + diff --git a/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryAllocation.java b/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryAllocation.java index 6bbaa2d1027f..001f30ca835b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryAllocation.java +++ b/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryAllocation.java @@ -13,34 +13,12 @@ */ package io.trino.spi.memory; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - import static java.util.Objects.requireNonNull; -public final class MemoryAllocation +public record MemoryAllocation(String tag, long allocation) { - private final String tag; - private final long allocation; - - @JsonCreator - public MemoryAllocation( - @JsonProperty("tag") String tag, - @JsonProperty("allocation") long allocation) - { - this.tag = requireNonNull(tag, "tag is null"); - this.allocation = allocation; - } - - @JsonProperty - public String getTag() - { - return tag; - } - - @JsonProperty - public long getAllocation() + public MemoryAllocation { - return allocation; + requireNonNull(tag, "tag is null"); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryPoolInfo.java b/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryPoolInfo.java index 59ff3f0e7541..99c4a8e0566f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryPoolInfo.java +++ b/core/trino-spi/src/main/java/io/trino/spi/memory/MemoryPoolInfo.java @@ -29,7 +29,6 @@ public final class MemoryPoolInfo private final long reservedRevocableBytes; private final Map queryMemoryReservations; private final Map> queryMemoryAllocations; - private final Map queryMemoryRevocableReservations; private final Map taskMemoryReservations; private final Map taskMemoryRevocableReservations; @@ -40,7 +39,6 @@ public MemoryPoolInfo( @JsonProperty("reservedRevocableBytes") long reservedRevocableBytes, @JsonProperty("queryMemoryReservations") Map queryMemoryReservations, @JsonProperty("queryMemoryAllocations") Map> queryMemoryAllocations, - @JsonProperty("queryMemoryRevocableReservations") Map queryMemoryRevocableReservations, @JsonProperty("taskMemoryReservations") Map taskMemoryReservations, @JsonProperty("taskMemoryRevocableReservations") Map taskMemoryRevocableReservations) { @@ -49,7 +47,6 @@ public MemoryPoolInfo( this.reservedRevocableBytes = reservedRevocableBytes; this.queryMemoryReservations = Map.copyOf(queryMemoryReservations); this.queryMemoryAllocations = Map.copyOf(queryMemoryAllocations); - this.queryMemoryRevocableReservations = Map.copyOf(queryMemoryRevocableReservations); this.taskMemoryReservations = Map.copyOf(taskMemoryReservations); this.taskMemoryRevocableReservations = Map.copyOf(taskMemoryRevocableReservations); } @@ -90,12 +87,6 @@ public Map> getQueryMemoryAllocations() return queryMemoryAllocations; } - @JsonProperty - public Map getQueryMemoryRevocableReservations() - { - return queryMemoryRevocableReservations; - } - @JsonProperty public Map getTaskMemoryReservations() { diff --git a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/TestFileResourceGroupConfigurationManager.java b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/TestFileResourceGroupConfigurationManager.java index 62c3165636c4..1eafcf758827 100644 --- a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/TestFileResourceGroupConfigurationManager.java +++ b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/TestFileResourceGroupConfigurationManager.java @@ -250,7 +250,7 @@ public void testExtractVariableConfiguration() public void testDocsExample() { FileResourceGroupConfigurationManager manager = new FileResourceGroupConfigurationManager( - (listener) -> listener.accept(new MemoryPoolInfo(MEMORY_POOL_SIZE, 0, 0, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of())), + (listener) -> listener.accept(new MemoryPoolInfo(MEMORY_POOL_SIZE, 0, 0, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of())), new FileResourceGroupConfig() // TODO: figure out a better way to validate documentation .setConfigFile("../../docs/src/main/sphinx/admin/resource-groups-example.json")); @@ -326,7 +326,7 @@ private static FileResourceGroupConfigurationManager parse(String fileName) FileResourceGroupConfig config = new FileResourceGroupConfig(); config.setConfigFile(getResource(fileName).getPath()); return new FileResourceGroupConfigurationManager( - listener -> listener.accept(new MemoryPoolInfo(MEMORY_POOL_SIZE, 0, 0, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of())), + listener -> listener.accept(new MemoryPoolInfo(MEMORY_POOL_SIZE, 0, 0, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of())), config); } diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index 520c9485e0f0..b35f885f2a23 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -183,6 +183,11 @@ tpch + + it.unimi.dsi + fastutil + + jakarta.annotation jakarta.annotation-api diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 3a6f49ccc90c..a8f393959d12 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -321,6 +321,12 @@ test + + it.unimi.dsi + fastutil + test + + joda-time joda-time diff --git a/testing/trino-tests/src/test/java/io/trino/memory/TestClusterMemoryLeakDetector.java b/testing/trino-tests/src/test/java/io/trino/memory/TestClusterMemoryLeakDetector.java index 79bd2325fd5b..9033b174cf31 100644 --- a/testing/trino-tests/src/test/java/io/trino/memory/TestClusterMemoryLeakDetector.java +++ b/testing/trino-tests/src/test/java/io/trino/memory/TestClusterMemoryLeakDetector.java @@ -13,29 +13,31 @@ */ package io.trino.memory; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.execution.QueryExecution; +import io.trino.execution.QueryInfo; import io.trino.execution.QueryState; -import io.trino.operator.RetryPolicy; +import io.trino.execution.StageId; +import io.trino.execution.StateMachine; +import io.trino.execution.TaskId; import io.trino.server.BasicQueryInfo; -import io.trino.server.BasicQueryStats; +import io.trino.server.ResultQueryInfo; +import io.trino.server.protocol.Slug; import io.trino.spi.QueryId; -import io.trino.spi.resourcegroups.ResourceGroupId; +import io.trino.sql.planner.Plan; +import it.unimi.dsi.fastutil.objects.Object2LongArrayMap; +import it.unimi.dsi.fastutil.objects.Object2LongMaps; import org.junit.jupiter.api.Test; -import java.net.URI; import java.time.Instant; import java.util.Optional; -import java.util.OptionalDouble; +import java.util.function.Consumer; -import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.execution.QueryState.FINISHED; import static io.trino.execution.QueryState.RUNNING; -import static io.trino.operator.BlockedReason.WAITING_FOR_MEMORY; -import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; public class TestClusterMemoryLeakDetector @@ -46,77 +48,224 @@ public void testLeakDetector() QueryId testQuery = new QueryId("test"); ClusterMemoryLeakDetector leakDetector = new ClusterMemoryLeakDetector(); - leakDetector.checkForMemoryLeaks(ImmutableList::of, ImmutableMap.of()); + leakDetector.checkForMemoryLeaks(_ -> Optional.empty(), new Object2LongArrayMap<>()); assertThat(leakDetector.getNumberOfLeakedQueries()).isEqualTo(0); // the leak detector should report no leaked queries as the query is still running - leakDetector.checkForMemoryLeaks(() -> ImmutableList.of(createQueryInfo(testQuery.id(), RUNNING)), ImmutableMap.of(testQuery, 1L)); + leakDetector.checkForMemoryLeaks(queryId -> Optional.of(createQueryExecution(queryId, testQuery, RUNNING)), Object2LongMaps.singleton(testQuery, 1L)); assertThat(leakDetector.getNumberOfLeakedQueries()).isEqualTo(0); // the leak detector should report exactly one leaked query since the query is finished, and its end time is way in the past - leakDetector.checkForMemoryLeaks(() -> ImmutableList.of(createQueryInfo(testQuery.id(), FINISHED)), ImmutableMap.of(testQuery, 1L)); + leakDetector.checkForMemoryLeaks(queryId -> Optional.of(createQueryExecution(queryId, testQuery, FINISHED)), Object2LongMaps.singleton(testQuery, 1L)); assertThat(leakDetector.getNumberOfLeakedQueries()).isEqualTo(1); // the leak detector should report no leaked queries as the query doesn't have any memory reservation - leakDetector.checkForMemoryLeaks(() -> ImmutableList.of(createQueryInfo(testQuery.id(), FINISHED)), ImmutableMap.of(testQuery, 0L)); + leakDetector.checkForMemoryLeaks(queryId -> Optional.of(createQueryExecution(queryId, testQuery, FINISHED)), Object2LongMaps.singleton(testQuery, 0L)); assertThat(leakDetector.getNumberOfLeakedQueries()).isEqualTo(0); // the leak detector should report exactly one leaked query since the coordinator doesn't know of any query - leakDetector.checkForMemoryLeaks(ImmutableList::of, ImmutableMap.of(testQuery, 1L)); + leakDetector.checkForMemoryLeaks(_ -> Optional.empty(), Object2LongMaps.singleton(testQuery, 1L)); assertThat(leakDetector.getNumberOfLeakedQueries()).isEqualTo(1); } - private static BasicQueryInfo createQueryInfo(String queryId, QueryState state) + private static QueryExecution createQueryExecution(QueryId queryId, QueryId expectedQueryId, QueryState state) { - return new BasicQueryInfo( - new QueryId(queryId), - TEST_SESSION.toSessionRepresentation(), - Optional.of(new ResourceGroupId("global")), - state, - true, - URI.create("1"), - "", - Optional.empty(), - Optional.empty(), - new BasicQueryStats( - Instant.parse("2025-05-11T13:32:17.751968Z"), - Instant.parse("2025-05-11T13:32:17.751968Z"), - new Duration(8, MINUTES), - new Duration(7, MINUTES), - new Duration(9, MINUTES), - new Duration(34, MINUTES), - 99, - 13, - 14, - 15, - 100, - 0, - 22, - DataSize.valueOf("23GB"), - DataSize.valueOf("23GB"), - DataSize.valueOf("23GB"), - DataSize.valueOf("23GB"), - 24, - 25, - DataSize.valueOf("26GB"), - DataSize.valueOf("27GB"), - DataSize.valueOf("28GB"), - DataSize.valueOf("29GB"), - new Duration(30, MINUTES), - new Duration(31, MINUTES), - new Duration(32, MINUTES), - new Duration(33, MINUTES), - new Duration(34, MINUTES), - new Duration(35, MINUTES), - new Duration(36, MINUTES), - new Duration(37, MINUTES), - true, - ImmutableSet.of(WAITING_FOR_MEMORY), - OptionalDouble.of(20), - OptionalDouble.of(0)), - null, - null, - Optional.empty(), - RetryPolicy.NONE); + assertThat(queryId).isEqualTo(expectedQueryId); + + // This ensures that only expected methods are ever called by the leak detector + return new QueryExecution() + { + @Override + public QueryId getQueryId() + { + return queryId; + } + + @Override + public boolean isDone() + { + return state.isDone(); + } + + @Override + public QueryState getState() + { + return state; + } + + @Override + public Instant getCreateTime() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getEndTime() + { + return Optional.of(Instant.parse("2025-05-11T13:32:17.751968Z")); + } + + @Override + public Session getSession() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getExecutionStartTime() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getPlanningTime() + { + throw new UnsupportedOperationException(); + } + + @Override + public Instant getLastHeartbeat() + { + throw new UnsupportedOperationException(); + } + + @Override + public void fail(Throwable cause) + { + throw new UnsupportedOperationException(); + } + + @Override + public void pruneInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isInfoPruned() + { + throw new UnsupportedOperationException(); + } + + @Override + public ListenableFuture getStateChange(QueryState currentState) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addStateChangeListener(StateMachine.StateChangeListener stateChangeListener) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setOutputInfoListener(Consumer listener) + { + throw new UnsupportedOperationException(); + } + + @Override + public void outputTaskFailed(TaskId taskId, Throwable failure) + { + throw new UnsupportedOperationException(); + } + + @Override + public void resultsConsumed() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getQueryPlan() + { + throw new UnsupportedOperationException(); + } + + @Override + public BasicQueryInfo getBasicQueryInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryInfo getQueryInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public ResultQueryInfo getResultQueryInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public Slug getSlug() + { + throw new UnsupportedOperationException(); + } + + @Override + public Duration getTotalCpuTime() + { + throw new UnsupportedOperationException(); + } + + @Override + public DataSize getUserMemoryReservation() + { + throw new UnsupportedOperationException(); + } + + @Override + public DataSize getTotalMemoryReservation() + { + throw new UnsupportedOperationException(); + } + + @Override + public void start() + { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelQuery() + { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelStage(StageId stageId) + { + throw new UnsupportedOperationException(); + } + + @Override + public void failTask(TaskId taskId, Exception reason) + { + throw new UnsupportedOperationException(); + } + + @Override + public void recordHeartbeat() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean shouldWaitForMinWorkers() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addFinalQueryInfoListener(StateMachine.StateChangeListener stateChangeListener) + { + throw new UnsupportedOperationException(); + } + }; } }