From 79dabf8bdf863225bb93da64a73efa3564be7dd7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 16 Apr 2024 17:45:53 +0800 Subject: [PATCH] added seed to dataloader args (#52) --- scripts/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index b097573..9d76bf6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,7 +4,6 @@ from pprint import pprint import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator @@ -12,6 +11,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( get_data_parallel_group, @@ -99,6 +99,7 @@ def main(): dataset=dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, + seed=cfg.seed, shuffle=True, drop_last=True, pin_memory=True,