Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions easybuild/tools/toolchain/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
_log = fancylogger.getLogger('tools.toolchain.mpi', fname=False)


def get_mpi_cmd_template(mpi_family, params, mpi_version=None):
def get_mpi_cmd_template(mpi_family, params, mpi_version=None, oversubscribe=False):
"""
Return template for MPI command, for specified MPI family.

Expand Down Expand Up @@ -123,6 +123,50 @@ def get_mpi_cmd_template(mpi_family, params, mpi_version=None):
else:
raise EasyBuildError("Don't know which template MPI command to use for MPI family '%s'", mpi_family)

if oversubscribe:
osub_cmd = ''
if mpi_family in [toolchain.OPENMPI]:
if mpi_version is None:
raise EasyBuildError("OpenMPI version unknown, can't determine how to handle oversubscription!")
if LooseVersion(mpi_version) < '5':
varname = 'OMPI_MCA_rmaps_base_oversubscribe'
varvalue = os.getenv(varname)
if varvalue and varvalue != '1':
_log.warning("Overwriting existing %s=%s with %s=1", varname, varvalue, varname)
osub_cmd = f'{varname}=1'
else:
varname = 'PRTE_MCA_rmaps_default_mapping_policy'
varvalue = os.getenv(varname)

# This logic should account for:
# - var not set -> set to 'core:oversubscribe'
# - unit set to value without `:` eg package -> 'package:oversubscribe'
# - unit set to value with `:` eg ppr:4:numa -> 'ppr:4:numa:oversubscribe'
# - all of the above but with oversubscribe already in flags
flags = ''
if varvalue.startswith('ppr'):
_log.warning("Can't handle ppr mapping with oversubscription yet, overwriting unit with 'core'")
unit = 'core'
flags = 'oversubscribe'
else:
unit, flags = (varvalue.rsplit(':', maxsplit=1) + [''])[:2]
unit = unit or 'core'
flags = list(filter(None, flags.split(',')))
if 'oversubscribe' not in flags:
flags.append('oversubscribe')
newvalue = f"{unit}:{','.join(flags)}"

osub_cmd = f'{varname}={newvalue}'
elif mpi_family in [toolchain.INTELMPI]:
_log.info("INTELMPI always oversubscribe by default, nothing to do...")
elif mpi_family in [toolchain.MVAPICH2, toolchain.MPICH, toolchain.MPICH2]:
_log.info("MPICH always oversubscribe by default, nothing to do...")
else:
raise EasyBuildError("Oversubscribe not supported for MPI family '%s'", mpi_family)

mpi_cmd_template = f'%(oversubscribe)s {mpi_cmd_template}'
params.update({'oversubscribe': osub_cmd}) # just a placeholder

missing = []
for key in sorted(params.keys()):
tmpl = '%(' + key + ')s'
Expand Down Expand Up @@ -270,7 +314,7 @@ def mpi_cmd_prefix(self, nr_ranks=1):

return result

def mpi_cmd_for(self, cmd, nr_ranks):
def mpi_cmd_for(self, cmd, nr_ranks, oversubscribe=False):
"""Construct an MPI command for the given command and number of ranks."""

# parameter values for mpirun command
Expand All @@ -281,20 +325,20 @@ def mpi_cmd_for(self, cmd, nr_ranks):

mpi_family = self.mpi_family()

mpi_version = None
# this fails when it's done too early (before modules for toolchain/dependencies are loaded),
# but it's safe to ignore this
mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0]

if mpi_family == toolchain.INTELMPI:
# for Intel MPI, try to determine impi version
# this fails when it's done too early (before modules for toolchain/dependencies are loaded),
# but it's safe to ignore this
mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0]
if not mpi_version:
self.log.debug("Ignoring error when trying to determine %s version", self.MPI_MODULE_NAME)
# impi version is required to determine correct MPI command template,
# so we have to return early if we couldn't determine the impi version...
return None

mpi_cmd_template, params = get_mpi_cmd_template(mpi_family, params, mpi_version=mpi_version)
mpi_cmd_template, params = get_mpi_cmd_template(
mpi_family, params, mpi_version=mpi_version, oversubscribe=oversubscribe
)
self.log.info("Using MPI command template '%s' (params: %s)", mpi_cmd_template, params)

try:
Expand Down
Loading