diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b49298192800..666def90ce672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - Fix for hasInitiatedFetching to fix allocation explain and manual reroute APIs (([#14972](https://github.com/opensearch-project/OpenSearch/pull/14972)) +- [Workload Management] QueryGroup resource tracking framework changes ([#13897](https://github.com/opensearch-project/OpenSearch/pull/13897)) - [Workload Management] Add queryGroupId to Task ([14708](https://github.com/opensearch-project/OpenSearch/pull/14708)) - Add setting to ignore throttling nodes for allocation of unassigned primaries in remote restore ([#14991](https://github.com/opensearch-project/OpenSearch/pull/14991)) - Add basic aggregation support for derived fields ([#14618](https://github.com/opensearch-project/OpenSearch/pull/14618)) diff --git a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java index 440b9e267cf0a..09bef2ddf9ee6 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java @@ -1391,7 +1391,7 @@ public Builder put(final QueryGroup queryGroup) { return queryGroups(existing); } - private Map getQueryGroups() { + public Map getQueryGroups() { return Optional.ofNullable(this.customs.get(QueryGroupMetadata.TYPE)) .map(o -> (QueryGroupMetadata) o) .map(QueryGroupMetadata::queryGroups) diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupHelper.java b/server/src/main/java/org/opensearch/wlm/QueryGroupHelper.java new file mode 100644 index 0000000000000..13b87b5010bf3 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupHelper.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; + +import java.util.Map; +import java.util.function.Function; + +/** + * Helper class for calculating resource usage for different resource types. + */ +public class QueryGroupHelper { + + /** + * A map that associates each {@link ResourceType} with a function that calculates the resource usage for a given {@link Task}. + */ + private static final Map> resourceUsageCalculator = Map.of( + ResourceType.MEMORY, + (task) -> task.getTotalResourceStats().getMemoryInBytes(), + ResourceType.CPU, + (task) -> task.getTotalResourceStats().getCpuTimeInNanos() + ); + + /** + * Gets the resource usage for a given resource type and task. + * + * @param resource the resource type + * @param task the task for which to calculate resource usage + * @return the resource usage + */ + public static long getResourceUsage(ResourceType resource, Task task) { + return resourceUsageCalculator.get(resource).apply(task); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java new file mode 100644 index 0000000000000..2fd743dc3f83f --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; + +import java.util.List; +import java.util.Map; + +/** + * Represents the point in time view of resource usage of a QueryGroup and + * has a 1:1 relation with a QueryGroup. + * This class holds the resource usage data and the list of active tasks. + */ +public class QueryGroupLevelResourceUsageView { + // resourceUsage holds the resource usage data for a QueryGroup at a point in time + private final Map resourceUsage; + // activeTasks holds the list of active tasks for a QueryGroup at a point in time + private final List activeTasks; + + public QueryGroupLevelResourceUsageView(Map resourceUsage, List activeTasks) { + this.resourceUsage = resourceUsage; + this.activeTasks = activeTasks; + } + + /** + * Returns the resource usage data. + * + * @return The map of resource usage data + */ + public Map getResourceUsageData() { + return resourceUsage; + } + + /** + * Returns the list of active tasks. + * + * @return The list of active tasks + */ + public List getActiveTasks() { + return activeTasks; + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java new file mode 100644 index 0000000000000..72966852a67ae --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.wlm.QueryGroupHelper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Represents an abstract task selection strategy. + * This class implements the TaskSelectionStrategy interface and provides a method to select tasks for cancellation based on a sorting condition. + * The specific sorting condition depends on the implementation. + */ +public abstract class AbstractTaskSelectionStrategy implements TaskSelectionStrategy { + + /** + * Returns a comparator that defines the sorting condition for tasks. + * The specific sorting condition depends on the implementation. + * + * @return The comparator + */ + public abstract Comparator sortingCondition(); + + /** + * Selects tasks for cancellation based on the provided limit and resource type. + * The tasks are sorted based on the sorting condition and then selected until the accumulated resource usage reaches the limit. + * + * @param tasks The list of tasks from which to select + * @param limit The limit on the accumulated resource usage + * @param resourceType The type of resource to consider + * @return The list of selected tasks + * @throws IllegalArgumentException If the limit is less than zero + */ + @Override + public List selectTasksForCancellation(List tasks, long limit, ResourceType resourceType) { + if (limit < 0) { + throw new IllegalArgumentException("reduceBy has to be greater than zero"); + } + if (limit == 0) { + return Collections.emptyList(); + } + + List sortedTasks = tasks.stream().sorted(sortingCondition()).collect(Collectors.toList()); + + List selectedTasks = new ArrayList<>(); + long accumulated = 0; + + for (Task task : sortedTasks) { + if (task instanceof CancellableTask) { + selectedTasks.add(createTaskCancellation((CancellableTask) task)); + accumulated += QueryGroupHelper.getResourceUsage(resourceType, task); + if (accumulated >= limit) { + break; + } + } + } + return selectedTasks; + } + + private TaskCancellation createTaskCancellation(CancellableTask task) { + // TODO add correct reason and callbacks + return new TaskCancellation(task, List.of(new TaskCancellation.Reason("limits exceeded", 5)), List.of(this::callbackOnCancel)); + } + + private void callbackOnCancel() { + // todo Implement callback logic here mostly used for Stats + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java new file mode 100644 index 0000000000000..d932d21e4affe --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.monitor.process.ProcessProbe; +import org.opensearch.search.ResourceType; +import org.opensearch.search.backpressure.settings.NodeDuressSettings; +import org.opensearch.search.backpressure.trackers.NodeDuressTrackers; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; + +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService.TRACKED_RESOURCES; + +/** + * Manages the cancellation of tasks enforced by QueryGroup thresholds on resource usage criteria. + * This class utilizes a strategy pattern through {@link TaskSelectionStrategy} to identify tasks that exceed + * predefined resource usage limits and are therefore eligible for cancellation. + * + *

The cancellation process is initiated by evaluating the resource usage of each QueryGroup against its + * resource limits. Tasks that contribute to exceeding these limits are selected for cancellation based on the + * implemented task selection strategy.

+ * + *

Instances of this class are configured with a map linking QueryGroup IDs to their corresponding resource usage + * views, a set of active QueryGroups, and a task selection strategy. These components collectively facilitate the + * identification and cancellation of tasks that threaten to breach QueryGroup resource limits.

+ * + * @see TaskSelectionStrategy + * @see QueryGroup + * @see ResourceType + */ +public class DefaultTaskCancellation { + private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + + protected final TaskSelectionStrategy taskSelectionStrategy; + // a map of QueryGroupId to its corresponding QueryGroupLevelResourceUsageView object + protected final Map queryGroupLevelResourceUsageViews; + protected final Set activeQueryGroups; + protected NodeDuressTrackers nodeDuressTrackers; + + public DefaultTaskCancellation( + TaskSelectionStrategy taskSelectionStrategy, + Map queryGroupLevelResourceUsageViews, + Set activeQueryGroups, + Settings settings, + ClusterSettings clusterSettings + ) { + this.taskSelectionStrategy = taskSelectionStrategy; + this.queryGroupLevelResourceUsageViews = queryGroupLevelResourceUsageViews; + this.activeQueryGroups = activeQueryGroups; + this.nodeDuressTrackers = setupNodeDuressTracker(settings, clusterSettings); + } + + /** + * Cancel tasks based on the implemented strategy. + */ + public final void cancelTasks() { + cancelTasksForMode(QueryGroup.ResiliencyMode.ENFORCED); + + if (nodeDuressTrackers.isNodeInDuress()) { + cancelTasksForMode(QueryGroup.ResiliencyMode.SOFT); + } + } + + private void cancelTasksForMode(QueryGroup.ResiliencyMode resiliencyMode) { + List cancellableTasks = getAllCancellableTasksFrom(resiliencyMode); + for (TaskCancellation taskCancellation : cancellableTasks) { + taskCancellation.cancel(); + } + } + + /** + * Get all cancellable tasks from the QueryGroups. + * + * @return List of tasks that can be cancelled + */ + protected List getAllCancellableTasksFrom(QueryGroup.ResiliencyMode resiliencyMode) { + return getQueryGroupsToCancelFrom(resiliencyMode).stream() + .flatMap(queryGroup -> getCancellableTasksFrom(queryGroup).stream()) + .collect(Collectors.toList()); + } + + /** + * returns the list of QueryGroups breaching their resource limits. + * + * @return List of QueryGroups + */ + private List getQueryGroupsToCancelFrom(QueryGroup.ResiliencyMode resiliencyMode) { + final List queryGroupsToCancelFrom = new ArrayList<>(); + + for (QueryGroup queryGroup : this.activeQueryGroups) { + if (queryGroup.getResiliencyMode() != resiliencyMode) { + continue; + } + Map queryGroupResourceUsage = queryGroupLevelResourceUsageViews.get(queryGroup.get_id()) + .getResourceUsageData(); + + for (ResourceType resourceType : TRACKED_RESOURCES) { + if (queryGroup.getResourceLimits().containsKey(resourceType) && queryGroupResourceUsage.containsKey(resourceType)) { + Double resourceLimit = (Double) queryGroup.getResourceLimits().get(resourceType); + Long resourceUsage = queryGroupResourceUsage.get(resourceType); + + if (isBreachingThreshold(resourceType, resourceLimit, resourceUsage)) { + queryGroupsToCancelFrom.add(queryGroup); + break; + } + } + } + } + + return queryGroupsToCancelFrom; + } + + /** + * Get cancellable tasks from a specific queryGroup. + * + * @param queryGroup The QueryGroup from which to get cancellable tasks + * @return List of tasks that can be cancelled + */ + protected List getCancellableTasksFrom(QueryGroup queryGroup) { + return TRACKED_RESOURCES.stream() + .filter(resourceType -> shouldCancelTasks(queryGroup, resourceType)) + .flatMap(resourceType -> getTaskCancellations(queryGroup, resourceType).stream()) + .collect(Collectors.toList()); + } + + private boolean shouldCancelTasks(QueryGroup queryGroup, ResourceType resourceType) { + long reduceBy = getReduceBy(queryGroup, resourceType); + return reduceBy > 0; + } + + private List getTaskCancellations(QueryGroup queryGroup, ResourceType resourceType) { + return taskSelectionStrategy.selectTasksForCancellation( + // get the active tasks in the query group + queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks(), + getReduceBy(queryGroup, resourceType), + resourceType + ); + } + + private long getReduceBy(QueryGroup queryGroup, ResourceType resourceType) { + if (queryGroup.getResourceLimits().get(resourceType) == null) { + return 0; + } + Double threshold = (Double) queryGroup.getResourceLimits().get(resourceType); + return getResourceUsage(queryGroup, resourceType) - convertThresholdIntoLong(resourceType, threshold); + } + + private Long convertThresholdIntoLong(ResourceType resourceType, Double resourceThresholdInPercentage) { + Long threshold = null; + if (resourceType == ResourceType.MEMORY) { + // Check if resource usage is breaching the threshold + threshold = (long) (resourceThresholdInPercentage * HEAP_SIZE_BYTES); + } else if (resourceType == ResourceType.CPU) { + // Get the total CPU time of the process in milliseconds + long cpuTotalTimeInMillis = ProcessProbe.getInstance().getProcessCpuTotalTime(); + // Check if resource usage is breaching the threshold + threshold = (long) (resourceThresholdInPercentage * cpuTotalTimeInMillis); + } + return threshold; + } + + private Long getResourceUsage(QueryGroup queryGroup, ResourceType resourceType) { + if (!queryGroupLevelResourceUsageViews.containsKey(queryGroup.get_id())) { + return 0L; + } + return queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getResourceUsageData().get(resourceType); + } + + private boolean isBreachingThreshold(ResourceType resourceType, Double resourceThresholdInPercentage, long resourceUsage) { + if (resourceType == ResourceType.MEMORY) { + // Check if resource usage is breaching the threshold + return resourceUsage > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage); + } + // Resource types should be CPU, resourceUsage is in nanoseconds, convert to milliseconds + long resourceUsageInMillis = resourceUsage / 1_000_000; + // Check if resource usage is breaching the threshold + return resourceUsageInMillis > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage); + } + + private NodeDuressTrackers setupNodeDuressTracker(Settings settings, ClusterSettings clusterSettings) { + NodeDuressSettings nodeDuressSettings = new NodeDuressSettings(settings, clusterSettings); + return new NodeDuressTrackers(new EnumMap<>(ResourceType.class) { + { + put( + ResourceType.CPU, + new NodeDuressTrackers.NodeDuressTracker( + () -> ProcessProbe.getInstance().getProcessCpuPercent() / 100.0 >= nodeDuressSettings.getCpuThreshold(), + nodeDuressSettings::getNumSuccessiveBreaches + ) + ); + put( + ResourceType.MEMORY, + new NodeDuressTrackers.NodeDuressTracker( + () -> JvmStats.jvmStats().getMem().getHeapUsedPercent() / 100.0 >= nodeDuressSettings.getHeapThreshold(), + nodeDuressSettings::getNumSuccessiveBreaches + ) + ); + } + }); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java new file mode 100644 index 0000000000000..d36d55b25bb4a --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; + +import java.util.Comparator; + +/** + * Represents a task selection strategy that prioritizes the longest running tasks first. + */ +public class LongestRunningTaskFirstSelectionStrategy extends AbstractTaskSelectionStrategy { + + /** + * Returns a comparator that sorts tasks based on their start time in descending order. + * + * @return The comparator + */ + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getStartTime); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java new file mode 100644 index 0000000000000..1e8e75b291d05 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; + +import java.util.Comparator; + +/** + * Represents a task selection strategy that prioritizes the shortest running tasks first. + */ +public class ShortestRunningTaskFirstSelectionStrategy extends AbstractTaskSelectionStrategy { + + /** + * Returns a comparator that sorts tasks based on their start time in ascending order. + * + * @return The comparator + */ + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getStartTime).reversed(); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java new file mode 100644 index 0000000000000..72161671186f2 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.util.List; + +/** + * Interface for strategies to select tasks for cancellation. + * Implementations of this interface define how tasks are selected for cancellation based on resource usage. + */ +public interface TaskSelectionStrategy { + /** + * Determines which tasks should be cancelled based on the provided criteria. + * + * @param tasks List of tasks available for cancellation. + * @param limit The amount of tasks to select whose resources reach this limit + * @param resourceType The type of resource that needs to be reduced, guiding the selection process. + * + * @return List of tasks that should be cancelled. + */ + List selectTasksForCancellation(List tasks, long limit, ResourceType resourceType); +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java new file mode 100644 index 0000000000000..9618d22c9d5e2 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * QueryGroup resource cancellation artifacts + */ +package org.opensearch.wlm.cancellation; diff --git a/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java new file mode 100644 index 0000000000000..94aafb9fea5f5 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.wlm.QueryGroupHelper; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; + +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * This class tracks resource usage per QueryGroup + */ +public class QueryGroupResourceUsageTrackerService implements QueryGroupUsageTracker, TaskManager.TaskEventListeners { + + public static final List TRACKED_RESOURCES = List.of(ResourceType.values()); + private final TaskManager taskManager; + private final TaskResourceTrackingService taskResourceTrackingService; + + /** + * QueryGroupResourceTrackerService constructor + * + * @param taskManager Task Manager service for keeping track of currently running tasks on the nodes + * @param taskResourceTrackingService Service that helps track resource usage of tasks running on a node. + */ + public QueryGroupResourceUsageTrackerService( + final TaskManager taskManager, + final TaskResourceTrackingService taskResourceTrackingService + ) { + this.taskManager = taskManager; + this.taskResourceTrackingService = taskResourceTrackingService; + } + + /** + * Constructs a map of QueryGroupLevelResourceUsageView instances for each QueryGroup. + * + * @return Map of QueryGroup views + */ + @Override + public Map constructQueryGroupLevelUsageViews() { + final Map> tasksByQueryGroup = getTasksGroupedByQueryGroup(); + final Map queryGroupViews = new HashMap<>(); + + // Iterate over each QueryGroup entry + for (Map.Entry> queryGroupEntry : tasksByQueryGroup.entrySet()) { + // Compute the QueryGroup usage + final EnumMap queryGroupUsage = new EnumMap<>(ResourceType.class); + for (ResourceType resourceType : TRACKED_RESOURCES) { + long queryGroupResourceUsage = 0; + for (Task task : queryGroupEntry.getValue()) { + queryGroupResourceUsage += QueryGroupHelper.getResourceUsage(resourceType, task); + } + queryGroupUsage.put(resourceType, queryGroupResourceUsage); + } + + // Add to the QueryGroup View + queryGroupViews.put( + queryGroupEntry.getKey(), + new QueryGroupLevelResourceUsageView(queryGroupUsage, queryGroupEntry.getValue()) + ); + } + return queryGroupViews; + } + + /** + * Groups tasks by their associated QueryGroup. + * + * @return Map of tasks grouped by QueryGroup + */ + private Map> getTasksGroupedByQueryGroup() { + return taskResourceTrackingService.getResourceAwareTasks() + .values() + .stream() + .filter(QueryGroupTask.class::isInstance) + .map(QueryGroupTask.class::cast) + .collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> (Task) task, Collectors.toList()))); + } + + /** + * Handles the completion of a task. + * + * @param task The completed task + */ + @Override + public void onTaskCompleted(Task task) {} +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupUsageTracker.java b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupUsageTracker.java new file mode 100644 index 0000000000000..23fb8b1b45aac --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupUsageTracker.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; + +import java.util.Map; + +/** + * This interface is mainly for tracking the queryGroup level resource usages + */ +public interface QueryGroupUsageTracker { + /** + * updates the current resource usage of queryGroup + */ + + Map constructQueryGroupLevelUsageViews(); +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/package-info.java b/server/src/main/java/org/opensearch/wlm/tracker/package-info.java new file mode 100644 index 0000000000000..86efc99355d3d --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * QueryGroup resource tracking artifacts + */ +package org.opensearch.wlm.tracker; diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java new file mode 100644 index 0000000000000..c965957516bad --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; + +import static org.opensearch.wlm.QueryGroupTestHelpers.getRandomTask; + +public class QueryGroupLevelResourceUsageViewTests extends OpenSearchTestCase { + Map resourceUsage; + List activeTasks; + + public void setUp() throws Exception { + super.setUp(); + resourceUsage = Map.of(ResourceType.fromName("memory"), 34L, ResourceType.fromName("cpu"), 12L); + activeTasks = List.of(getRandomTask(4321)); + } + + public void testGetResourceUsageData() { + QueryGroupLevelResourceUsageView queryGroupLevelResourceUsageView = new QueryGroupLevelResourceUsageView( + resourceUsage, + activeTasks + ); + Map resourceUsageData = queryGroupLevelResourceUsageView.getResourceUsageData(); + assertTrue(assertResourceUsageData(resourceUsageData)); + } + + public void testGetActiveTasks() { + QueryGroupLevelResourceUsageView queryGroupLevelResourceUsageView = new QueryGroupLevelResourceUsageView( + resourceUsage, + activeTasks + ); + List activeTasks = queryGroupLevelResourceUsageView.getActiveTasks(); + assertEquals(1, activeTasks.size()); + assertEquals(4321, activeTasks.get(0).getId()); + } + + private boolean assertResourceUsageData(Map resourceUsageData) { + return resourceUsageData.get(ResourceType.fromName("memory")) == 34L && resourceUsageData.get(ResourceType.fromName("cpu")) == 12L; + } +} diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupTestHelpers.java b/server/src/test/java/org/opensearch/wlm/QueryGroupTestHelpers.java new file mode 100644 index 0000000000000..b3fda77c662b2 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupTestHelpers.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.Task; + +import java.util.Collections; + +import static org.opensearch.test.OpenSearchTestCase.randomLong; + +public class QueryGroupTestHelpers { + + public static Task getRandomTask(long id) { + return new Task( + id, + "transport", + SearchAction.NAME, + "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java new file mode 100644 index 0000000000000..0c8f186ed425b --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java @@ -0,0 +1,340 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchTask; +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.search.ResourceType; +import org.opensearch.search.backpressure.trackers.NodeDuressTrackers; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.junit.Before; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DefaultTaskCancellationTests extends OpenSearchTestCase { + private static final String queryGroupId1 = "queryGroup1"; + private static final String queryGroupId2 = "queryGroup2"; + + private static class TestTaskCancellationImpl extends DefaultTaskCancellation { + + public TestTaskCancellationImpl( + TaskSelectionStrategy taskSelectionStrategy, + Map queryGroupLevelViews, + Set activeQueryGroups + ) { + super( + taskSelectionStrategy, + queryGroupLevelViews, + activeQueryGroups, + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + } + } + + private Map queryGroupLevelViews; + private Set activeQueryGroups; + private DefaultTaskCancellation taskCancellation; + + @Before + public void setup() { + queryGroupLevelViews = new HashMap<>(); + activeQueryGroups = new HashSet<>(); + taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + long usage = 100_000_000L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThresholdForMemory() { + ResourceType resourceType = ResourceType.MEMORY; + long usage = 900_000_000_000L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsNoTasksWhenNotBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + long usage = 500L; + Double threshold = 0.9; + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1); + assertTrue(cancellableTasksFrom.isEmpty()); + } + + public void testGetCancellableTasksFrom_filtersQueryGroupCorrectly() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.SOFT); + assertEquals(0, cancellableTasksFrom.size()); + } + + public void testCancelTasks_cancelsGivenTasks() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + taskCancellation.cancelTasks(); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + } + + public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup", + queryGroupId2, + QueryGroup.ResiliencyMode.SOFT, + Map.of(resourceType, threshold), + 1L + ); + + queryGroupLevelViews.put(queryGroupId1, createResourceUsageViewMock(resourceType, usage)); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(5678), getRandomSearchTask(8765))); + queryGroupLevelViews.put(queryGroupId2, mockView); + Collections.addAll(activeQueryGroups, queryGroup1, queryGroup2); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + NodeDuressTrackers mock = mock(NodeDuressTrackers.class); + when(mock.isNodeInDuress()).thenReturn(true); + taskCancellation.nodeDuressTrackers = mock; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFrom1 = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.SOFT); + assertEquals(2, cancellableTasksFrom1.size()); + assertEquals(5678, cancellableTasksFrom1.get(0).getTask().getId()); + assertEquals(8765, cancellableTasksFrom1.get(1).getTask().getId()); + + taskCancellation.cancelTasks(); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(1).getTask().isCancelled()); + } + + public void testGetAllCancellableTasks_ReturnsNoTasksFromWhenNotBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + long usage = 1L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List allCancellableTasks = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertTrue(allCancellableTasks.isEmpty()); + } + + public void testGetAllCancellableTasks_ReturnsTasksFromWhenBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List allCancellableTasks = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, allCancellableTasks.size()); + assertEquals(1234, allCancellableTasks.get(0).getTask().getId()); + assertEquals(4321, allCancellableTasks.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_doesNotReturnTasksWhenQueryGroupIdNotFound() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup", + queryGroupId2, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + activeQueryGroups.add(queryGroup2); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup2); + assertEquals(0, cancellableTasksFrom.size()); + } + + private QueryGroupLevelResourceUsageView createResourceUsageViewMock(ResourceType resourceType, Long usage) { + QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class); + when(mockView.getResourceUsageData()).thenReturn(Collections.singletonMap(resourceType, usage)); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1234), getRandomSearchTask(4321))); + return mockView; + } + + private Task getRandomSearchTask(long id) { + return new SearchTask( + id, + "transport", + SearchAction.NAME, + () -> "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java new file mode 100644 index 0000000000000..ad76a5021b175 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LongestRunningTaskFirstStrategySelectionStrategyTests extends OpenSearchTestCase { + public void testSortingCondition() { + Task task1 = mock(Task.class); + Task task2 = mock(Task.class); + Task task3 = mock(Task.class); + when(task1.getStartTime()).thenReturn(100L); + when(task2.getStartTime()).thenReturn(200L); + when(task3.getStartTime()).thenReturn(300L); + + List tasks = Arrays.asList(task2, task1, task3); + tasks.sort(new LongestRunningTaskFirstSelectionStrategy().sortingCondition()); + + assertEquals(Arrays.asList(task1, task2, task3), tasks); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java new file mode 100644 index 0000000000000..3c07df09f6f5e --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShortestRunningTaskFirstStrategySelectionStrategyTests extends OpenSearchTestCase { + public void testSortingCondition() { + Task task1 = mock(Task.class); + Task task2 = mock(Task.class); + Task task3 = mock(Task.class); + when(task1.getStartTime()).thenReturn(100L); + when(task2.getStartTime()).thenReturn(200L); + when(task3.getStartTime()).thenReturn(300L); + + List tasks = Arrays.asList(task1, task3, task2); + tasks.sort(new ShortestRunningTaskFirstSelectionStrategy().sortingCondition()); + + assertEquals(Arrays.asList(task3, task2, task1), tasks); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java new file mode 100644 index 0000000000000..43ccbd0920068 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchTask; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +public class TaskSelectionStrategyTests extends OpenSearchTestCase { + + public static class TestTaskSelectionStrategy extends AbstractTaskSelectionStrategy { + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getId); + } + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsGreaterThanZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = 50L; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + List selectedTasks = testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + assertFalse(selectedTasks.isEmpty()); + assertTrue(tasksUsageMeetsThreshold(selectedTasks, reduceBy)); + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsLesserThanZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = -50L; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + try { + testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + } catch (Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertEquals("reduceBy has to be greater than zero", e.getMessage()); + } + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsEqualToZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = 0; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + List selectedTasks = testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + assertTrue(selectedTasks.isEmpty()); + } + + private boolean tasksUsageMeetsThreshold(List selectedTasks, long threshold) { + long memory = 0; + for (TaskCancellation task : selectedTasks) { + memory += task.getTask().getTotalResourceUtilization(ResourceStats.MEMORY); + if (memory > threshold) { + return true; + } + } + return false; + } + + private List getListOfTasks(long totalMemory) { + List tasks = new ArrayList<>(); + + while (totalMemory > 0) { + long id = randomLong(); + final Task task = getRandomSearchTask(id); + long initial_memory = randomLongBetween(1, 100); + + ResourceUsageMetric[] initialTaskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, initial_memory) }; + task.startThreadResourceTracking(id, ResourceStatsType.WORKER_STATS, initialTaskResourceMetrics); + + long memory = initial_memory + randomLongBetween(1, 10000); + + totalMemory -= memory - initial_memory; + + ResourceUsageMetric[] taskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, memory), }; + task.updateThreadResourceStats(id, ResourceStatsType.WORKER_STATS, taskResourceMetrics); + task.stopThreadResourceTracking(id, ResourceStatsType.WORKER_STATS); + tasks.add(task); + } + + return tasks; + } + + private Task getRandomSearchTask(long id) { + return new SearchTask( + id, + "transport", + SearchAction.NAME, + () -> "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/tracking/QueryGroupResourceUsageTrackerServiceTests.java b/server/src/test/java/org/opensearch/wlm/tracking/QueryGroupResourceUsageTrackerServiceTests.java new file mode 100644 index 0000000000000..1c1893c4cef37 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/tracking/QueryGroupResourceUsageTrackerServiceTests.java @@ -0,0 +1,131 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracking; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.action.search.SearchTask; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; +import org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.wlm.QueryGroupTask.QUERY_GROUP_ID_HEADER; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QueryGroupResourceUsageTrackerServiceTests extends OpenSearchTestCase { + TestThreadPool threadPool; + TaskManager taskManager; + TaskResourceTrackingService mockTaskResourceTrackingService; + QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService; + + @Before + public void setup() { + threadPool = new TestThreadPool(getTestName()); + taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); + mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(taskManager, mockTaskResourceTrackingService); + } + + @After + public void cleanup() { + ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS); + } + + public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_WhenTasksArePresent() { + List queryGroupIds = List.of("queryGroup1", "queryGroup2", "queryGroup3"); + + Map activeSearchShardTasks = createActiveSearchShardTasks(queryGroupIds); + when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); + Map stringQueryGroupLevelResourceUsageViewMap = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + + for (String queryGroupId : queryGroupIds) { + assertEquals( + 400, + (long) stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getResourceUsageData().get(ResourceType.MEMORY) + ); + assertEquals(2, stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getActiveTasks().size()); + } + } + + public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_WhenTasksAreNotPresent() { + Map stringQueryGroupLevelResourceUsageViewMap = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + assertTrue(stringQueryGroupLevelResourceUsageViewMap.isEmpty()); + } + + public void testConstructQueryGroupLevelUsageViews_WithTasksHavingDifferentResourceUsage() { + Map activeSearchShardTasks = new HashMap<>(); + activeSearchShardTasks.put(1L, createMockTask(SearchShardTask.class, 100, 200, "queryGroup1")); + activeSearchShardTasks.put(2L, createMockTask(SearchShardTask.class, 200, 400, "queryGroup1")); + when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); + + Map queryGroupViews = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + + assertEquals(600, (long) queryGroupViews.get("queryGroup1").getResourceUsageData().get(ResourceType.MEMORY)); + assertEquals(2, queryGroupViews.get("queryGroup1").getActiveTasks().size()); + } + + private Map createActiveSearchShardTasks(List queryGroupIds) { + Map activeSearchShardTasks = new HashMap<>(); + long task_id = 0; + for (String queryGroupId : queryGroupIds) { + for (int i = 0; i < 2; i++) { + activeSearchShardTasks.put(++task_id, createMockTask(SearchShardTask.class, 100, 200, queryGroupId)); + } + } + return activeSearchShardTasks; + } + + private T createMockTask(Class type, long cpuUsage, long heapUsage, String queryGroupId) { + T task = mock(type); + if (task instanceof SearchTask || task instanceof SearchShardTask) { + // Stash the current thread context to ensure that any existing context is preserved and restored after setting the query group + // ID. + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putHeader(QUERY_GROUP_ID_HEADER, queryGroupId); + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } + } + when(task.getTotalResourceStats()).thenReturn(new TaskResourceUsage(cpuUsage, heapUsage)); + when(task.getStartTimeNanos()).thenReturn((long) 0); + + AtomicBoolean isCancelled = new AtomicBoolean(false); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(task).cancel(anyString()); + doAnswer(invocation -> isCancelled.get()).when(task).isCancelled(); + + return task; + } +}