mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 21:42:26 +02:00
added seed to dataloader args (#52)
This commit is contained in:
parent
a72e59610f
commit
79dabf8bdf
|
|
@ -4,7 +4,6 @@ from pprint import pprint
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
|
@ -12,6 +11,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import (
|
||||
get_data_parallel_group,
|
||||
|
|
@ -99,6 +99,7 @@ def main():
|
|||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
seed=cfg.seed,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
|
|
|
|||
Loading…
Reference in a new issue