-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
75 lines (56 loc) · 2.4 KB
/
run.py
File metadata and controls
75 lines (56 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import annotations
import sys
import os
import argparse
import subprocess
from pathlib import Path
def main() -> None:
# 1. Check if we are already running in a distributed worker process
# (accelerate sets LOCAL_RANK, RANK, etc.)
if "LOCAL_RANK" in os.environ:
_run_training_entrypoint()
return
# 2. Parse launcher arguments
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--gpus", type=str, help="Comma-separated list of GPU IDs to use (e.g. '0,1')")
parser.add_argument("--num-gpus", type=int, help="Number of GPUs to use")
# Use parse_known_args so unknown args (like --config, --mode) are passed through
launcher_args, script_args = parser.parse_known_args()
# 3. If no multi-GPU intent is detected, assume single-GPU / standard run
if notLauncherArgsArePresent(launcher_args):
# Pass full original args to the training entrypoint (argparse there will handle them)
_run_training_entrypoint()
return
# 4. Construct 'accelerate launch' command
cmd = ["accelerate", "launch"]
# Determine number of processes and visible devices
num_processes = 1
env = os.environ.copy()
if launcher_args.gpus:
env["CUDA_VISIBLE_DEVICES"] = launcher_args.gpus
gpu_ids = launcher_args.gpus.split(",")
num_processes = len(gpu_ids)
# If user also specified num_gpus, ensure it matches
if launcher_args.num_gpus and launcher_args.num_gpus != num_processes:
print(f"Warning: --num-gpus ({launcher_args.num_gpus}) ignored in favor of --gpus count ({num_processes})")
elif launcher_args.num_gpus:
num_processes = launcher_args.num_gpus
cmd.extend(["--num_processes", str(num_processes)])
# Append the script itself
cmd.append(str(Path(__file__).resolve()))
# Append the rest of the arguments (forwarded to the script)
cmd.extend(script_args)
# 5. Execute
try:
subprocess.run(cmd, env=env, check=True)
except subprocess.CalledProcessError as e:
sys.exit(e.returncode)
def _run_training_entrypoint() -> None:
root = Path(__file__).resolve().parent
sys.path.insert(0, str(root / "src"))
from aligndiff.train import main as _main
_main()
def notLauncherArgsArePresent(args):
return not (args.gpus or args.num_gpus)
if __name__ == "__main__":
main()