From db9445464505010165ffa343b11c844bb6ad3c91 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 15 Apr 2024 23:55:58 +0800 Subject: [PATCH] Feature/timeout (#50) * added large nccl timeout * polish * polish --- .gitignore | 1 + scripts/train.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 55daee5..8797ae7 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,4 @@ pretrained_models/ # Secret files hostfile gradio_cached_examples/ +wandb/ diff --git a/scripts/train.py b/scripts/train.py index d25c4a4..b097573 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,7 +1,7 @@ from copy import deepcopy +from datetime import timedelta from pprint import pprint -import colossalai import torch import torch.distributed as dist import wandb @@ -9,7 +9,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, set_seed from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint @@ -47,7 +47,10 @@ def main(): assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" # 2.1. colossalai init distributed training - colossalai.launch_from_torch({}) + # we set a very large timeout to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(1024) coordinator = DistCoordinator() device = get_current_device() dtype = to_torch_dtype(cfg.dtype)