added seed to dataloader args (#52)

This commit is contained in:
Frank Lee 2024-04-16 17:45:53 +08:00 committed by GitHub
parent a72e59610f
commit 79dabf8bdf

View file

@ -4,7 +4,6 @@ from pprint import pprint
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
@ -12,6 +11,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
@ -99,6 +99,7 @@ def main():
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
seed=cfg.seed,
shuffle=True,
drop_last=True,
pin_memory=True,