How to Use submitit
on Slurm Systems.
submitit
is a Python package by Facebook/Meta for submitting one or more Python
jobs to Slurm clusters, like those at OSC or Facebook.
While it’s easy to use, there were a couple gotchas and lessons I had to learn before I really appreciated how useful it was compared to writing Bash scripts. This article covers those, with plenty of examples and explanation.
I have used submitit for several projects, including BioBench, saev, frx
(unmaintained at the moment), and mindthegap. These
repos will contain complete examples of how I use submitit
in a larger context than a blog post can provide.
Table of Contents
- Minimal Example
- Why Use
submitit? - GPUs and CPUs
- Multi-GPU Training in Torch
- DebugExecutor
stderr_to_stdout=True- Global Variables
- Environment Variables
- Submitting GPU JAX Jobs on CPU-Only Nodes
cloudpickleand Equinox- Complete Examples
Minimal Example
Minimal example for submitting jobs on a cluster.
def add(a, b):
return a + b
executor = submitit.AutoExecutor(folder="./logs")
executor.update_parameters(
timeout_min=1, slurm_partition="gpu"
)
job = executor.submit(add, 5, 7)
output = job.result()
assert output == 12This is directly from the submitit
README.
Why Use submitit?
Understand the motivation for submitit.
I learned to interact with Slurm clusters through bash scripts and
sbatch. We used this a lot in BioCLIP, like this
launch script to run make_wds.py. I would submit it
with sbatch slurm/make-dataset-wds.sh.
This sucked for a couple reasons.
- I had to write scripts in another language (bash) when I wanted to use Python for everything.
- I had to edit the bash script to change the Python script arguments.
I had a very nice argument parser in my Python script with
argparseand help text, etc. but I couldn’t use it because I didn’t have an argument parser formake-datset-wds.sh. - I couldn’t easily programatically launch many jobs at once. Sometimes I have a config file that specifies a sweep of jobs, and I want to launch many jobs with one script. But because I can’t write bash very well, I’ve written Python scripts to parse the config files, then launch the jobs, with code like this:
# buggy and error-prone code; use submitit instead.
command = [
"sbatch",
f"--output=./logs/{job_name}-%j.log",
f"--job-name={job_name}",
f"--export=CONFIG_FILE={config_file}",
template_file,
]
try:
output = subprocess.run(
command, check=True, capture_output=True
)
print(output.stdout.decode("utf-8"), end="")
except subprocess.CalledProcessError as e:
print(e.stderr.decode("utf-8"), end="")
print(e)- Re-launching jobs is also a hassle. If an experiment is checkpointable, I would like to restart jobs when they end without having to log back in.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
submitit solves all of these pain points. If you look at
projects like BioBench or saev, there’s no bash
scripts whatsoever, but they run on Slurm clusters and local s
GPUs and CPUs
Setting number of GPUs and CPUs in your jobs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=120,
partition="gpu",
account="ACCOUNT",
# These args are important.
gpus_per_node=4,
ntasks_per_node=2,
cpus_per_task=12,
)
# calls job_fn ntasks_per_node times in parallel.
executor.submit(job_fn) gpus_per_node is GPUs per node,
ntasks_per_node is the number of processes that call your
function, and cpus_per_task is the number of CPUs available
per task. So if you want to run two tasks, each with two GPUs,
and a total of 24 CPUs, you need ntasks_per_node=2,
gpus_per_node=4, and cpus_per_task=12.
Multi-GPU Training in Torch
torch.distributed
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=12 * 60,
partition="gpu",
account="ACCOUNT",
# These args are important.
gpus_per_node=4,
ntasks_per_node=4,
cpus_per_task=12,
)
executor.submit(train).result()
def train():
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
torch.distributed.init_process_group(
backend="nccl", world_size=dist_env.world_size
)
assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()When setting up your executor, be sure to set
ntasks_per_node to the same number as
gpus_per_node so that every GPU has a task.
The submitit.helpers.TorchDistributedEnvironment class
somehow handles environment variables so that PyTorch can setup the
distributed environment correctly.
Note that if you don’t have a CUDA device available, you
cannot call init_process_group so you probably want to
handle that.
DebugExecutor
Run code in the current process rather than in a Slurm job.
if debug:
executor = submitit.DebugExecutor(folder=args.log_to)
else:
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
# Use executor as normal.
executor.submit(job_fn, arg1, arg2).result()If you are debugging jobs, you likely want to use pdb or
other interactive debuggers. You cannot use pdb in a
“headless” process like a Slurm job. However, the
submitit.DebugExecutor will run jobs in the same process
that you create the executor from. This is really useful for debugging
jobs, because DebugExecutor has the same API as
SlurmExecutor so you can split up your executor
construction code and then debug your jobs.
This solves problem #5.
- Setting up your jobs to run on both Slurm cluster and local clusters is challenging.
stderr_to_stdout=True
Don’t split up stderr and stdout logs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30,
partition="PARTITION",
account="SLURM_ACCT",
stderr_to_stdout=True, # <- This line
)Most of the time I’m not interested in the distinction between stderr
and stdout because I just care about outputs. print() in
Python goes to stdout, logging.info() goes to stderr. If
you mix them, it can be irritating to try and understand how your
debugging statements are ordered (but you should also use pdb in a DebugExecutor instead of print
statements). Setting stderr_to_stdout=True in
executor.update_parameters()writes everything to the same
stream.
Global Variables
Global variables don’t work.
def main():
if not use_ssl:
import ssl
# By default do not use HTTPS
ssl._create_default_https_context = ssl._create_unverified_context
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
exeuctor.submit(job_fn)
def job_fn():
print(ssl._create_default_https_context)
# Will not be an unverified contextIf you want to set global variables after the program is running but
before you submit jobs, these variables will not persist in your jobs.
In the example above, I want to set Python’s ssl’s
module to ignore HTTPS certs by setting
ssl._create_default_https_context to
ssl._create_unverified_context. However, in the job
function job_fn,
ssl._create_default_https_context will not be set
correctly.
Environment Variables
How to set environment variables in Slurm jobs.
executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
time=30, partition="PARTITION", account="SLURM_ACCT"
)
if not use_ssl:
executor.update_parameters(setup=[
"export DISABLE_SSL=1",
"export HAVE_FUN=2",
])If you want to set global variables, you might end up using
environment variables. While using environment variables to manage
program state is almost always a source of bugs, if you absolutely need
to, you can use the setup parameter to set environment
variables in the running Slurm jobs.
Submitting GPU JAX Jobs on CPU-Only Nodes
Weird JAX issues with GPUs.
executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(
time=12 * 60,
partition="PARTITION",
account="ACCOUNT",
setup=["export JAX_PLATFORMS=''"],
)For whatever reason, we cannot import JAX without a GPU. That is, if
you run import jax with jax[cuda12] installed
in your environment, you get an exception about “No GPUs found” or
something like that. But you can use JAX_PLATFORMS=cpu
before uv run python -c "import jax; print(jax)" and it
will work fine. But, if you set JAX_PLATFORMS=cpu
to run this launcher script, then it will be true for the submitted
jobs. This means that your training jobs will run on the CPU instead of
the cluster GPUs.
This extra arg exports an updated JAX_PLATFORMS variable
for the cluster jobs and it will find the GPUs for training.
cloudpickle and
Equinox
When using Equinox
modules with submitit, which uses cloudpickle,
JIT tracing fails if the module class is defined in
__main__.
Define an eqx.Module in your main script and submit it
via submitit. JIT tracing will fail.
# broken.py - THIS FAILS
import equinox as eqx
import jax
import jax.numpy as jnp
import submitit
class Model(eqx.Module): # Defined in __main__
linear: eqx.nn.Linear
def __init__(self, key):
self.linear = eqx.nn.Linear(2, 2, key=key)
def __call__(self, x):
return self.linear(x)
def run():
model = Model(jax.random.PRNGKey(0))
print("Eager:", model(jnp.ones(2))) # Works
@eqx.filter_jit
def forward(model, x):
return model(x)
print("JIT:", forward(model, jnp.ones(2))) # FAILS
executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(gpus_per_node=1, ...)
job = executor.submit(run)
job.result()Error:
AttributeError: 'Model' object has no attribute 'linear'
Eager execution works; only JIT tracing fails. This happens because
cloudpickleserializes __main__ classes by
inlining their definition, which breaks Equinox’s attribute mechanism
during tracing.
Fix: Move your eqx.Module classes to a
separate importable file:
# model.py
import equinox as eqx
class Model(eqx.Module):
linear: eqx.nn.Linear
def __init__(self, key):
self.linear = eqx.nn.Linear(2, 2, key=key)
def __call__(self, x):
return self.linear(x)# main.py - THIS WORKS
import equinox as eqx
import jax
import jax.numpy as jnp
import submitit
from model import Model # Import instead of define
def run():
model = Model(jax.random.PRNGKey(0))
@eqx.filter_jit
def forward(model, x):
return model(x)
print("JIT:", forward(model, jnp.ones(2))) # Works
executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(gpus_per_node=1, ...)
job = executor.submit(run)
job.result()cloudpickle handles classes differently based on where
they’re defined:
__main__classes: Serialized with full class definition inlined- Imported classes: Serialized as module references
(e.g.,
model.Model)
The latter preserves the class identity correctly, allowing Equinox’s
attribute mechanism to work during JIT tracing in the subprocess. Thus,
always define eqx.Module classes in importable files, not
in __main__, when using submitit or other
cloudpickle-based job submission.
Complete Examples
Complete PyTorch and JAX example with all these tricks.
PyTorch:
# import ...
def train(device: str):
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
is_ddp = False
global_rank = 0
local_rank = 0
is_master = False
if device == "cuda":
torch.distributed.init_process_group(
backend="nccl", world_size=dist_env.world_size
)
assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()
is_ddp = True
global_rank = dist_env.rank
local_rank = dist_env.local_rank
is_master = dist_env.rank == 0
logger = logging.getLogger("Rank %d", args.global_rank)
model = ViT()
if is_ddp:
model = torch.nn.parallel.DistributedDataParallel(model)
# Train as normal
def main():
JAX:
JAX multi-GPU training with submitit requires a different approach
than PyTorch. While you can use single-process multi-GPU (one
process sees all GPUs via mesh sharding), this causes NaN/gradient
corruption on some clusters. The stable approach mirrors PyTorch: one
process per GPU with jax.distributed.
First, define your model in a separate file to avoid cloudpickle issues:
# jax_model.py
import equinox as eqx
import jax
class MLP(eqx.Module):
layers: list
def __init__(self, in_dim=128, hidden_dim=256, out_dim=64, *, key):
k1, k2, k3 = jax.random.split(key, 3)
self.layers = [
eqx.nn.Linear(in_dim, hidden_dim, key=k1),
eqx.nn.Linear(hidden_dim, hidden_dim, key=k2),
eqx.nn.Linear(hidden_dim, out_dim, key=k3),
]
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x)Then your training script:
# jax_submit.py
import jax
import jax.sharding
import jax.numpy as jnp
import equinox as eqx
import optax
import submitit
from jax_model import MLP
def worker_fn():
# Initialize multi-process JAX (like torch.distributed.init_process_group)
jax.distributed.initialize()
if jax.process_index() == 0:
print(f"{jax.process_count()} processes, {jax.device_count()} devices")
# Mesh spans all devices across all processes
mesh = jax.make_mesh((jax.device_count(),), ("batch",))
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("batch"))
model = MLP(key=jax.random.PRNGKey(0))
model = eqx.filter_shard(model, replicated)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
def loss_fn(model, x, y):
return jnp.mean((jax.vmap(model)(x) - y) ** 2)
@eqx.filter_jit
def step(model, opt_state, x, y):
loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
updates, opt_state = optimizer.update(
grads, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
key = jax.random.PRNGKey(0)
for i in range(100):
key, k1, k2 = jax.random.split(key, 3)
x = jax.device_put(jax.random.normal(k1, (256, 128)), sharded)
y = jax.device_put(jax.random.normal(k2, (256, 64)), sharded)
model, opt_state, loss = step(model, opt_state, x, y)
if i % 20 == 0 and jax.process_index() == 0:
print(f"Step {i}, Loss: {float(loss):.4f}")
if jax.process_index() == 0:
return float(loss)
def main():
executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(
account="ACCOUNT",
partition="PARTITION",
nodes=1,
gpus_per_node=2,
ntasks_per_node=2, # One process per GPU
cpus_per_task=8,
)
job = executor.submit(worker_fn)
print(f"Submitted: {job.job_id}")
print(f"Final loss: {job.results()[0]}")
if __name__ == "__main__":
main()The jax.distributed.initialize() call reads SLURM
environment variables to coordinate processes, similar to how
TorchDistributedEnvironment works.
Why not single-process multi-GPU? JAX can run with
ntasks_per_node=1 and see all GPUs from one process. This
should work but causes gradient corruption (NaN, exploding
loss) on some clusters due to GPU-to-GPU communication issues. Using one
process per GPU with jax.distributed routes communication
through NCCL properly and is 100% stable.
Sam Stevens, 2024