Skip to content

[WIP]Support for specifying RDMA devices when multiple RDMA devices are present. #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -696,6 +696,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--enable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
@@ -719,6 +720,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--disable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
51 changes: 34 additions & 17 deletions src/client/rpc_client.cc
Original file line number Diff line number Diff line change
@@ -94,16 +94,18 @@ Status RPCClient::Connect(const std::string& rpc_endpoint) {
Status RPCClient::Connect(const std::string& rpc_endpoint,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
return this->Connect(rpc_endpoint, RootSessionID(), username, password,
rdma_endpoint);
rdma_endpoint, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& rpc_endpoint,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
size_t pos = rpc_endpoint.find(":");
std::string host, port;
if (pos == std::string::npos) {
@@ -125,28 +127,32 @@ Status RPCClient::Connect(const std::string& rpc_endpoint,

return this->Connect(host, static_cast<uint32_t>(std::stoul(port)),
session_id, username, password, rdma_host,
static_cast<uint32_t>(std::stoul(rdma_port)));
static_cast<uint32_t>(std::stoul(rdma_port)),
src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), "", "", rdma_host,
rdma_port);
rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), username, password,
rdma_host, rdma_port);
rdma_host, rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);
std::string rpc_endpoint = host + ":" + std::to_string(port);
RETURN_ON_ASSERT(!connected_ || rpc_endpoint == rpc_endpoint_);
@@ -183,7 +189,8 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
instance_id_ = UnspecifiedInstanceID() - 1;

if (rdma_host.length() > 0) {
Status status = ConnectRDMA(rdma_host, rdma_port);
src_rdma_endpoint_ = src_rdma_ednpoint;
Status status = ConnectRDMA(rdma_host, rdma_port, src_rdma_ednpoint);
if (status.ok()) {
rdma_endpoint_ = rdma_host + ":" + std::to_string(rdma_port);
std::cout << "Connected to RPC server: " << rpc_endpoint
@@ -192,33 +199,38 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
} else {
std::cout << "Connect RDMA server failed! Fall back to RPC mode. Error:"
<< status.message() << std::endl;
std::cout << "Failed src_rdma_ednpoint: " << src_rdma_ednpoint
<< std::endl;
}
}

return Status::OK();
}

Status RPCClient::ConnectRDMA(const std::string& rdma_host,
uint32_t rdma_port) {
Status RPCClient::ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint) {
if (this->rdma_connected_) {
return Status::OK();
}

RETURN_ON_ERROR(RDMAClientCreator::Create(this->rdma_client_, rdma_host,
static_cast<int>(rdma_port)));
static_cast<int>(rdma_port),
src_rdma_endpoint));

int retry = 0;
do {
if (this->rdma_client_->Connect().ok()) {
Status status = this->rdma_client_->Connect();
if (status.ok()) {
break;
}
if (retry == 10) {
return Status::Invalid("Failed to connect to RDMA server.");
}
retry++;
usleep(300 * 1000);
std::cout << "Connect rdma server failed! retry: " << retry << " times."
<< std::endl;
std::cout << "Connect rdma server failed! Error:" + status.message() +
"retry: "
<< retry << " times." << std::endl;
} while (true);
this->rdma_connected_ = true;
return Status::OK();
@@ -272,6 +284,9 @@ Status RPCClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) {

Status RPCClient::StopRDMA() {
if (!rdma_connected_) {
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));
return Status::OK();
}
rdma_connected_ = false;
@@ -285,7 +300,9 @@ Status RPCClient::StopRDMA() {

RETURN_ON_ERROR(rdma_client_->Stop());
RETURN_ON_ERROR(rdma_client_->Close());
RETURN_ON_ERROR(RDMAClientCreator::Release(rdma_endpoint_));
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));

