How to Use submitit
on Slurm Systems.
submitit
is a Python package by Facebook/Meta for submitting one or more Python
jobs to Slurm clusters, like those at OSC or Facebook.
While it’s easy to use, there were a couple gotchas and lessons I had to learn before I really appreciated how useful it was compared to writing Bash scripts. This article covers those, with plenty of examples and explanation.
I have used submitit for several projects, including BioBench, saev, frx
(unmaintained at the moment), and an I-JEPA
reimplementation (also unmaintained at the moment). These repos will
contain complete examples of how I use submitit in a larger
context than a blog post can provide.
I recommend reading from the top to the bottom as the tips are ordered in relative importance/utility.
Table of Contents
- Minimal Example
- Why Use
submitit? - GPUs and CPUs
- Multi-GPU Training in Torch
- DebugExecutor
stderr_to_stdout=True- Global Variables
- Environment Variables
- Submitting GPU Jax Jobs on CPU-Only Nodes
- Complete Examples
Minimal Example
Minimal example for submitting jobs on a cluster.
def add(a, b):
return a + b
executor = submitit.AutoExecutor(folder="./logs")
executor.update_parameters(
timeout_min=1, slurm_partition="gpu"
)
job = executor.submit(add, 5, 7)
output = job.result()
assert output == 12This is directly from the submitit
README.
Why Use submitit?
Understand the motivation for submitit.
I learned to interact with Slurm clusters through bash scripts and
sbatch. We used this a lot in BioCLIP, like this
launch script to run make_wds.py. I would submit it
with sbatch slurm/make-dataset-wds.sh.
This sucked for a couple reasons.
- I had to write scripts in another language (bash) when I wanted to use Python for everything.
- I had to edit the bash script to change the Python script arguments.
I had a very nice argument parser in my Python script with
argparseand help text, etc. but I couldn’t use it because I didn’t have an argument parser formake-datset-wds.sh. - I couldn’t easily programatically launch many jobs at once. Sometimes I have a config file that specifies a sweep of jobs, and I want to launch many jobs with one script. But because I can’t write bash very well, I’ve written Python scripts to parse the config files, then launch the jobs, with code like this:
# buggy and error-prone code; use submitit instead.
command = [
"sbatch",
f"--output=./logs/{job_name}-%j.log",
f"--job-name={job_name}",
f"--export=CONFIG_FILE={config_file}",
template_file,
]
try:
output = subprocess.run(
command, check=True, capture_output=True
)
print(output.stdout.decode("utf-8"), end="")
except subprocess.CalledProcessError as e:
print(e.stderr.decode("utf-8"), end="")
print(e)- Re-launching jobs is also a hassle. If an experiment is checkpointable, I would like to restart jobs when they end without having to log back in.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
submitit solves all of these pain points. If you look at
projects like BioBench or saev, there’s no bash
scripts whatsoever, but they run on Slurm clusters and local s
GPUs and CPUs
Setting number of GPUs and CPUs in your jobs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=120,
partition="gpu",
account="ACCOUNT",
# These args are important.
gpus_per_node=4,
ntasks_per_node=2,
cpus_per_task=12,
)
# calls job_fn ntasks_per_node times in parallel.
executor.submit(job_fn) gpus_per_node is GPUs per node,
ntasks_per_node is the number of processes that call your
function, and cpus_per_task is the number of CPUs available
per task. So if you want to run two tasks, each with two GPUs,
and a total of 24 CPUs, you need ntasks_per_node=2,
gpus_per_node=4, and cpus_per_task=12.
Multi-GPU Training in Torch
torch.distributed
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=12 * 60,
partition="gpu",
account="ACCOUNT",
# These args are important.
gpus_per_node=4,
ntasks_per_node=4,
cpus_per_task=12,
)
executor.submit(train).result()
def train():
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
torch.distributed.init_process_group(
backend="nccl", world_size=dist_env.world_size
)
assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()When setting up your executor, be sure to set
ntasks_per_node to the same number as
gpus_per_node so that every GPU has a task.
The submitit.helpers.TorchDistributedEnvironment class
somehow handles environment variables so that PyTorch can setup the
distributed environment correctly.
Note that if you don’t have a CUDA device available, you
cannot call init_process_group so you probably want to
handle that.
DebugExecutor
Run code in the current process rather than in a Slurm job.
if debug:
executor = submitit.DebugExecutor(folder=args.log_to)
else:
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
# Use executor as normal.
executor.submit(job_fn, arg1, arg2).result()If you are debugging jobs, you likely want to use pdb or
other interactive debuggers. You cannot use pdb in a
“headless” process like a Slurm job. However, the
submitit.DebugExecutor will run jobs in the same process
that you create the executor from. This is really useful for debugging
jobs, because DebugExecutor has the same API as
SlurmExecutor so you can split up your executor
construction code and then debug your jobs.
This solves problem #5.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
stderr_to_stdout=True
Don’t split up stderr and stdout logs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30,
partition="PARTITION",
account="SLURM_ACCT",
stderr_to_stdout=True, # <- This line
)Most of the time I’m not interested in the distinction between stderr
and stdout because I just care about outputs. print() in
Python goes to stdout, logging.info() goes to stderr. If
you mix them, it can be irritating to try and understand how your
debugging statements are ordered (but you should also use pdb in a DebugExecutor instead of print
statements). Setting stderr_to_stdout=True in
executor.update_parameters()writes everything to the same
stream.
Global Variables
Global variables don’t work.
def main():
if not use_ssl:
import ssl
# By default do not use HTTPS
ssl._create_default_https_context = ssl._create_unverified_context
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
exeuctor.submit(job_fn)
def job_fn():
print(ssl._create_default_https_context)
# Will not be an unverified contextIf you want to set global variables after the program is running but
before you submit jobs, these variables will not persist in your jobs.
In the example above, I want to set Python’s ssl’s
module to ignore HTTPS certs by setting
ssl._create_default_https_context to
ssl._create_unverified_context. However, in the job
function job_fn,
ssl._create_default_https_context will not be set
correctly.
Environment Variables
How to set environment variables in Slurm jobs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
if not use_ssl:
executor.update_parameters(setup=[
"export DISABLE_SSL=1",
"export HAVE_FUN=2",
])If you want to set global variables, you might end up using
environment variables. While using environment variables to manage
program state is almost always a source of bugs, if you absolutely need
to, you can use the setup parameter to set environment
variables in the running Slurm jobs.
Submitting GPU Jax Jobs on CPU-Only Nodes
Weird Jax issues with GPUs.
executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(
time=12 * 60,
partition="PARTITION",
account="ACCOUNT",
setup=["export JAX_PLATFORMS=''"],
)For whatever reason, we cannot import Jax without a GPU. That is, if
you run import jax with jax[cuda12] installed
in your environment, you get an exception about “No GPUs found” or
something like that. But you can use JAX_PLATFORMS=cpu
before uv run python -c "import jax; print(jax)" and it
will work fine. But, if you set JAX_PLATFORMS=cpu
to run this launcher script, then it will be true for the submitted
jobs. This means that your training jobs will run on the CPU instead of
the cluster GPUs.
This extra arg exports an updated JAX_PLATFORMS variable
for the cluster jobs and it will find the GPUs for training.
Complete Examples
Complete PyTorch and Jax example with all these tricks.
PyTorch:
# import ...
def train(device: str):
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
is_ddp = False
global_rank = 0
local_rank = 0
is_master = False
if device == "cuda":
torch.distributed.init_process_group(
backend="nccl", world_size=dist_env.world_size
)
assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()
is_ddp = True
global_rank = dist_env.rank
local_rank = dist_env.local_rank
is_master = dist_env.rank == 0
logger = logging.getLogger("Rank %d", args.global_rank)
model = ViT()
if is_ddp:
model = torch.nn.parallel.DistributedDataParallel(model)
# Train as normal
def main():
Jax:
# TODOSam Stevens, 2024