Behold, My Stuff

[Home] [Writing] [CV] [Contact]

How to Use submitit on Slurm Systems.

submitit is a Python package by Facebook/Meta for submitting one or more Python jobs to Slurm clusters, like those at OSC or Facebook.

While it’s easy to use, there were a couple gotchas and lessons I had to learn before I really appreciated how useful it was compared to writing Bash scripts. This article covers those, with plenty of examples and explanation.

I have used submitit for several projects, including BioBench, saev, frx (unmaintained at the moment), and an I-JEPA reimplementation (also unmaintained at the moment). These repos will contain complete examples of how I use submitit in a larger context than a blog post can provide.

I recommend reading from the top to the bottom as the tips are ordered in relative importance/utility.

Table of Contents

  1. Minimal Example
  2. Why Use submitit?
  3. GPUs and CPUs
  4. Multi-GPU Training in Torch
  5. DebugExecutor
  6. stderr_to_stdout=True
  7. Global Variables
  8. Environment Variables
  9. Submitting GPU Jax Jobs on CPU-Only Nodes
  10. Complete Examples

Minimal Example

Minimal example for submitting jobs on a cluster.

def add(a, b):
    return a + b

executor = submitit.AutoExecutor(folder="./logs")
executor.update_parameters(
    timeout_min=1, slurm_partition="gpu"
)
job = executor.submit(add, 5, 7)

output = job.result()
assert output == 12

This is directly from the submitit README.

Why Use submitit?

Understand the motivation for submitit.

I learned to interact with Slurm clusters through bash scripts and sbatch. We used this a lot in BioCLIP, like this launch script to run make_wds.py. I would submit it with sbatch slurm/make-dataset-wds.sh.

This sucked for a couple reasons.

  1. I had to write scripts in another language (bash) when I wanted to use Python for everything.
  2. I had to edit the bash script to change the Python script arguments. I had a very nice argument parser in my Python script with argparse and help text, etc. but I couldn’t use it because I didn’t have an argument parser for make-datset-wds.sh.
  3. I couldn’t easily programatically launch many jobs at once. Sometimes I have a config file that specifies a sweep of jobs, and I want to launch many jobs with one script. But because I can’t write bash very well, I’ve written Python scripts to parse the config files, then launch the jobs, with code like this:
# buggy and error-prone code; use submitit instead.
command = [
    "sbatch",
    f"--output=./logs/{job_name}-%j.log",
    f"--job-name={job_name}",
    f"--export=CONFIG_FILE={config_file}",
    template_file,
]
try:
    output = subprocess.run(
      command, check=True, capture_output=True
    )
    print(output.stdout.decode("utf-8"), end="")
except subprocess.CalledProcessError as e:
    print(e.stderr.decode("utf-8"), end="")
    print(e)
  1. Re-launching jobs is also a hassle. If an experiment is checkpointable, I would like to restart jobs when they end without having to log back in.
  2. Setting up your jobs to run on both Slurm cluster and local clusters is challenging.

submitit solves all of these pain points. If you look at projects like BioBench or saev, there’s no bash scripts whatsoever, but they run on Slurm clusters and local s

GPUs and CPUs

Setting number of GPUs and CPUs in your jobs.

executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
    time=120,
    partition="gpu",
    account="ACCOUNT",
    # These args are important.
    gpus_per_node=4,
    ntasks_per_node=2,
    cpus_per_task=12,
)
# calls job_fn ntasks_per_node times in parallel.
executor.submit(job_fn)  

gpus_per_node is GPUs per node, ntasks_per_node is the number of processes that call your function, and cpus_per_task is the number of CPUs available per task. So if you want to run two tasks, each with two GPUs, and a total of 24 CPUs, you need ntasks_per_node=2, gpus_per_node=4, and cpus_per_task=12.

Multi-GPU Training in Torch

torch.distributed

executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
    time=12 * 60,
    partition="gpu",
    account="ACCOUNT",
    # These args are important.
    gpus_per_node=4,
    ntasks_per_node=4,
    cpus_per_task=12,
)
executor.submit(train).result()

def train():
    dist_env = submitit.helpers.TorchDistributedEnvironment().export()

    torch.distributed.init_process_group(
        backend="nccl", world_size=dist_env.world_size
    )
    assert dist_env.rank == torch.distributed.get_rank()
    assert dist_env.world_size == torch.distributed.get_world_size()

When setting up your executor, be sure to set ntasks_per_node to the same number as gpus_per_node so that every GPU has a task.

The submitit.helpers.TorchDistributedEnvironment class somehow handles environment variables so that PyTorch can setup the distributed environment correctly.

Note that if you don’t have a CUDA device available, you cannot call init_process_group so you probably want to handle that.