return Status::OK();
}
19 changes: 13 additions & 6 deletions src/client/rpc_client.h
Original file line number Diff line number Diff line change
@@ -88,7 +88,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& rpc_endpoint, std::string const& username,
std::string const& password,
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP endpoint `rpc_endpoint`.
@@ -104,7 +105,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& rpc_endpoint, const SessionID session_id,
std::string const& username = "",
std::string const& password = "",
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
@@ -117,7 +119,8 @@ class RPCClient final : public ClientBase {
* @return Status that indicates whether the connect has succeeded.
*/
Status Connect(const std::string& host, uint32_t port,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
@@ -131,7 +134,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& host, uint32_t port,
std::string const& username, std::string const& password,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
@@ -147,7 +151,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& host, uint32_t port,
const SessionID session_id, std::string const& username = "",
std::string const& password = "",
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Create a new client using self endpoint.
@@ -436,7 +441,8 @@ class RPCClient final : public ClientBase {
const std::string rdma_endpoint() { return rdma_endpoint_; }

private:
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port);
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint = "");

Status StopRDMA();

@@ -479,6 +485,7 @@ class RPCClient final : public ClientBase {
std::string rdma_endpoint_;
std::shared_ptr<RDMAClient> rdma_client_;
mutable bool rdma_connected_ = false;
std::string src_rdma_endpoint_ = "";

friend class Client;
};
23 changes: 19 additions & 4 deletions src/common/rdma/rdma.cc
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ Status IRDMA::RegisterMemory(fid_mr** mr, fid_domain* domain, void* address,
return Status::IOError("Failed to register memory region:" +
std::to_string(ret));
}
CHECK_ERROR(!ret, "Failed to register memory region:" + std::to_string(ret));
CHECK_ERROR(ret, "Failed to register memory region:" + std::to_string(ret));

mr_desc = fi_mr_desc(*mr);

@@ -177,10 +177,25 @@ int IRDMA::GetCompletion(fid_cq* cq, int timeout, void** context) {
return ret < 0 ? ret : 0;
}

void IRDMA::FreeInfo(fi_info* info) {
if (info) {
fi_freeinfo(info);
void IRDMA::FreeInfo(fi_info* info, bool is_hints) {
if (!info) {
return;
}

if (is_hints) {
if (info->src_addr) {
free(info->src_addr);
info->src_addr = nullptr;
info->src_addrlen = 0;
}
if (info->dest_addr) {
free(info->dest_addr);
info->dest_addr = nullptr;
info->dest_addrlen = 0;
}
}

fi_freeinfo(info);
}

} // namespace vineyard
2 changes: 1 addition & 1 deletion src/common/rdma/rdma.h
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ class IRDMA {

static int GetCompletion(fid_cq* cq, int timeout, void** context);

static void FreeInfo(fi_info* info);
static void FreeInfo(fi_info* info, bool is_hints);

template <typename FIDType>
static Status CloseResource(FIDType* res, const char* resource_name) {
119 changes: 80 additions & 39 deletions src/common/rdma/rdma_client.cc
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ Status RDMAClient::Make(std::shared_ptr<RDMAClient>& ptr,

fi_eq_attr eq_attr = {0};
eq_attr.wait_obj = FI_WAIT_UNSPEC;
CHECK_ERROR(!fi_eq_open(ptr->fabric, &eq_attr, &ptr->eq, NULL),
CHECK_ERROR(fi_eq_open(ptr->fabric, &eq_attr, &ptr->eq, NULL),
"fi_eq_open failed.");

fi_cq_attr cq_attr = {0};
@@ -54,25 +54,25 @@ Status RDMAClient::Make(std::shared_ptr<RDMAClient>& ptr,
cq_attr.wait_obj = FI_WAIT_NONE;
cq_attr.wait_cond = FI_CQ_COND_NONE;
cq_attr.size = ptr->fi->rx_attr->size;
CHECK_ERROR(!fi_cq_open(ptr->domain, &cq_attr, &ptr->rxcq, NULL),
CHECK_ERROR(fi_cq_open(ptr->domain, &cq_attr, &ptr->rxcq, NULL),
"fi_cq_open failed.");

cq_attr.size = ptr->fi->tx_attr->size;
CHECK_ERROR(!fi_cq_open(ptr->domain, &cq_attr, &ptr->txcq, NULL),
CHECK_ERROR(fi_cq_open(ptr->domain, &cq_attr, &ptr->txcq, NULL),
"fi_cq_open failed.");

CHECK_ERROR(!fi_endpoint(ptr->domain, ptr->fi, &ptr->ep, NULL),
CHECK_ERROR(fi_endpoint(ptr->domain, ptr->fi, &ptr->ep, NULL),
"fi_endpoint failed.");

CHECK_ERROR(!fi_ep_bind(ptr->ep, &ptr->eq->fid, 0), "fi_ep_bind eq failed.");
CHECK_ERROR(fi_ep_bind(ptr->ep, &ptr->eq->fid, 0), "fi_ep_bind eq failed.");

CHECK_ERROR(!fi_ep_bind(ptr->ep, &ptr->rxcq->fid, FI_RECV),
CHECK_ERROR(fi_ep_bind(ptr->ep, &ptr->rxcq->fid, FI_RECV),
"fi_ep_bind rxcq failed.");

CHECK_ERROR(!fi_ep_bind(ptr->ep, &ptr->txcq->fid, FI_SEND),
CHECK_ERROR(fi_ep_bind(ptr->ep, &ptr->txcq->fid, FI_SEND),
"fi_ep_bind txcq failed.");

CHECK_ERROR(!fi_enable(ptr->ep), "fi_enable failed.");
CHECK_ERROR(fi_enable(ptr->ep), "fi_enable failed.");

ptr->rx_msg_buffer = new char[ptr->fi->rx_attr->size];
if (!ptr->rx_msg_buffer) {
@@ -93,13 +93,13 @@ Status RDMAClient::Make(std::shared_ptr<RDMAClient>& ptr,
}

Status RDMAClient::Connect() {
CHECK_ERROR(!fi_connect(ep, fi->dest_addr, NULL, 0), "fi_connect failed.");
CHECK_ERROR(fi_connect(ep, fi->dest_addr, NULL, 0), "fi_connect failed.");

fi_eq_cm_entry entry;
uint32_t event;

CHECK_ERROR(
fi_eq_sread(eq, &event, &entry, sizeof(entry), -1, 0) == sizeof(entry),
fi_eq_sread(eq, &event, &entry, sizeof(entry), -1, 0) != sizeof(entry),
"fi_eq_sread failed.");

if (event != FI_CONNECTED || entry.fid != &ep->fid) {
@@ -226,8 +226,14 @@ Status RDMAClient::Close() {
RETURN_ON_ERROR(CloseResource(rxcq, "receive comeple queue"));
RETURN_ON_ERROR(CloseResource(eq, "event queue"));

delete rx_msg_buffer;
delete tx_msg_buffer;
if (rx_msg_buffer) {
delete rx_msg_buffer;
rx_msg_buffer = nullptr;
}
if (tx_msg_buffer) {
delete tx_msg_buffer;
tx_msg_buffer = nullptr;
}

return Status::OK();
}
@@ -237,64 +243,93 @@ Status RDMAClientCreator::Create(std::shared_ptr<RDMAClient>& ptr,
int port) {
std::string server_endpoint = server_address + ":" + std::to_string(port);
std::lock_guard<std::mutex> lock(servers_mtx_);
if (servers_.find(server_endpoint) == servers_.end()) {
std::string connection_key = buildConnectionKey(server_endpoint);
if (servers_.find(connection_key) == servers_.end()) {
RDMARemoteNodeInfo node_info;
RETURN_ON_ERROR(
CreateRDMARemoteNodeInfo(node_info, hints, server_address, port));
RETURN_ON_ERROR(RDMAClient::Make(ptr, node_info));

servers_[server_endpoint] = node_info;
servers_[connection_key] = node_info;
} else {
RETURN_ON_ERROR(RDMAClient::Make(ptr, servers_[server_endpoint]));
RETURN_ON_ERROR(RDMAClient::Make(ptr, servers_[connection_key]));
}
return Status::OK();
}

Status RDMAClientCreator::Create(std::shared_ptr<RDMAClient>& ptr,
std::string server_address, int port) {
std::string server_address, int port,
std::string src_endpoint) {
std::string server_endpoint = server_address + ":" + std::to_string(port);
std::lock_guard<std::mutex> lock(servers_mtx_);
if (servers_.find(server_endpoint) == servers_.end()) {
std::string connection_key =
buildConnectionKey(server_endpoint, src_endpoint);
if (servers_.find(connection_key) == servers_.end()) {
RDMARemoteNodeInfo node_info;
RETURN_ON_ERROR(CreateRDMARemoteNodeInfo(node_info, server_address, port));
RETURN_ON_ERROR(CreateRDMARemoteNodeInfo(node_info, server_address, port,
src_endpoint));
RETURN_ON_ERROR(RDMAClient::Make(ptr, node_info));
node_info.refcnt++;

servers_[server_endpoint] = node_info;
servers_[connection_key] = node_info;
} else {
RETURN_ON_ERROR(RDMAClient::Make(ptr, servers_[server_endpoint]));
servers_[server_endpoint].refcnt++;
RETURN_ON_ERROR(RDMAClient::Make(ptr, servers_[connection_key]));
servers_[connection_key].refcnt++;
}
return Status::OK();
}

Status RDMAClientCreator::CreateRDMARemoteNodeInfo(RDMARemoteNodeInfo& info,
fi_info* hints,
std::string server_address,
int port) {
int port,
std::string src_endpoint) {
if (!hints) {
return Status::Invalid("Invalid fabric hints info.");
}

CHECK_ERROR(!fi_getinfo(VINEYARD_FIVERSION, server_address.c_str(),
std::to_string(port).c_str(), 0, hints,
reinterpret_cast<fi_info**>(&(info.fi))),
"fi_getinfo failed")
if (!src_endpoint.empty()) {
size_t pos = src_endpoint.find(':');
if (pos == std::string::npos) {
return Status::Invalid("Invalid source endpoint:" + src_endpoint);
}

std::string src_host = src_endpoint.substr(0, pos);
uint32_t src_port = std::stoi(src_endpoint.substr(pos + 1));
fi_info* src_fi;
CHECK_ERROR(fi_getinfo(VINEYARD_FIVERSION, src_host.c_str(),
std::to_string(src_port).c_str(), FI_SOURCE, hints,
reinterpret_cast<fi_info**>(&src_fi)),
"fi_getinfo failed with client src endpoint.");
hints->src_addrlen = src_fi->src_addrlen;
hints->src_addr = malloc(src_fi->src_addrlen);
memcpy(hints->src_addr, src_fi->src_addr, src_fi->src_addrlen);
IRDMA::FreeInfo(src_fi, false);
}
CHECK_ERROR(fi_getinfo(VINEYARD_FIVERSION, server_address.c_str(),
std::to_string(port).c_str(), 0, hints,
reinterpret_cast<fi_info**>(&(info.fi))),
"fi_getinfo failed");
fi_info* fi = reinterpret_cast<fi_info*>(info.fi);
if (fi != nullptr && fi->nic != nullptr) {
std::cout << "Open device name:" << fi->nic->device_attr->name << std::endl;
}

CHECK_ERROR(!fi_fabric(reinterpret_cast<fi_info*>(info.fi)->fabric_attr,
reinterpret_cast<fid_fabric**>(&info.fabric), NULL),
CHECK_ERROR(fi_fabric(reinterpret_cast<fi_info*>(info.fi)->fabric_attr,
reinterpret_cast<fid_fabric**>(&info.fabric), NULL),
"fi_fabric failed.");

CHECK_ERROR(!fi_domain(reinterpret_cast<fid_fabric*>(info.fabric),
reinterpret_cast<fi_info*>(info.fi),
reinterpret_cast<fid_domain**>(&info.domain), NULL),
CHECK_ERROR(fi_domain(reinterpret_cast<fid_fabric*>(info.fabric),
reinterpret_cast<fi_info*>(info.fi),
reinterpret_cast<fid_domain**>(&info.domain), NULL),
"fi_domain failed.");
return Status::OK();
}

Status RDMAClientCreator::CreateRDMARemoteNodeInfo(RDMARemoteNodeInfo& info,
std::string server_address,
int port) {
int port,
std::string src_endpoint) {
fi_info* hints = fi_allocinfo();
if (!hints) {
return Status::Invalid("Failed to allocate fabric info.");
@@ -311,11 +346,16 @@ Status RDMAClientCreator::CreateRDMARemoteNodeInfo(RDMARemoteNodeInfo& info,
hints->tx_attr->tclass = FI_TC_BULK_DATA;
hints->ep_attr->type = FI_EP_MSG;
hints->fabric_attr = new fi_fabric_attr;
hints->src_addr = NULL;
hints->src_addrlen = 0;
hints->dest_addr = NULL;
hints->dest_addrlen = 0;
memset(hints->fabric_attr, 0, sizeof *(hints->fabric_attr));
hints->fabric_attr->prov_name = strdup("verbs");

RETURN_ON_ERROR(CreateRDMARemoteNodeInfo(info, hints, server_address, port));
IRDMA::FreeInfo(hints);
RETURN_ON_ERROR(CreateRDMARemoteNodeInfo(info, hints, server_address, port,
src_endpoint));
IRDMA::FreeInfo(hints, true);
return Status::OK();
}

@@ -327,10 +367,11 @@ Status RDMAClientCreator::Clear() {
return Status::OK();
}

Status RDMAClientCreator::Release(std::string rdma_endpoint) {
Status RDMAClientCreator::Release(std::string connection_key) {
std::lock_guard<std::mutex> lock(servers_mtx_);
if (servers_.find(rdma_endpoint) != servers_.end()) {
RDMARemoteNodeInfo& info = servers_[rdma_endpoint];
if (servers_.find(connection_key) != servers_.end()) {
std::cout << "Release RDMA client:" << connection_key << std::endl;
RDMARemoteNodeInfo& info = servers_[connection_key];

info.refcnt--;
if (info.refcnt == 0) {
@@ -341,8 +382,8 @@ Status RDMAClientCreator::Release(std::string rdma_endpoint) {
reinterpret_cast<fid_domain*>(info.domain), "domain"));
RETURN_ON_ERROR(IRDMA::CloseResource(
reinterpret_cast<fid_fabric*>(info.fabric), "fabric"));
IRDMA::FreeInfo(reinterpret_cast<fi_info*>(info.fi));
servers_.erase(rdma_endpoint);
IRDMA::FreeInfo(reinterpret_cast<fi_info*>(info.fi), false);
servers_.erase(connection_key);
}
}

15 changes: 11 additions & 4 deletions src/common/rdma/rdma_client.h
Original file line number Diff line number Diff line change
@@ -105,13 +105,18 @@ class RDMAClient : public IRDMA {

class RDMAClientCreator {
public:
static Status Create(std::shared_ptr<RDMAClient>& ptr,
std::string server_address, int port);
static Status Create(std::shared_ptr<RDMAClient>& ptr, std::string dst_addr,
int port, std::string src_endpoint = "");

static Status Release(std::string rdma_endpoint);

static Status Clear();

static std::string buildConnectionKey(std::string rdma_endpoint,
std::string src_endpoint = "") {
return src_endpoint + "->" + rdma_endpoint;
}

private:
#if defined(__linux__)
RDMAClientCreator() = delete;
@@ -121,10 +126,12 @@ class RDMAClientCreator {

static Status CreateRDMARemoteNodeInfo(RDMARemoteNodeInfo& info,
fi_info* hints,
std::string server_address, int port);
std::string server_address, int port,
std::string src_endpoint = "");

static Status CreateRDMARemoteNodeInfo(RDMARemoteNodeInfo& info,
std::string server_address, int port);
std::string server_address, int port,
std::string src_endpoint = "");

static std::map<std::string, RDMARemoteNodeInfo> servers_;
static std::mutex servers_mtx_;
65 changes: 40 additions & 25 deletions src/common/rdma/rdma_server.cc
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ limitations under the License.
namespace vineyard {

#if defined(__linux__)
Status RDMAServer::Make(std::shared_ptr<RDMAServer>& ptr, int port) {
Status RDMAServer::Make(std::shared_ptr<RDMAServer>& ptr, int port,
std::string host) {
fi_info* hints = fi_allocinfo();
if (!hints) {
return Status::Invalid("Failed to allocate fabric info.");
@@ -49,50 +50,66 @@ Status RDMAServer::Make(std::shared_ptr<RDMAServer>& ptr, int port) {
hints->ep_attr->type = FI_EP_MSG;
hints->fabric_attr = new fi_fabric_attr;
memset(hints->fabric_attr, 0, sizeof *(hints->fabric_attr));
hints->src_addr = nullptr;
hints->src_addrlen = 0;
hints->dest_addr = nullptr;
hints->dest_addrlen = 0;
hints->fabric_attr->prov_name = strdup("verbs");

return Make(ptr, hints, port);
Status status = Make(ptr, hints, port, host);
IRDMA::FreeInfo(hints, true);
return status;
}

Status RDMAServer::Make(std::shared_ptr<RDMAServer>& ptr, fi_info* hints,
int port) {
int port, std::string host) {
if (!hints) {
return Status::Invalid("Invalid fabric hints info.");
}

ptr = std::make_shared<RDMAServer>();

uint64_t flags = 0;
CHECK_ERROR(
!fi_getinfo(VINEYARD_FIVERSION, NULL, std::to_string(port).c_str(), flags,
hints, &(ptr->fi)),
"fi_getinfo failed.");
if (!host.empty()) {
uint64_t flags = FI_SOURCE;
CHECK_ERROR(
fi_getinfo(VINEYARD_FIVERSION, host.c_str(),
std::to_string(port).c_str(), flags, hints, &(ptr->fi)),
"fi_getinfo failed.");
} else {
CHECK_ERROR(fi_getinfo(VINEYARD_FIVERSION, NULL,
std::to_string(port).c_str(), 0, hints, &(ptr->fi)),
"fi_getinfo failed.");
}
if (ptr->fi != nullptr && ptr->fi->nic != nullptr) {
std::cout << "open device name:" << ptr->fi->nic->device_attr->name
<< std::endl;
}

CHECK_ERROR(!fi_fabric(ptr->fi->fabric_attr, &ptr->fabric, NULL),
CHECK_ERROR(fi_fabric(ptr->fi->fabric_attr, &ptr->fabric, NULL),
"fi_fabric failed.");

ptr->eq_attr.wait_obj = FI_WAIT_UNSPEC;
CHECK_ERROR(!fi_eq_open(ptr->fabric, &ptr->eq_attr, &ptr->eq, NULL),
CHECK_ERROR(fi_eq_open(ptr->fabric, &ptr->eq_attr, &ptr->eq, NULL),
"fi_eq_open failed.");

CHECK_ERROR(!fi_passive_ep(ptr->fabric, ptr->fi, &ptr->pep, NULL),
CHECK_ERROR(fi_passive_ep(ptr->fabric, ptr->fi, &ptr->pep, NULL),
"fi_passive_ep failed.");

CHECK_ERROR(!fi_pep_bind(ptr->pep, &ptr->eq->fid, 0), "fi_pep_bind failed.");
CHECK_ERROR(fi_pep_bind(ptr->pep, &ptr->eq->fid, 0), "fi_pep_bind failed.");

CHECK_ERROR(!fi_domain(ptr->fabric, ptr->fi, &ptr->domain, NULL),
CHECK_ERROR(fi_domain(ptr->fabric, ptr->fi, &ptr->domain, NULL),
"fi_domain failed.");

memset(&ptr->cq_attr, 0, sizeof cq_attr);
ptr->cq_attr.format = FI_CQ_FORMAT_CONTEXT;
ptr->cq_attr.wait_obj = FI_WAIT_NONE;
ptr->cq_attr.wait_cond = FI_CQ_COND_NONE;
ptr->cq_attr.size = ptr->fi->rx_attr->size;
CHECK_ERROR(!fi_cq_open(ptr->domain, &ptr->cq_attr, &ptr->rxcq, NULL),
CHECK_ERROR(fi_cq_open(ptr->domain, &ptr->cq_attr, &ptr->rxcq, NULL),
"fi_cq_open failed.");

ptr->cq_attr.size = ptr->fi->tx_attr->size;
CHECK_ERROR(!fi_cq_open(ptr->domain, &ptr->cq_attr, &ptr->txcq, NULL),
CHECK_ERROR(fi_cq_open(ptr->domain, &ptr->cq_attr, &ptr->txcq, NULL),
"fi_cq_open failed.");

ptr->rx_msg_buffer = new char[ptr->rx_msg_size];
@@ -126,10 +143,9 @@ Status RDMAServer::Make(std::shared_ptr<RDMAServer>& ptr, fi_info* hints,

ptr->port = port;

CHECK_ERROR(!fi_listen(ptr->pep), "fi_listen failed.");
CHECK_ERROR(fi_listen(ptr->pep), "fi_listen failed.");

ptr->state = READY;

return Status::OK();
}

@@ -167,7 +183,7 @@ Status RDMAServer::Close() {
delete[] rx_buffer_bitmaps;
delete[] tx_buffer_bitmaps;

FreeInfo(fi);
FreeInfo(fi, false);

return Status::OK();
}
@@ -176,18 +192,17 @@ Status RDMAServer::PrepareConnection(VineyardEventEntry vineyard_entry) {
// prepare new ep
fid_ep* ep = NULL;
fi_info* client_fi = reinterpret_cast<fi_info*>(vineyard_entry.fi);
CHECK_ERROR(!fi_endpoint(domain, client_fi, &ep, NULL),
"fi_endpoint failed.");
CHECK_ERROR(fi_endpoint(domain, client_fi, &ep, NULL), "fi_endpoint failed.");

CHECK_ERROR(!fi_ep_bind(ep, &eq->fid, 0), "fi_ep_bind eq failed.");
CHECK_ERROR(fi_ep_bind(ep, &eq->fid, 0), "fi_ep_bind eq failed.");

CHECK_ERROR(!fi_ep_bind(ep, &rxcq->fid, FI_RECV), "fi_ep_bind rxcq failed.");
CHECK_ERROR(fi_ep_bind(ep, &rxcq->fid, FI_RECV), "fi_ep_bind rxcq failed.");

CHECK_ERROR(!fi_ep_bind(ep, &txcq->fid, FI_SEND), "fi_ep_bind txcq failed.");
CHECK_ERROR(fi_ep_bind(ep, &txcq->fid, FI_SEND), "fi_ep_bind txcq failed.");

CHECK_ERROR(!fi_enable(ep), "fi_enable failed.");
CHECK_ERROR(fi_enable(ep), "fi_enable failed.");

CHECK_ERROR(!fi_accept(ep, NULL, 0), "fi_accept failed.");
CHECK_ERROR(fi_accept(ep, NULL, 0), "fi_accept failed.");

std::lock_guard<std::mutex> lock(wait_conn_ep_map_mutex_);
wait_conn_ep_map_[&ep->fid] = ep;
8 changes: 5 additions & 3 deletions src/common/rdma/rdma_server.h
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>

#include "common/rdma/rdma.h"
@@ -35,7 +36,8 @@ class RDMAServer : public IRDMA {

~RDMAServer() {}

static Status Make(std::shared_ptr<RDMAServer>& ptr, int port);
static Status Make(std::shared_ptr<RDMAServer>& ptr, int port,
std::string host = "");

Status Send(uint64_t clientID, void* buf, size_t size, void* ctx);

@@ -89,8 +91,8 @@ class RDMAServer : public IRDMA {

private:
#if defined(__linux__)
static Status Make(std::shared_ptr<RDMAServer>& ptr, fi_info* hints,
int port);
static Status Make(std::shared_ptr<RDMAServer>& ptr, fi_info* hints, int port,
std::string host);

Status RegisterMemory(fid_mr** mr, void* address, size_t size, uint64_t& rkey,
void*& mr_desc);
13 changes: 9 additions & 4 deletions src/common/rdma/util.h
Original file line number Diff line number Diff line change
@@ -26,10 +26,15 @@ limitations under the License.

namespace vineyard {

#define CHECK_ERROR(condition, message) \
if (!(condition)) { \
return Status::Invalid(message); \
}
#define CHECK_ERROR(condition, message) \
do { \
int condition_ret = 0; \
condition_ret = (condition); \
if (condition_ret) { \
return Status::Invalid(std::string(message) + \
"ret:" + std::to_string(condition_ret)); \
} \
} while (0)

#if defined(__linux__)
#define POST(post_fn, op_str, ...) \
3 changes: 2 additions & 1 deletion src/server/async/rpc_server.cc
Original file line number Diff line number Diff line change
@@ -83,9 +83,10 @@ Status RPCServer::InitRDMA() {
if (pos == std::string::npos) {
return Status::Invalid("Invalid RDMA endpoint: " + rdma_endpoint);
}
std::string rdma_host = rdma_endpoint.substr(0, pos);
uint32_t rdma_port = std::stoi(rdma_endpoint.substr(pos + 1));

Status status = RDMAServer::Make(this->rdma_server_, rdma_port);
Status status = RDMAServer::Make(this->rdma_server_, rdma_port, rdma_host);
if (status.ok()) {
rdma_stop_ = false;
rdma_listen_thread_ = std::thread([this]() { this->doRDMAAccept(); });
3 changes: 2 additions & 1 deletion src/server/util/remote.cc
Original file line number Diff line number Diff line change
@@ -64,7 +64,8 @@ Status RemoteClient::StopRDMA() {

RETURN_ON_ERROR(rdma_client_->Stop());
RETURN_ON_ERROR(rdma_client_->Close());
RETURN_ON_ERROR(RDMAClientCreator::Release(rdma_endpoint_));
RETURN_ON_ERROR(RDMAClientCreator::Release(
RDMAClientCreator::buildConnectionKey(rdma_endpoint_)));
return Status::OK();
}

15 changes: 10 additions & 5 deletions test/rdma_blob_perf_test.cc
Original file line number Diff line number Diff line change
@@ -153,10 +153,11 @@ void CheckBlobValue(

// Test 512K~512M blob
int main(int argc, const char** argv) {
if (argc < 7) {
if (argc < 8) {
LOG(ERROR) << "usage: " << argv[0] << " <ipc_socket>"
<< " <rpc_endpoint>"
<< " <rdma_endpoint>"
<< " <rdma_src_endpoint>"
<< " <min_size>"
<< " <max_size>"
<< " <parallel>";
@@ -165,17 +166,21 @@ int main(int argc, const char** argv) {
std::string ipc_socket = std::string(argv[1]);
std::string rpc_endpoint = std::string(argv[2]);
std::string rdma_endpoint = std::string(argv[3]);
int parallel = std::stoi(argv[6]);
std::string rdma_src_endpoint = std::string(argv[4]);
std::cout << "rdma_src_endpoint: " << rdma_src_endpoint << std::endl;
int parallel = std::stoi(argv[7]);
std::vector<std::shared_ptr<RPCClient>> clients;
for (int i = 0; i < parallel; i++) {
clients.push_back(std::make_shared<RPCClient>());
VINEYARD_CHECK_OK(clients[i]->Connect(rpc_endpoint, "", "", rdma_endpoint));
VINEYARD_CHECK_OK(clients[i]->Connect(
rpc_endpoint, "", "", rdma_endpoint,
rdma_src_endpoint + ":" + std::to_string(i + 5111)));
}

uint64_t min_size = 1024 * 1024 * 2; // 512K
uint64_t max_size = 1024 * 1024 * 2; // 64M
min_size = std::stoull(argv[4]) * 1024 * 1024;
max_size = std::stoull(argv[5]) * 1024 * 1024;
min_size = std::stoull(argv[5]) * 1024 * 1024;
max_size = std::stoull(argv[6]) * 1024 * 1024;
if (min_size == 0) {
min_size = 1024 * 512;
}
3 changes: 2 additions & 1 deletion test/rdma_test.cc
Original file line number Diff line number Diff line change
@@ -209,7 +209,8 @@ void StartClient(std::string server_address) {
HelloToServer();
VINEYARD_CHECK_OK(client->Stop());
VINEYARD_CHECK_OK(client->Close());
RDMAClientCreator::Release(server_address + ":" + std::to_string(port));
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
server_address + ":" + std::to_string(port)));
}

int main(int argc, char** argv) {
1 change: 1 addition & 0 deletions test/runner.py
Original file line number Diff line number Diff line change
@@ -519,6 +519,7 @@ def run_vineyard_cpp_tests(meta, allocator, endpoints, tests):
'rdma_blob_perf_test',
'127.0.0.1:%d' % rpc_socket_port,
'127.0.0.1:%d' % rdma_port,
"",
64,
64,
1,