Skip to content

Commit

Permalink
TL/MLX5: generate schedule for zcopy allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Feb 4, 2025
1 parent 73651ea commit 1c7ff90
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec {
int max_eager;
int cuda_mem_enabled;
int one_sided_reliability_enable;
int truly_zero_copy_allgather_enabled;
int mcast_prepost_bucket_size;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -276,6 +278,8 @@ typedef struct ucc_tl_mlx5_mcast_allgather_comm {
uint32_t coll_counter;
uint32_t max_num_packets;
uint32_t max_push_send;
uint8_t truly_zero_copy_allgather_enabled;
uint32_t mcast_prepost_bucket_size;
} ucc_tl_mlx5_mcast_allgather_comm_t;

typedef struct ucc_tl_mlx5_mcast_bcast_comm {
Expand Down Expand Up @@ -434,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_memory_type_t buf_mem_type;
enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme;
uint32_t ag_counter;
int concurrency_level;
int mcast_prepost_bucket_size;
int state;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule;
int total_steps;
Expand Down
174 changes: 174 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,169 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
}
}

static inline ucc_status_t
ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{

if (req->concurrency_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"num_packets (%d) must be a multiple of mcast_prepost_bucket_size (%d) "
"when concurrency_level (%d) is even.",
req->num_packets, req->mcast_prepost_bucket_size, req->concurrency_level);
return UCC_ERR_NOT_SUPPORTED;
}

if (comm->commsize % req->concurrency_level != 0) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"team size (%d) must be a multiple of concurrency_level (%d).",
comm->commsize, req->concurrency_level);
return UCC_ERR_NOT_SUPPORTED;
}

if (req->length % comm->max_per_packet != 0) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"length (%ld) must be a multiple of max_per_packet (%d).",
req->length, comm->max_per_packet);
return UCC_ERR_NOT_SUPPORTED;
}

if (req->mcast_prepost_bucket_size * req->concurrency_level * 2 > comm->params.rx_depth) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"we only support the case prepost_bucket_size * concurrency_level * 2 > rx_depth, "
"but got: prepost_bucket_size=%d, concurrency_level=%d, "
"rx_depth=%d",
req->mcast_prepost_bucket_size, req->concurrency_level,
comm->params.rx_depth);
return UCC_ERR_NOT_SUPPORTED;
}

return UCC_OK;
}


/*
* at each stage half of the mcast groups are ready for receiving mcast
* packets while the other half are getting prepared by preposting recv
* buffers
*/
static inline ucc_status_t
ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_rank_t root = 0;
int offset = 0;
ucc_status_t status;
ucc_rank_t j, i;
int total_steps;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *schedule;

ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled);

req->concurrency_level = comm->mcast_group_count / 2;
req->concurrency_level = ucc_min(req->concurrency_level, ONE_SIDED_MAX_CONCURRENT_LEVEL);
req->concurrency_level = ucc_min(req->concurrency_level, comm->commsize);

if (req->concurrency_level == 0) {
tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather");
return UCC_ERR_NOT_SUPPORTED;
}

req->mcast_prepost_bucket_size =
ucc_min(req->num_packets, comm->allgather_comm.mcast_prepost_bucket_size);

status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req);
if (status != UCC_OK) {
return status;
}

/* calculate the schedule and details of what we should
* mcast and prepost to which mcast group at each stage*/
total_steps = req->num_packets * (comm->commsize / req->concurrency_level)
/ req->mcast_prepost_bucket_size + 1;

schedule = ucc_calloc(1,
sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) *
total_steps, "sched");
if (!schedule) {
tl_warn(comm->lib, "cannot allocate memory for schedule list");
return UCC_ERR_NO_MEMORY;
}

/* generate schedule */
for (i = 0; i < total_steps; i++) {
if (i < total_steps - 1) {
for (j = 0; j < req->concurrency_level; j++) {
schedule[i].prepost_buf_op[j].group_id =
j + req->concurrency_level * (i % 2);
schedule[i].prepost_buf_op[j].offset =
offset * comm->max_per_packet;
schedule[i].prepost_buf_op[j].root = root + j;
schedule[i].prepost_buf_op[j].count =
req->mcast_prepost_bucket_size;
}
} else {
schedule[i].prepost_buf_op_done = 1;
}

if (i > 0) {
for (j = 0; j < req->concurrency_level; j++) {
schedule[i].multicast_op[j].group_id =
schedule[i - 1].prepost_buf_op[j].group_id;
schedule[i].multicast_op[j].offset =
schedule[i - 1].prepost_buf_op[j].offset;
schedule[i].multicast_op[j].offset_left =
schedule[i - 1].prepost_buf_op[j].offset;
schedule[i].multicast_op[j].root =
schedule[i - 1].prepost_buf_op[j].root;
schedule[i].multicast_op[j].to_send_left =
schedule[i - 1].prepost_buf_op[j].count;
schedule[i].multicast_op[j].to_recv =
schedule[i - 1].prepost_buf_op[j].count;
schedule[i].to_recv += schedule[i].multicast_op[j].to_recv;
if (schedule[i].multicast_op[j].root == comm->rank) {
schedule[i].to_send += schedule[i].multicast_op[j].to_send_left;
}
}
}

if (!schedule[i].to_send || !schedule[i].to_recv) {
schedule[i].multicast_op_done = 1;
}

offset += req->mcast_prepost_bucket_size;

if (offset == req->num_packets) {
offset = 0;
root = (root + req->concurrency_level) % comm->commsize;
}
}

tl_trace(comm->lib,
"generated the schedule for pipelined zero copy allgather with total_steps %d",
total_steps);
schedule->total_steps = total_steps;
req->total_steps = total_steps;
req->ag_schedule = schedule;
tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize);
ucc_assert(req->recv_rreg == NULL);

status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length *
comm->commsize, &reg);
if (UCC_OK != status) {
tl_warn(comm->lib, "unable to register receive buffer %p of size %ld",
req->rptr, req->length * comm->commsize);
ucc_free(schedule);
return status;
}

req->recv_rreg = reg;
req->recv_mr = reg->mr;

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
ucc_coll_task_t *coll_task = &(task->super);
Expand Down Expand Up @@ -357,6 +520,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
req->to_send = req->num_packets;
req->to_recv = comm->commsize * req->num_packets;

if (comm->allgather_comm.truly_zero_copy_allgather_enabled) {
status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req);
if (UCC_OK != status) {
goto failed;
}
}

comm->allgather_comm.coll_counter++;

task->coll_mcast.req_handle = req;
Expand All @@ -367,6 +537,10 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)

failed:
tl_warn(UCC_TASK_LIB(task), "mcast init allgather failed:%d", status);
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req) {
ucc_mpool_put(req);
}
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

memcpy(&comm->params, conf_params, sizeof(*conf_params));

comm->allgather_comm.mcast_prepost_bucket_size
= conf_params->mcast_prepost_bucket_size;
comm->allgather_comm.truly_zero_copy_allgather_enabled
= conf_params->truly_zero_copy_allgather_enabled;
comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable;
comm->bcast_comm.wsize = conf_params->wsize;
comm->allgather_comm.max_push_send = conf_params->max_push_send;
Expand Down
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline"
" in truly zero copy mcast allgather design",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size),
UCC_CONFIG_TYPE_INT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand Down

0 comments on commit 1c7ff90

Please sign in to comment.