DebugExecutor

Run code in the current process rather than in a Slurm job.

if debug:
    executor = submitit.DebugExecutor(folder=args.log_to)
else:
    executor = submitit.SlurmExecutor(folder="./logs")
    executor.update_parameters(
        time=30, partition="PARTITION", account="SLURM_ACCT"
    )

# Use executor as normal.
executor.submit(job_fn, arg1, arg2).result()

If you are debugging jobs, you likely want to use pdb or other interactive debuggers. You cannot use pdb in a “headless” process like a Slurm job. However, the submitit.DebugExecutor will run jobs in the same process that you create the executor from. This is really useful for debugging jobs, because DebugExecutor has the same API as SlurmExecutor so you can split up your executor construction code and then debug your jobs.

This solves problem #5.

  1. Setting up your jobs to run on both Slurm cluster and local clusters is challenging.

stderr_to_stdout=True

Don’t split up stderr and stdout logs.

executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
    time=30,
    partition="PARTITION",
    account="SLURM_ACCT",
    stderr_to_stdout=True,  # <- This line
)

Most of the time I’m not interested in the distinction between stderr and stdout because I just care about outputs. print() in Python goes to stdout, logging.info() goes to stderr. If you mix them, it can be irritating to try and understand how your debugging statements are ordered (but you should also use pdb in a DebugExecutor instead of print statements). Setting stderr_to_stdout=True in executor.update_parameters()writes everything to the same stream.

Global Variables

Global variables don’t work.

def main():
    if not use_ssl:
        import ssl
        # By default do not use HTTPS
        ssl._create_default_https_context = ssl._create_unverified_context

    executor = submitit.SlurmExecutor(folder="./logs")
    executor.update_parameters(
        time=30, partition="PARTITION", account="SLURM_ACCT"
    )
    exeuctor.submit(job_fn)


def job_fn():
    print(ssl._create_default_https_context)
    # Will not be an unverified context

If you want to set global variables after the program is running but before you submit jobs, these variables will not persist in your jobs. In the example above, I want to set Python’s ssl’s module to ignore HTTPS certs by setting ssl._create_default_https_context to ssl._create_unverified_context. However, in the job function job_fn, ssl._create_default_https_context will not be set correctly.

Environment Variables

How to set environment variables in Slurm jobs.

executor = submitit.SlurmExecutor(folder="./logs")
executor.update_parameters(
    time=30, partition="PARTITION", account="SLURM_ACCT"
)

if not use_ssl:
    executor.update_parameters(setup=[
        "export DISABLE_SSL=1",
        "export HAVE_FUN=2",
    ])

If you want to set global variables, you might end up using environment variables. While using environment variables to manage program state is almost always a source of bugs, if you absolutely need to, you can use the setup parameter to set environment variables in the running Slurm jobs.

Submitting GPU Jax Jobs on CPU-Only Nodes

Weird Jax issues with GPUs.

executor = submitit.SlurmExecutor(folder="logs")
executor.update_parameters(
    time=12 * 60,
    partition="PARTITION",
    account="ACCOUNT",
    setup=["export JAX_PLATFORMS=''"],
)

For whatever reason, we cannot import Jax without a GPU. That is, if you run import jax with jax[cuda12] installed in your environment, you get an exception about “No GPUs found” or something like that. But you can use JAX_PLATFORMS=cpu before uv run python -c "import jax; print(jax)" and it will work fine. But, if you set JAX_PLATFORMS=cpu to run this launcher script, then it will be true for the submitted jobs. This means that your training jobs will run on the CPU instead of the cluster GPUs.

This extra arg exports an updated JAX_PLATFORMS variable for the cluster jobs and it will find the GPUs for training.

Complete Examples

Complete PyTorch and Jax example with all these tricks.

PyTorch:

# import ...

def train(device: str):
    dist_env = submitit.helpers.TorchDistributedEnvironment().export()

    is_ddp = False
    global_rank = 0
    local_rank = 0
    is_master = False
    if device == "cuda":
        torch.distributed.init_process_group(
            backend="nccl", world_size=dist_env.world_size
        )
        assert dist_env.rank == torch.distributed.get_rank()
        assert dist_env.world_size == torch.distributed.get_world_size()

        is_ddp = True
        global_rank = dist_env.rank
        local_rank = dist_env.local_rank
        is_master = dist_env.rank == 0

    logger = logging.getLogger("Rank %d", args.global_rank)

    model = ViT()

    if is_ddp:
        model = torch.nn.parallel.DistributedDataParallel(model)

    # Train as normal

def main():
    

Jax:

# TODO

[Relevant link] [Source]

Sam Stevens, 2024