From 5142705390142040867819e1ce396ad855f59d60 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Thu, 4 Jan 2024 10:43:23 +0400 Subject: [PATCH] TL/NCCL: lazy init nccl comm (#851) * TL/NCCL: lazy init nccl comm * REVIEW: fix review comments --- src/components/tl/nccl/tl_nccl.c | 11 +- src/components/tl/nccl/tl_nccl.h | 14 +- src/components/tl/nccl/tl_nccl_coll.c | 10 +- src/components/tl/nccl/tl_nccl_team.c | 186 ++++++++++++++++---------- src/components/tl/ucc_tl.c | 5 + src/components/tl/ucc_tl.h | 10 ++ 6 files changed, 157 insertions(+), 79 deletions(-) diff --git a/src/components/tl/nccl/tl_nccl.c b/src/components/tl/nccl/tl_nccl.c index 8e71cdc1e2..46fdcff8e3 100644 --- a/src/components/tl/nccl/tl_nccl.c +++ b/src/components/tl/nccl/tl_nccl.c @@ -39,12 +39,17 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = { UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names) }, - {"BLOCKING", "1", - "If set to 0 will use non-blocking mode communicator behavior, " - "if set to 1 will use blocking mode", + {"BLOCKING", "yes", + "If set to no will use non-blocking mode communicator behavior, " + "if set to yes will use blocking mode", ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking), UCS_CONFIG_TYPE_BOOL}, + {"LAZY_INIT", "yes", + "Initialize NCCL communicator on first collective", + ucc_offsetof(ucc_tl_nccl_context_config_t, nccl_lazy_init), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 06f32c0371..b922601812 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -45,6 +45,15 @@ #define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) #define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB +enum { + TL_NCCL_COMM_STATE_ERROR, + TL_NCCL_COMM_STATE_OOB, + TL_NCCL_COMM_STATE_INIT_TEAM, + TL_NCCL_COMM_STATE_INIT_COMM, + TL_NCCL_COMM_STATE_DESTROY_COMM, + TL_NCCL_COMM_STATE_READY, +}; + typedef struct ucc_tl_nccl_iface { ucc_tl_iface_t super; } ucc_tl_nccl_iface_t; @@ -66,6 +75,7 @@ typedef struct ucc_tl_nccl_context_config { ucc_tl_context_config_t super; ucc_tl_nccl_completion_sync_type_t sync_type; int nccl_cfg_blocking; + int nccl_lazy_init; } ucc_tl_nccl_context_config_t; typedef struct ucc_tl_nccl_lib { @@ -85,7 +95,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *, typedef struct ucc_tl_nccl_team { ucc_tl_team_t super; - ucc_status_t comm_state; + int comm_state; ncclUniqueId *unique_id; void *oob_req; ncclComm_t nccl_comm; @@ -146,6 +156,8 @@ static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NO return UCC_OK; } +ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team); + #define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \ do { \ ncclResult_t e = _cmd; \ diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index 8a225c268b..ee3d523b0b 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -131,6 +131,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_tl_nccl_task_t **coll_task) { + ucc_tl_nccl_team_t *nccl_team = ucc_derived_of(team, ucc_tl_nccl_team_t); ucc_tl_nccl_context_t *nccl_ctx = ucc_derived_of(team->context, ucc_tl_nccl_context_t); ucc_tl_nccl_task_t *task; @@ -143,6 +144,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, return UCC_ERR_NOT_SUPPORTED; } + if (ucc_unlikely(nccl_team->comm_state != TL_NCCL_COMM_STATE_READY)) { + status = ucc_tl_nccl_comm_init(nccl_team); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } + task = ucc_mpool_get(&nccl_ctx->req_mp); if (ucc_unlikely(!task)) { tl_error(team->context->lib, "failed to get task from mpool"); @@ -206,7 +214,7 @@ ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; if (ucc_unlikely(task->super.super.status != UCC_OK)) { - team->comm_state = task->super.super.status; + team->comm_state = TL_NCCL_COMM_STATE_ERROR; } tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task); ucc_tl_nccl_free_task(task); diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index af2aff2ac6..bf8caf7e53 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -15,14 +15,17 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) { - ucc_tl_nccl_context_t *ctx = - ucc_derived_of(tl_context, ucc_tl_nccl_context_t); + ucc_tl_nccl_context_t *ctx = ucc_derived_of(tl_context, + ucc_tl_nccl_context_t); + ucc_team_oob_coll_t *oob; ucc_status_t status; ucc_rank_t size; - UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params); + UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params); + oob = &(UCC_TL_TEAM_OOB(self)); size = UCC_TL_TEAM_SIZE(self); - self->comm_state = UCC_OK; + self->stream = NULL; + self->nccl_comm = NULL; self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1), "tl_nccl_unique_id"); if (!self->unique_id) { @@ -31,6 +34,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, sizeof(ncclUniqueId) * (size + 1)); return UCC_ERR_NO_MEMORY; } + if (UCC_TL_TEAM_RANK(self) == 0) { ncclResult_t st; st = ncclGetUniqueId(&self->unique_id[size]); @@ -39,14 +43,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, memset(&self->unique_id[size], 0, sizeof(ncclUniqueId)); } } - status = UCC_TL_TEAM_OOB(self).allgather( - &self->unique_id[size], self->unique_id, - sizeof(ncclUniqueId), UCC_TL_TEAM_OOB(self).coll_info, - &self->oob_req); + + status = oob->allgather(&self->unique_id[size], + self->unique_id, sizeof(ncclUniqueId), + oob->coll_info, &self->oob_req); if (status != UCC_OK) { tl_error(ctx->super.super.lib, "failed to start oob allgather"); goto free_unique_id; } + self->comm_state = TL_NCCL_COMM_STATE_OOB; + return UCC_OK; free_unique_id: @@ -69,15 +75,17 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) #if NCCL_USE_NON_BLOCKING ncclResult_t nccl_status, st; - if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) { + if (team->comm_state == TL_NCCL_COMM_STATE_DESTROY_COMM) { goto check_finalize; } #endif + if (team->stream) { + cudaStreamDestroy(team->stream); + team->stream = NULL; + } if (team->nccl_comm) { - if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) { - /* if communication error was detected ncclCommAbort should be used - since ncclCommDestroy could block */ + if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) { ncclCommAbort(team->nccl_comm); } else { #if NCCL_USE_NON_BLOCKING @@ -91,7 +99,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) ncclCommAbort(team->nccl_comm); return UCC_ERR_NO_MESSAGE; } else if (nccl_status == ncclInProgress) { - team->comm_state = UCC_INPROGRESS; + team->comm_state = TL_NCCL_COMM_STATE_DESTROY_COMM; return UCC_INPROGRESS; } else { ncclCommDestroy(team->nccl_comm); @@ -101,95 +109,125 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) ncclCommDestroy(team->nccl_comm); #endif } - cudaStreamDestroy(team->stream); } UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team); return UCC_OK; } -ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) +ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team) { - ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); ucc_status_t status; ncclResult_t nccl_status; - ncclUniqueId errorid; - #if NCCL_USE_NON_BLOCKING ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER; - ncclResult_t st; - - if (team->comm_state == UCC_INPROGRESS) { - goto ncclInitStage; - } + ncclResult_t async_status; #endif - status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req); - if (status == UCC_INPROGRESS) { - return UCC_INPROGRESS; - } - if (status != UCC_OK) { - UCC_TL_TEAM_OOB(team).req_free(team->oob_req); - tl_error(tl_team->context->lib, "oob req test failed"); - goto free_unique_id; - } - status = UCC_TL_TEAM_OOB(team).req_free(team->oob_req); - if (status != UCC_OK) { - tl_error(tl_team->context->lib, "oob req free failed"); - goto free_unique_id; - } - /* check unique id is valid */ - memset(&errorid, 0, sizeof(errorid)); - if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) { - tl_error(tl_team->context->lib, "incorrect unique id"); - goto free_unique_id; + if (team->comm_state == TL_NCCL_COMM_STATE_READY) { + return UCC_OK; + } else if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) { + return UCC_ERR_NOT_SUPPORTED; + } else if (team->comm_state == TL_NCCL_COMM_STATE_INIT_COMM) { +#if NCCL_USE_NON_BLOCKING + goto nccl_async_init; +#else + ucc_assert_always(0); +#endif } CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream, - cudaStreamNonBlocking), free_unique_id, status); + cudaStreamNonBlocking), + exit_err, status); #if NCCL_USE_NON_BLOCKING - nccl_cfg.blocking = UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking; - nccl_status = ncclCommInitRankConfig(&team->nccl_comm, - UCC_TL_TEAM_SIZE(team), - team->unique_id[0], - UCC_TL_TEAM_RANK(team), - &nccl_cfg); - if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) { - goto free_stream; + /* + * if NCCL comm initialized during first call to collective init a.k.a lazy init + * we need to use blocking init to correctly fallback to other TL in case of error + */ + nccl_cfg.blocking = (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking || + UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) ? 1: 0; + + nccl_status = ncclCommInitRankConfig(&team->nccl_comm, tsize, + team->unique_id[0], trank, &nccl_cfg); + if ((nccl_status != ncclInProgress) && (nccl_status != ncclSuccess)) { + goto nccl_comm_init_err; } -ncclInitStage: - st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); - if (st != ncclSuccess) { - nccl_status = st; +nccl_async_init: + nccl_status = ncclCommGetAsyncError(team->nccl_comm, &async_status); + if (nccl_status != ncclSuccess) { + goto nccl_comm_init_err; } - if (nccl_status == ncclInProgress){ - team->comm_state = UCC_INPROGRESS; - return UCC_INPROGRESS; + if (async_status == ncclInProgress) { + team->comm_state = TL_NCCL_COMM_STATE_INIT_COMM; } #else - nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team), - team->unique_id[0], UCC_TL_TEAM_RANK(team)); -#endif + nccl_status = ncclCommInitRank(&team->nccl_comm, tsize, team->unique_id[0], + trank); if (nccl_status != ncclSuccess) { - goto free_stream; + goto nccl_comm_init_err; } - ucc_free(team->unique_id); - tl_debug(tl_team->context->lib, "initialized tl team: %p", team); +#endif + + team->comm_state = TL_NCCL_COMM_STATE_READY; return UCC_OK; -free_stream: - tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status, - ncclGetErrorString(nccl_status)); - status = UCC_ERR_NO_MESSAGE; -#if NCCL_USE_NON_BLOCKING - ncclCommAbort(team->nccl_comm); -#endif - cudaStreamDestroy(team->stream); -free_unique_id: - ucc_free(team->unique_id); +nccl_comm_init_err: + tl_debug(team->super.super.context->lib, "NCCL error %d %s", + nccl_status, ncclGetErrorString(nccl_status)); + if (nccl_status == ncclInvalidUsage) { + /* + * handles the case when trying to inititize multiple ranks + * on the same GPU. Return "not supported" and fallback to other TL + */ + status = UCC_ERR_NOT_SUPPORTED; + } else { + status = UCC_ERR_NO_RESOURCE; + } + team->comm_state = TL_NCCL_COMM_STATE_ERROR; + +exit_err: return status; } +ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) +{ + ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); + ucc_team_oob_coll_t *oob = &(UCC_TL_TEAM_OOB(team)); + ncclUniqueId errorid; + ucc_status_t status; + + + if (team->comm_state == TL_NCCL_COMM_STATE_OOB) { + status = oob->req_test(team->oob_req); + if (status == UCC_INPROGRESS) { + return UCC_INPROGRESS; + } + + oob->req_free(team->oob_req); + if (status != UCC_OK) { + tl_error(tl_team->context->lib, "oob req test failed"); + return status; + } + + /* check unique id is valid */ + memset(&errorid, 0, sizeof(errorid)); + if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) { + tl_error(tl_team->context->lib, "incorrect unique id"); + return status; + } + + team->comm_state = TL_NCCL_COMM_STATE_INIT_TEAM; + } + + if (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) { + return UCC_OK; + } + + return ucc_tl_nccl_comm_init(team); +} + ucc_status_t ucc_tl_nccl_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) diff --git a/src/components/tl/ucc_tl.c b/src/components/tl/ucc_tl.c index dcbb2b6d71..3134c9fd14 100644 --- a/src/components/tl/ucc_tl.c +++ b/src/components/tl/ucc_tl.c @@ -242,6 +242,11 @@ ucc_status_t ucc_tl_team_create_multiple(ucc_team_multiple_req_t *req) } req->descs[*id].status = UCC_TL_CTX_IFACE(req->descs[*id].ctx) ->team.create_test(&req->descs[*id].team->super); + if (req->descs[*id].status < 0) { + /* if team create failed in team create test need to cleanup resources */ + UCC_TL_CTX_IFACE(req->descs[*id].ctx)->team.destroy( + &req->descs[*id].team->super); + } return UCC_INPROGRESS; } diff --git a/src/components/tl/ucc_tl.h b/src/components/tl/ucc_tl.h index 53e62052dc..75a5e3e1a0 100644 --- a/src/components/tl/ucc_tl.h +++ b/src/components/tl/ucc_tl.h @@ -138,8 +138,18 @@ typedef struct ucc_tl_lib_attr { #define UCC_TL_TEAM_IFACE(_tl_team) \ (ucc_derived_of((_tl_team)->super.context->lib, ucc_tl_lib_t))->iface +/** + * Get TL team lib + * @param [in] _tl_team pointer to TL team object + * @return pointer to TL lib object + */ #define UCC_TL_TEAM_LIB(_tl_team) (_tl_team)->super.super.context->lib +/** + * Get TL team context + * @param [in] _tl_team pointer to TL team object + * @return pointer to TL context object + */ #define UCC_TL_TEAM_CTX(_tl_team) (_tl_team)->super.super.context #define UCC_TL_CORE_CTX(_tl_team) ((_tl_team)->super.super.context->ucc_context)