diff --git a/scripts/train.py b/scripts/train.py index b097573..9d76bf6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,7 +4,6 @@ from pprint import pprint import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator @@ -12,6 +11,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( get_data_parallel_group, @@ -99,6 +99,7 @@ def main(): dataset=dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, + seed=cfg.seed, shuffle=True, drop_last=True, pin_memory=True,