Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -812,16 +812,18 @@ size_t ClusterSizeBasedLeaseRequestRateLimiter::
void ClusterSizeBasedLeaseRequestRateLimiter::OnNodeChanges(
const rpc::GcsNodeAddressAndLiveness &data) {
if (data.state() == rpc::GcsNodeInfo::DEAD) {
RAY_LOG(INFO) << "BRO: " << data.node_id() << " IS DEAD!!!";
if (num_alive_nodes_ != 0) {
num_alive_nodes_--;
} else {
RAY_LOG(WARNING) << "Node" << data.node_manager_address()
<< " change state to DEAD but num_alive_node is 0.";
}
} else {
RAY_LOG(INFO) << "BRO: " << data.node_id() << " IS ALIVE!!!";
num_alive_nodes_++;
}
RAY_LOG_EVERY_MS(INFO, 60000) << "Number of alive nodes:" << num_alive_nodes_.load();
RAY_LOG(INFO) << "Number of alive nodes: " << num_alive_nodes_.load();
}

} // namespace core
Expand Down
7 changes: 7 additions & 0 deletions src/ray/gcs/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,18 @@ ray_cc_test(
tags = ["team:core"],
deps = [
"//src/mock/ray/pubsub:mock_publisher",
"//src/ray/common:grpc_util",
"//src/ray/common:test_utils",
"//src/ray/gcs:gcs_node_manager",
"//src/ray/gcs/store_client:in_memory_store_client",
"//src/ray/observability:fake_ray_event_recorder",
"//src/ray/protobuf:pubsub_cc_grpc",
"//src/ray/pubsub:gcs_subscriber",
"//src/ray/pubsub:publisher",
"//src/ray/pubsub:subscriber",
"//src/ray/raylet_rpc_client:fake_raylet_client",
"//src/ray/util:network_util",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
317 changes: 317 additions & 0 deletions src/ray/gcs/tests/gcs_node_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@

#include <gtest/gtest.h>

#include <map>
#include <memory>
#include <utility>
#include <vector>

#include "absl/synchronization/mutex.h"
#include "mock/ray/pubsub/publisher.h"
#include "ray/common/asio/periodical_runner.h"
#include "ray/common/ray_config.h"
#include "ray/common/test_utils.h"
#include "ray/gcs/store_client/in_memory_store_client.h"
#include "ray/observability/fake_ray_event_recorder.h"
#include "ray/pubsub/gcs_subscriber.h"
#include "ray/pubsub/publisher.h"
#include "ray/pubsub/subscriber.h"
#include "ray/raylet_rpc_client/fake_raylet_client.h"
#include "ray/util/network_util.h"
#include "src/ray/protobuf/pubsub.grpc.pb.h"

namespace ray {
class GcsNodeManagerTest : public ::testing::Test {
Expand Down Expand Up @@ -535,4 +543,313 @@ TEST_F(GcsNodeManagerTest, TestHandleGetAllNodeAddressAndLiveness) {
}
}

// Subscriber service implementation for integration testing
class SubscriberServiceImpl final : public rpc::SubscriberService::CallbackService {
public:
explicit SubscriberServiceImpl(std::unique_ptr<pubsub::Publisher> publisher)
: publisher_(std::move(publisher)) {}

grpc::ServerUnaryReactor *PubsubLongPolling(
grpc::CallbackServerContext *context,
const rpc::PubsubLongPollingRequest *request,
rpc::PubsubLongPollingReply *reply) override {
auto *reactor = context->DefaultReactor();
publisher_->ConnectToSubscriber(*request,
reply->mutable_publisher_id(),
reply->mutable_pub_messages(),
[reactor](ray::Status status,
std::function<void()> success_cb,
std::function<void()> failure_cb) {
RAY_CHECK_OK(status);
reactor->Finish(grpc::Status::OK);
});
return reactor;
}

grpc::ServerUnaryReactor *PubsubCommandBatch(
grpc::CallbackServerContext *context,
const rpc::PubsubCommandBatchRequest *request,
rpc::PubsubCommandBatchReply *reply) override {
const auto subscriber_id = UniqueID::FromBinary(request->subscriber_id());
auto *reactor = context->DefaultReactor();
for (const auto &command : request->commands()) {
if (command.has_unsubscribe_message()) {
publisher_->UnregisterSubscription(command.channel_type(),
subscriber_id,
command.key_id().empty()
? std::nullopt
: std::make_optional(command.key_id()));
} else if (command.has_subscribe_message()) {
publisher_->RegisterSubscription(command.channel_type(),
subscriber_id,
command.key_id().empty()
? std::nullopt
: std::make_optional(command.key_id()));
}
}
reactor->Finish(grpc::Status::OK);
return reactor;
}

pubsub::Publisher &GetPublisher() { return *publisher_; }

private:
std::unique_ptr<pubsub::Publisher> publisher_;
};

// Subscriber client for integration testing
class CallbackSubscriberClient final : public pubsub::SubscriberClientInterface {
public:
explicit CallbackSubscriberClient(const std::string &address) {
auto channel = grpc::CreateChannel(address, grpc::InsecureChannelCredentials());
stub_ = rpc::SubscriberService::NewStub(std::move(channel));
}

~CallbackSubscriberClient() final = default;

void PubsubLongPolling(
rpc::PubsubLongPollingRequest &&request,
const rpc::ClientCallback<rpc::PubsubLongPollingReply> &callback) final {
auto *context = new grpc::ClientContext;
auto *reply = new rpc::PubsubLongPollingReply;
stub_->async()->PubsubLongPolling(
context, &request, reply, [callback, context, reply](grpc::Status s) {
callback(GrpcStatusToRayStatus(s), std::move(*reply));
delete reply;
delete context;
});
}

void PubsubCommandBatch(
rpc::PubsubCommandBatchRequest &&request,
const rpc::ClientCallback<rpc::PubsubCommandBatchReply> &callback) final {
auto *context = new grpc::ClientContext;
auto *reply = new rpc::PubsubCommandBatchReply;
stub_->async()->PubsubCommandBatch(
context, &request, reply, [callback, context, reply](grpc::Status s) {
callback(GrpcStatusToRayStatus(s), std::move(*reply));
delete reply;
delete context;
});
}

std::string DebugString() const { return ""; }

private:
std::unique_ptr<rpc::SubscriberService::Stub> stub_;
};

// Integration test for GCS NodeManager with NodeAddressAndLiveness pubsub
TEST_F(GcsNodeManagerTest, TestNodeAddressAndLivenessHighChurn) {
// Set up pubsub infrastructure
IOServicePool io_service_pool(3);
io_service_pool.Run();
auto periodical_runner = PeriodicalRunner::Create(*io_service_pool.Get());

const std::string address = "127.0.0.1:7929";
rpc::Address address_proto;
address_proto.set_ip_address("127.0.0.1");
address_proto.set_port(7929);
address_proto.set_worker_id(UniqueID::FromRandom().Binary());

// Create publisher - we'll keep this alive throughout the test
std::shared_ptr<pubsub::Publisher> publisher = std::make_shared<pubsub::Publisher>(
/*channels=*/
std::vector<rpc::ChannelType>{
rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL,
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
},
/*periodical_runner=*/*periodical_runner,
/*get_time_ms=*/[]() -> double { return absl::ToUnixMicros(absl::Now()); },
/*subscriber_timeout_ms=*/absl::ToInt64Microseconds(absl::Seconds(30)),
/*batch_size=*/100);

// Non-owning deleter struct
struct NoOpDeleter {
void operator()(pubsub::Publisher *) const {}
void operator()(pubsub::PublisherInterface *) const {}
};

// Set up gRPC server with a non-owning reference to the publisher
auto subscriber_service = std::make_unique<SubscriberServiceImpl>(
std::unique_ptr<pubsub::Publisher, NoOpDeleter>(publisher.get()));
grpc::ServerBuilder builder;
builder.AddListeningPort(address, grpc::InsecureServerCredentials());
builder.RegisterService(subscriber_service.get());
auto server = builder.BuildAndStart();

// Create GCS publisher that wraps the real pubsub publisher (also non-owning)
auto gcs_publisher_wrapper = std::make_unique<pubsub::GcsPublisher>(
std::unique_ptr<pubsub::PublisherInterface, NoOpDeleter>(publisher.get()));

// Create GCS NodeManager with the real publisher
gcs::GcsNodeManager node_manager(gcs_publisher_wrapper.get(),
gcs_table_storage_.get(),
*io_context_,
client_pool_.get(),
ClusterID::Nil(),
*fake_ray_event_recorder_,
"test_session_name");

// Create subscriber
auto subscriber = std::make_unique<pubsub::Subscriber>(
UniqueID::FromRandom(),
/*channels=*/
std::vector<rpc::ChannelType>{
rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL,
},
/*max_command_batch_size=*/3,
/*get_client=*/
[](const rpc::Address &address) {
return std::make_shared<CallbackSubscriberClient>(
BuildAddress(address.ip_address(), address.port()));
},
io_service_pool.Get());

// Track received messages
absl::Mutex mu;
std::vector<rpc::GcsNodeAddressAndLiveness> received_messages;

// Subscribe to NodeAddressAndLiveness channel
std::atomic<bool> subscribed(false);
subscriber->Subscribe(
std::make_unique<rpc::SubMessage>(),
rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL,
address_proto,
/*key_id=*/std::nullopt,
/*subscribe_done_callback=*/
[&subscribed](Status status) {
RAY_CHECK_OK(status);
subscribed = true;
},
/*subscribe_item_callback=*/
[&mu, &received_messages](const rpc::PubMessage &msg) {
absl::MutexLock lock(&mu);
received_messages.push_back(msg.node_address_and_liveness_message());
},
/*subscription_failure_callback=*/
[](const std::string &, const Status &status) { RAY_CHECK_OK(status); });

// Wait for subscription to complete
while (!subscribed) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}

// Simulate high node churn: 1000 nodes being added, removed, and re-added
const int num_nodes = 1000;
std::vector<std::shared_ptr<rpc::GcsNodeInfo>> nodes;
for (int i = 0; i < num_nodes; i++) {
auto node = GenNodeInfo();
nodes.push_back(node);
}

RAY_LOG(INFO) << "Adding " << num_nodes << " nodes...";
// Phase 1: Add all nodes
for (int i = 0; i < num_nodes; i++) {
rpc::RegisterNodeRequest register_request;
register_request.mutable_node_info()->CopyFrom(*nodes[i]);
rpc::RegisterNodeReply register_reply;
auto send_reply_callback =
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};
node_manager.HandleRegisterNode(
register_request, &register_reply, send_reply_callback);
// Process async operations
while (io_context_->poll() > 0) {
}
}

RAY_LOG(INFO) << "Removing " << num_nodes << " nodes...";
// Phase 2: Remove all nodes
for (int i = 0; i < num_nodes; i++) {
rpc::UnregisterNodeRequest unregister_request;
unregister_request.set_node_id(nodes[i]->node_id());
unregister_request.mutable_node_death_info()->set_reason(
rpc::NodeDeathInfo::UNEXPECTED_TERMINATION);
rpc::UnregisterNodeReply unregister_reply;
auto send_reply_callback =
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};
node_manager.HandleUnregisterNode(
unregister_request, &unregister_reply, send_reply_callback);
// Process async operations
while (io_context_->poll() > 0) {
}
}

RAY_LOG(INFO) << "Re-adding " << num_nodes << " nodes...";
// Phase 3: Re-add all nodes
for (int i = 0; i < num_nodes; i++) {
rpc::RegisterNodeRequest register_request;
register_request.mutable_node_info()->CopyFrom(*nodes[i]);
rpc::RegisterNodeReply register_reply;
auto send_reply_callback =
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};
node_manager.HandleRegisterNode(
register_request, &register_reply, send_reply_callback);
// Process async operations
while (io_context_->poll() > 0) {
}
}

