Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new configs to support inter AZ stage-to-stage routing #755

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2025 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.mantisrx.common.util;

public interface AvailabilityZoneUtils {
String getAvailabilityZone();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2025 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.mantisrx.common.util;

import java.util.Properties;

public class DefaultAvailabilityZoneUtils implements AvailabilityZoneUtils {

public static DefaultAvailabilityZoneUtils valueOf(Properties properties) {
return new DefaultAvailabilityZoneUtils();
}

@Override
public String getAvailabilityZone() {
return "";
}
}
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ public Observable<Void> call(
// sample state
boolean enableSampling = false;
long samplingTimeMsec = 0;
String availabilityZone = null;

// predicate state
Map<String, List<String>> predicateParams = null;
@@ -122,6 +123,9 @@ public Observable<Void> call(
}
enableSampling = true;
}
if (params.containsKey("availabilityZone") && !params.get("availabilityZone").isEmpty()) {
availabilityZone = params.get("availabilityZone");
}
}
Func1<T, Boolean> predicateFunction = null;
if (predicate != null) {
@@ -144,7 +148,7 @@ public Observable<Void> call(
groupId, slotId, id, null,
false, null, enableSampling, samplingTimeMsec,
predicateFunction, null, legacyMsgProcessedCounter, legacyDroppedWrites,
null);
null, availabilityZone);
}
return null;
}
Original file line number Diff line number Diff line change
@@ -180,9 +180,9 @@ protected Observable<Void> manageConnection(final DefaultChannelWriter<R> writer
final Subscription heartbeatSubscription, boolean applySampling, long samplingRateMSec,
Func1<T, Boolean> predicate, final Action0 connectionClosedCallback,
final Counter legacyMsgProcessedCounter, final Counter legacyDroppedWrites,
final Action0 connectionSubscribeCallback) {
final Action0 connectionSubscribeCallback, final String availabilityZone) {
return manageConnection(writer, host, port, groupId, slotId, id, lastWriteTime, applicationHeartbeats, heartbeatSubscription,
applySampling, samplingRateMSec, null, null, predicate, connectionClosedCallback, legacyMsgProcessedCounter, legacyDroppedWrites, connectionSubscribeCallback);
applySampling, samplingRateMSec, null, null, predicate, connectionClosedCallback, legacyMsgProcessedCounter, legacyDroppedWrites, connectionSubscribeCallback, availabilityZone);
}

protected Observable<Void> manageConnection(final DefaultChannelWriter<R> writer, String host, int port,
@@ -191,9 +191,9 @@ protected Observable<Void> manageConnection(final DefaultChannelWriter<R> writer
final SerializedSubject<String, String> metaMsgSubject, final Subscription metaMsgSubscription,
Func1<T, Boolean> predicate, final Action0 connectionClosedCallback,
final Counter legacyMsgProcessedCounter, final Counter legacyDroppedWrites,
final Action0 connectionSubscribeCallback) {
final Action0 connectionSubscribeCallback, final String availabilityZone) {
return manageConnectionWithCompression(writer, host, port, groupId, slotId, id, lastWriteTime, applicationHeartbeats, heartbeatSubscription,
applySampling, samplingRateMSec, null, null, predicate, connectionClosedCallback, legacyMsgProcessedCounter, legacyDroppedWrites, connectionSubscribeCallback, false, false, null, null);
applySampling, samplingRateMSec, null, null, predicate, connectionClosedCallback, legacyMsgProcessedCounter, legacyDroppedWrites, connectionSubscribeCallback, false, false, null, availabilityZone);

}

Original file line number Diff line number Diff line change
@@ -16,12 +16,17 @@

package io.reactivex.mantis.network.push;

import java.lang.reflect.Constructor;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.functions.Func1;


public class Routers {
private static final Logger logger = LoggerFactory.getLogger(Routers.class);

private Routers() {}

@@ -85,6 +90,32 @@ public byte[] call(T data) {
});
}

public static <T> Router<T> createRouterInstance(String routerClassName, String name, final Func1<T, byte[]> toBytes) {
try {
// Load the class by its name
Class<?> clazz = Class.forName(routerClassName);
// Check if the class is a Router
if (!Router.class.isAssignableFrom(clazz)) {
throw new IllegalArgumentException(routerClassName + " does not implement " + Router.class.getName());
}
// Find the constructor of the class
Constructor<?> constructor = clazz.getDeclaredConstructor(String.class, Func1.class);
// Create a new instance using the constructor
Object instance = constructor.newInstance(name, toBytes);

@SuppressWarnings("unchecked")
Router<T> routerInstance = (Router<T>) instance;

return routerInstance;
} catch (Exception e) {
// Handle any exceptions (ClassNotFoundException, NoSuchMethodException, etc.)
final String msg = "failed to create instance of " + routerClassName;
logger.error(msg, e);
// Fall back to RoundRobinRouter
return roundRobinLegacyTcpProtocol(name, toBytes);
}
}

