Skip to content

Commit

Permalink
FEAT-modin-project#6914: Add a config for setting a number of threads…
Browse files Browse the repository at this point in the history
… per Dask worker

Signed-off-by: Igoshev, Iaroslav <[email protected]>
  • Loading branch information
YarShev committed Feb 6, 2024
1 parent 807298d commit ac6155e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
3 changes: 3 additions & 0 deletions modin/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CIAWSAccessKeyID,
CIAWSSecretAccessKey,
CpuCount,
DaskThreadsPerWorker,
DoUseCalcite,
Engine,
EnvironmentVariable,
Expand Down Expand Up @@ -73,6 +74,8 @@
"RayRedisPassword",
"TestRayClient",
"LazyExecution",
# Dask specific
"DaskThreadsPerWorker",
# Partitioning
"NPartitions",
"MinPartitionSize",
Expand Down
7 changes: 7 additions & 0 deletions modin/config/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,13 @@ class LazyExecution(EnvironmentVariable, type=bool):
default = False


class DaskThreadsPerWorker(EnvironmentVariable, type=int):
"""Number of threads per Dask worker."""

varname = "MODIN_DASK_THREADS_PER_WORKER"
default = 1


def _check_vars() -> None:
"""
Check validity of environment variables.
Expand Down
8 changes: 7 additions & 1 deletion modin/core/execution/dask/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GithubCI,
Memory,
NPartitions,
DaskThreadsPerWorker,
)
from modin.core.execution.utils import set_env
from modin.error_message import ErrorMessage
Expand Down Expand Up @@ -54,12 +55,17 @@ def _disable_warnings():
""",
)
num_cpus = CpuCount.get()
threads_per_worker = DaskThreadsPerWorker.get()
memory_limit = Memory.get()
worker_memory_limit = memory_limit // num_cpus if memory_limit else "auto"

# when the client is initialized, environment variables are inherited
with set_env(PYTHONWARNINGS="ignore::FutureWarning"):
client = Client(n_workers=num_cpus, memory_limit=worker_memory_limit)
client = Client(
n_workers=num_cpus,
threads_per_worker=threads_per_worker,
memory_limit=worker_memory_limit,
)

if GithubCI.get():
# set these keys to run tests that write to the mock s3 service. this seems
Expand Down

0 comments on commit ac6155e

Please sign in to comment.