How to Use submitit
on Slurm Systems.
submitit
is a Python package by Facebook/Meta for submitting one or more Python
jobs to Slurm clusters, like those at OSC or Facebook.
While it’s easy to use, there were a couple gotchas and lessons I had to learn before I really appreciated how useful it was compared to writing Bash scripts. This article covers those, with plenty of examples and explanation.
I have used submitit
for several projects, including BioBench, saev, frx
(unmaintained at the moment), and an I-JEPA
reimplementation (also unmaintained at the moment). These repos will
contain complete examples of how I use submitit
in a larger
context than a blog post can provide.
I recommend reading from the top to the bottom as the tips are ordered in relative importance/utility.
Table of Contents
- Minimal Example
- Why Use
submitit
? - GPUs and CPUs
- Multi-GPU Training in Torch
- DebugExecutor
stderr_to_stdout=True
- Global Variables
- Environment Variables
- Submitting GPU Jax Jobs on CPU-Only Nodes
- Complete Examples
Minimal Example
Minimal example for submitting jobs on a cluster.
def add(a, b):
return a + b
= submitit.AutoExecutor(folder="./logs")
executor
executor.update_parameters(=1, slurm_partition="gpu"
timeout_min
)= executor.submit(add, 5, 7)
job
= job.result()
output assert output == 12
This is directly from the submitit
README.
Why Use submitit
?
Understand the motivation for submitit
.
I learned to interact with Slurm clusters through bash scripts and
sbatch
. We used this a lot in BioCLIP, like this
launch script to run make_wds.py
. I would submit it
with sbatch slurm/make-dataset-wds.sh
.
This sucked for a couple reasons.
- I had to write scripts in another language (bash) when I wanted to use Python for everything.
- I had to edit the bash script to change the Python script arguments.
I had a very nice argument parser in my Python script with
argparse
and help text, etc. but I couldn’t use it because I didn’t have an argument parser formake-datset-wds.sh
. - I couldn’t easily programatically launch many jobs at once. Sometimes I have a config file that specifies a sweep of jobs, and I want to launch many jobs with one script. But because I can’t write bash very well, I’ve written Python scripts to parse the config files, then launch the jobs, with code like this:
# buggy and error-prone code; use submitit instead.
= [
command "sbatch",
f"--output=./logs/{job_name}-%j.log",
f"--job-name={job_name}",
f"--export=CONFIG_FILE={config_file}",
template_file,
]try:
= subprocess.run(
output =True, capture_output=True
command, check
)print(output.stdout.decode("utf-8"), end="")
except subprocess.CalledProcessError as e:
print(e.stderr.decode("utf-8"), end="")
print(e)
- Re-launching jobs is also a hassle. If an experiment is checkpointable, I would like to restart jobs when they end without having to log back in.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
submitit
solves all of these pain points. If you look at
projects like BioBench or saev, there’s no bash
scripts whatsoever, but they run on Slurm clusters and local s
GPUs and CPUs
Setting number of GPUs and CPUs in your jobs.
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=120,
time="gpu",
partition="ACCOUNT",
account# These args are important.
=4,
gpus_per_node=2,
ntasks_per_node=12,
cpus_per_task
)# calls job_fn ntasks_per_node times in parallel.
executor.submit(job_fn)
gpus_per_node
is GPUs per node,
ntasks_per_node
is the number of processes that call your
function, and cpus_per_task
is the number of CPUs available
per task. So if you want to run two tasks, each with two GPUs,
and a total of 24 CPUs, you need ntasks_per_node=2
,
gpus_per_node=4
, and cpus_per_task=12
.
Multi-GPU Training in Torch
torch.distributed
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=12 * 60,
time="gpu",
partition="ACCOUNT",
account# These args are important.
=4,
gpus_per_node=4,
ntasks_per_node=12,
cpus_per_task
)
executor.submit(train).result()
def train():
= submitit.helpers.TorchDistributedEnvironment().export()
dist_env
torch.distributed.init_process_group(="nccl", world_size=dist_env.world_size
backend
)assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()
When setting up your executor, be sure to set
ntasks_per_node
to the same number as
gpus_per_node
so that every GPU has a task.
The submitit.helpers.TorchDistributedEnvironment
class
somehow handles environment variables so that PyTorch can setup the
distributed environment correctly.
Note that if you don’t have a CUDA device available, you
cannot call init_process_group
so you probably want to
handle that.
DebugExecutor
Run code in the current process rather than in a Slurm job.
if debug:
= submitit.DebugExecutor(folder=args.log_to)
executor else:
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=30, partition="PARTITION", account="SLURM_ACCT"
time
)
# Use executor as normal.
executor.submit(job_fn, arg1, arg2).result()
If you are debugging jobs, you likely want to use pdb
or
other interactive debuggers. You cannot use pdb
in a
“headless” process like a Slurm job. However, the
submitit.DebugExecutor
will run jobs in the same process
that you create the executor from. This is really useful for debugging
jobs, because DebugExecutor
has the same API as
SlurmExecutor
so you can split up your executor
construction code and then debug your jobs.
This solves problem #5.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
stderr_to_stdout=True
Don’t split up stderr and stdout logs.
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=30,
time="PARTITION",
partition="SLURM_ACCT",
account=True, # <- This line
stderr_to_stdout )
Most of the time I’m not interested in the distinction between stderr
and stdout because I just care about outputs. print()
in
Python goes to stdout, logging.info()
goes to stderr. If
you mix them, it can be irritating to try and understand how your
debugging statements are ordered (but you should also use pdb in a DebugExecutor instead of print
statements). Setting stderr_to_stdout=True
in
executor.update_parameters()
writes everything to the same
stream.
Global Variables
Global variables don’t work.
def main():
if not use_ssl:
import ssl
# By default do not use HTTPS
= ssl._create_unverified_context
ssl._create_default_https_context
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=30, partition="PARTITION", account="SLURM_ACCT"
time
)
exeuctor.submit(job_fn)
def job_fn():
print(ssl._create_default_https_context)
# Will not be an unverified context
If you want to set global variables after the program is running but
before you submit jobs, these variables will not persist in your jobs.
In the example above, I want to set Python’s ssl
’s
module to ignore HTTPS certs by setting
ssl._create_default_https_context
to
ssl._create_unverified_context
. However, in the job
function job_fn
,
ssl._create_default_https_context
will not be set
correctly.
Environment Variables
How to set environment variables in Slurm jobs.
= submitit.SlurmExecutor(folder="./logs")
executor
executor.update_parameters(=30, partition="PARTITION", account="SLURM_ACCT"
time
)
if not use_ssl:
=[
executor.update_parameters(setup"export DISABLE_SSL=1",
"export HAVE_FUN=2",
])
If you want to set global variables, you might end up using
environment variables. While using environment variables to manage
program state is almost always a source of bugs, if you absolutely need
to, you can use the setup
parameter to set environment
variables in the running Slurm jobs.
Submitting GPU Jax Jobs on CPU-Only Nodes
Weird Jax issues with GPUs.
= submitit.SlurmExecutor(folder="logs")
executor
executor.update_parameters(=12 * 60,
time="PARTITION",
partition="ACCOUNT",
account=["export JAX_PLATFORMS=''"],
setup )
For whatever reason, we cannot import Jax without a GPU. That is, if
you run import jax
with jax[cuda12]
installed
in your environment, you get an exception about “No GPUs found” or
something like that. But you can use JAX_PLATFORMS=cpu
before uv run python -c "import jax; print(jax)"
and it
will work fine. But, if you set JAX_PLATFORMS=cpu
to run this launcher script, then it will be true for the submitted
jobs. This means that your training jobs will run on the CPU instead of
the cluster GPUs.
This extra arg exports an updated JAX_PLATFORMS
variable
for the cluster jobs and it will find the GPUs for training.
Complete Examples
Complete PyTorch and Jax example with all these tricks.
PyTorch:
# import ...
def train(device: str):
= submitit.helpers.TorchDistributedEnvironment().export()
dist_env
= False
is_ddp = 0
global_rank = 0
local_rank = False
is_master if device == "cuda":
torch.distributed.init_process_group(="nccl", world_size=dist_env.world_size
backend
)assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()
= True
is_ddp = dist_env.rank
global_rank = dist_env.local_rank
local_rank = dist_env.rank == 0
is_master
= logging.getLogger("Rank %d", args.global_rank)
logger
= ViT()
model
if is_ddp:
= torch.nn.parallel.DistributedDataParallel(model)
model
# Train as normal
def main():
Jax:
# TODO
[Relevant link] [Source]
Sam Stevens, 2024