private static Func1<String, byte[]> stringWithEncoding(String encoding) {
final Charset charset = Charset.forName(encoding);
return new Func1<String, byte[]>() {
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2025 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.reactivex.mantis.network.push;

import org.junit.jupiter.api.Test;
import rx.functions.Func1;

import static io.reactivex.mantis.network.push.Routers.createRouterInstance;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class RoutersTest {

@Test
void testCreateRouterInstanceSuccessfully() {
final Func1<String, byte[]> toBytes = s -> s.getBytes();

Router<String> router = Routers.createRouterInstance(
RoundRobinRouter.class.getName(),
"testRouter",
toBytes
);

assertTrue(router instanceof RoundRobinRouter, "Expected instance of ConsistentHashingRouter");
}

@Test
void testCreateRouterInstanceClassNotFound() {
final Func1<String, byte[]> toBytes = s -> s.getBytes();
System.out.println("HEEEEE");

Router<String> router = createRouterInstance(
"NonExistentRouterClass",
"testRouter",
toBytes
);

assertTrue(router instanceof RoundRobinRouter, "Expected instance of ConsistentHashingRouter");
}
}
Original file line number Diff line number Diff line change
@@ -120,6 +120,11 @@ public Builder<T> slotId(String slotId) {
return this;
}

public Builder<T> availabilityZone(String availabilityZone) {
this.subscribeParameters.put("availabilityZone", availabilityZone);
return this;
}

public Builder<T> decoder(Decoder<T> decoder) {
this.decoder = decoder;
return this;
Original file line number Diff line number Diff line change
@@ -448,11 +448,10 @@ public int acquirePort() {
// execute source stage
String remoteObservableName = rw.getJobId() + "_" + rw.getStageNum();

StageSchedulingInfo currentStageSchedulingInfo = rw.getSchedulingInfo().forStage(1);
WorkerPublisherRemoteObservable publisher
= new WorkerPublisherRemoteObservable<>(rw.getPorts().next(),
remoteObservableName, numWorkersAtStage(selfSchedulingInfo, rw.getJobId(), rw.getStageNum() + 1),
rw.getJobName());
rw.getJobName(), this.config);

closeables.add(StageExecutors.executeSource(rw.getWorkerIndex(), rw.getJob().getSource(),
rw.getStage(), publisher, rw.getContext(), rw.getSourceStageTotalWorkersObservable()));
@@ -561,14 +560,13 @@ public void call() {
// intermediate stage
logger.info("JobId: " + rw.getJobId() + ", executing intermediate stage: " + rw.getStageNum());


int stageNumToExecute = rw.getStageNum();
String jobId = rw.getJobId();
String remoteObservableName = jobId + "_" + stageNumToExecute;

WorkerPublisherRemoteObservable publisher
= new WorkerPublisherRemoteObservable<>(workerPort, remoteObservableName,
numWorkersAtStage(selfSchedulingInfo, rw.getJobId(), rw.getStageNum() + 1), rw.getJobName());
numWorkersAtStage(selfSchedulingInfo, rw.getJobId(), rw.getStageNum() + 1), rw.getJobName(), this.config);
closeables.add(StageExecutors.executeIntermediate(consumer, rw.getStage(), publisher,
rw.getContext()));
RemoteRxServer server = publisher.getServer();
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@


public class StaticPropertiesConfigurationFactory implements ConfigurationFactory {

private final WorkerConfiguration config;

public StaticPropertiesConfigurationFactory(Properties props) {
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

package io.mantisrx.runtime.loader.config;

import io.mantisrx.common.util.AvailabilityZoneUtils;
import io.mantisrx.server.core.CoreConfiguration;
import io.mantisrx.shaded.com.fasterxml.jackson.annotation.JsonIgnore;
import io.mantisrx.shaded.com.google.common.base.Splitter;
@@ -192,4 +193,16 @@ default Map<String, String> getTaskExecutorAttributes() {
.filter(entry -> !entry.getValue().matches("\\$\\{.*\\}"))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

// ------------------------------------------------------------------------
// Routing related configurations
// ------------------------------------------------------------------------
@Config("mantis.taskexecutor.router.scalar-stage-to-stage")
@Default("io.reactivex.mantis.network.push.RoundRobinRouter")
String getScalarStageToStageRouterClassName();


@Config("mantis.availabilityZoneUtils.class")
@Default("io.mantisrx.common.util.DefaultAvailabilityZoneUtils")
AvailabilityZoneUtils getAvailabilityZoneUtils();
}
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
package io.mantisrx.runtime.loader.config;

import io.mantisrx.common.JsonSerializer;
import io.mantisrx.common.util.AvailabilityZoneUtils;
import io.mantisrx.server.core.MetricsCoercer;
import io.mantisrx.server.master.client.config.PluginCoercible;
import java.io.IOException;
@@ -29,6 +30,7 @@ public static <T extends WorkerConfiguration> T frmProperties(Properties propert
properties);
configurationObjectFactory.addCoercible(new MetricsCoercer(properties));
configurationObjectFactory.addCoercible(new PluginCoercible<>(MetricsCollector.class, properties));
configurationObjectFactory.addCoercible(new PluginCoercible<>(AvailabilityZoneUtils.class, properties));
return configurationObjectFactory.build(tClass);
}

@@ -67,6 +69,8 @@ public static <T extends WorkerConfiguration> WorkerConfigurationWritable toWrit
.zkConnectionRetrySleepMs(configSource.getZkConnectionRetrySleepMs())
.zkRoot(configSource.getZkRoot())
.leaderMonitorFactory(configSource.getLeaderMonitorFactoryName())
.scalarStageToStageRouterClassName(configSource.getScalarStageToStageRouterClassName())
.availabilityZoneUtils(configSource.getAvailabilityZoneUtils())
.build();
}

Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

package io.mantisrx.runtime.loader.config;

import io.mantisrx.common.util.AvailabilityZoneUtils;
import io.mantisrx.common.metrics.MetricsPublisher;
import io.mantisrx.server.core.ILeaderMonitorFactory;
import io.mantisrx.server.core.utils.ConfigUtils;
@@ -79,6 +80,10 @@ public class WorkerConfigurationWritable implements WorkerConfiguration {
String leaderMonitorFactory;
String metricsCollectorClass;
String jobAutoscalerManagerClassName;
String scalarStageToStageRouterClassName;

@JsonIgnore
AvailabilityZoneUtils availabilityZoneUtils;

@JsonIgnore
MetricsPublisher metricsPublisher;
@@ -139,6 +144,11 @@ public boolean getAsyncHttpClientFollowRedirect() {
@Override
public String getLeaderMonitorFactoryName() {return this.leaderMonitorFactory;}

@Override
public AvailabilityZoneUtils getAvailabilityZoneUtils() {
return this.availabilityZoneUtils;
}

public ILeaderMonitorFactory getLeaderMonitorFactoryImpl() {
return ConfigUtils.createInstance(this.leaderMonitorFactory, ILeaderMonitorFactory.class);
}
@@ -288,6 +298,11 @@ public String taskExecutorAttributes() {
return this.taskExecutorAttributesStr;
}

@Override
public String getScalarStageToStageRouterClassName() {
return this.scalarStageToStageRouterClassName;
}

@Override
public File getRegistrationStoreDir() {
return null;
1 change: 1 addition & 0 deletions mantis-runtime/build.gradle
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ dependencies {
api project(':mantis-remote-observable')
api project(':mantis-network')
api project(':mantis-common')
api project(":mantis-runtime-loader")
api libraries.slf4jApi
compileOnly libraries.jsr305
compileOnly libraries.spectatorApi
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import io.mantisrx.common.metrics.Metrics;
import io.mantisrx.common.metrics.MetricsRegistry;
import io.mantisrx.runtime.*;
import io.mantisrx.runtime.loader.config.WorkerConfiguration;
import io.reactivex.mantis.remote.observable.ConnectToGroupedObservable;
import io.reactivex.mantis.remote.observable.ConnectToObservable;
import io.reactivex.mantis.remote.observable.DynamicConnectionSet;
@@ -38,11 +39,18 @@ public class WorkerConsumerRemoteObservable<T, R> implements WorkerConsumer<T> {

private DynamicConnectionSet<T> connectionSet;
private Reconciliator<T> reconciliator;
private final WorkerConfiguration config;

public WorkerConsumerRemoteObservable(String name,
EndpointInjector endpointInjector) {
this(name, endpointInjector, null);
}

public WorkerConsumerRemoteObservable(String name,
EndpointInjector endpointInjector, WorkerConfiguration config) {
this.name = name;
this.injector = endpointInjector;
this.config = config;
}

@SuppressWarnings( {"rawtypes", "unchecked"})
@@ -68,6 +76,10 @@ public Observable<Observable<T>> start(StageConfig<T, ?> stage) {
.decoder(stage.getInputCodec())
.subscribeAttempts(30); // max retry before failure

if (config != null) {
connectToBuilder.availabilityZone(config.getAvailabilityZoneUtils().getAvailabilityZone());
}

connectionSet = DynamicConnectionSet.create(connectToBuilder);
} else {
throw new RuntimeException("Unsupported stage type: " + stage);
Loading