// Wait for all messages to be received
const size_t expected_messages = num_nodes * 3; // ALIVE + DEAD + ALIVE
RAY_LOG(INFO) << "Waiting for " << expected_messages << " messages...";

// Wait up to 60 seconds for all messages
auto start_time = std::chrono::steady_clock::now();
while (true) {
{
absl::MutexLock lock(&mu);
if (received_messages.size() >= expected_messages) {
break;
}
}
auto elapsed = std::chrono::steady_clock::now() - start_time;
if (elapsed > std::chrono::seconds(60)) {
absl::MutexLock lock(&mu);
FAIL() << "Timeout: Expected " << expected_messages
<< " messages, but received only " << received_messages.size();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}

absl::MutexLock lock(&mu);
RAY_LOG(INFO) << "Received all " << received_messages.size() << " messages.";

// Verify that we received all expected messages
EXPECT_EQ(received_messages.size(), expected_messages);

// Verify the sequence: for each node, we should see ALIVE -> DEAD -> ALIVE
std::map<std::string, std::vector<rpc::GcsNodeInfo::GcsNodeState>> node_states;
for (const auto &msg : received_messages) {
node_states[msg.node_id()].push_back(msg.state());
}

EXPECT_EQ(node_states.size(), static_cast<size_t>(num_nodes));
int correct_sequences = 0;
for (const auto &[node_id, states] : node_states) {
if (states.size() == 3 && states[0] == rpc::GcsNodeInfo::ALIVE &&
states[1] == rpc::GcsNodeInfo::DEAD && states[2] == rpc::GcsNodeInfo::ALIVE) {
correct_sequences++;
} else {
RAY_LOG(ERROR)
<< "Node " << NodeID::FromBinary(node_id).Hex()
<< " has incorrect state sequence. Size: " << states.size() << " States: "
<< (states.size() > 0 ? rpc::GcsNodeInfo::GcsNodeState_Name(states[0]) : "none")
<< " -> "
<< (states.size() > 1 ? rpc::GcsNodeInfo::GcsNodeState_Name(states[1]) : "none")
<< " -> "
<< (states.size() > 2 ? rpc::GcsNodeInfo::GcsNodeState_Name(states[2])
: "none");
}
}

EXPECT_EQ(correct_sequences, num_nodes)
<< "Only " << correct_sequences << " out of " << num_nodes
<< " nodes have correct ALIVE->DEAD->ALIVE sequence";

// Clean up
server->Shutdown();
io_service_pool.Stop();
}

} // namespace ray
1 change: 1 addition & 0 deletions src/ray/pubsub/subscriber.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class Subscriber : public SubscriberInterface {
///

FRIEND_TEST(IntegrationTest, SubscribersToOneIDAndAllIDs);
FRIEND_TEST(IntegrationTest, NodeAddressAndLivenessHighChurn);
FRIEND_TEST(SubscriberTest, TestBasicSubscription);
FRIEND_TEST(SubscriberTest, TestSingleLongPollingWithMultipleSubscriptions);
FRIEND_TEST(SubscriberTest, TestMultiLongPollingWithTheSameSubscription);
Expand Down
Loading