mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
parent
d53715c0c2
commit
db94454645
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -177,3 +177,4 @@ pretrained_models/
|
|||
# Secret files
|
||||
hostfile
|
||||
gradio_cached_examples/
|
||||
wandb/
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pprint
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
|
@ -9,7 +9,7 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
|
|
@ -47,7 +47,10 @@ def main():
|
|||
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
|
||||
|
||||
# 2.1. colossalai init distributed training
|
||||
colossalai.launch_from_torch({})
|
||||
# we set a very large timeout to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue