mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
parent
338bd83f24
commit
dc7b7fd64a
|
|
@ -249,7 +249,7 @@ def load(
|
|||
load_dir: str,
|
||||
sampler=None,
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
booster.load_model(model, os.path.join(load_dir, model_name))
|
||||
# ema is not boosted, so we don't use booster.load_model
|
||||
ema.load_state_dict(
|
||||
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def main():
|
|||
)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
import random
|
||||
from datetime import timedelta
|
||||
from pprint import pprint
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
|
|
@ -15,7 +14,7 @@ from tqdm import tqdm
|
|||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
|
|
@ -27,68 +26,64 @@ from opensora.utils.config_utils import (
|
|||
save_training_config,
|
||||
)
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype
|
||||
from opensora.utils.train_utils import create_colossalai_plugin
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. args & cfg
|
||||
# 1. configs & runtime variables & colossalai launch
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=True)
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables & colossalai launch
|
||||
# ======================================================
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
|
||||
# 2.1. colossalai init distributed training
|
||||
# we set a very large timeout to avoid some processes exit early
|
||||
# == colossalai init distributed training ==
|
||||
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
|
||||
# 2.2. init exp_dir, logger, tensorboard & wandb
|
||||
# == init exp_dir ==
|
||||
exp_name, exp_dir = define_experiment_workspace(cfg)
|
||||
coordinator.block_all()
|
||||
if coordinator.is_master():
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
save_training_config(cfg.to_dict(), exp_dir)
|
||||
coordinator.block_all()
|
||||
|
||||
if not coordinator.is_master():
|
||||
logger = create_logger(None)
|
||||
else:
|
||||
print("Training configuration:")
|
||||
pprint(cfg._cfg_dict)
|
||||
# == init logger, tensorboard & wandb ==
|
||||
if coordinator.is_master():
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info(f"Experiment directory created at {exp_dir}")
|
||||
|
||||
writer = create_tensorboard_writer(exp_dir)
|
||||
if cfg.wandb:
|
||||
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
|
||||
|
||||
# 2.3. initialize ColossalAI booster
|
||||
if cfg.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
tb_writer = create_tensorboard_writer(exp_dir)
|
||||
if cfg.get("wandb", False):
|
||||
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict())
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {cfg.plugin}")
|
||||
logger = create_logger(None)
|
||||
|
||||
# == init ColossalAI booster ==
|
||||
plugin = create_colossalai_plugin(
|
||||
plugin=cfg.get("plugin", "zero2"),
|
||||
dtype=cfg_dtype,
|
||||
grad_clip=cfg.get("grad_clip", 0),
|
||||
sp_size=cfg.get("sp_size", 1),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# 2. build dataset and dataloader
|
||||
# ======================================================
|
||||
assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for now"
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info(f"Dataset contains {len(dataset)} samples.")
|
||||
logger.info("Dataset contains %s samples.", len(dataset))
|
||||
# == build dataloader ==
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
|
|
@ -99,76 +94,93 @@ def main():
|
|||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
# TODO: use plugin's prepare dataloader
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size()
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.bucket_config,
|
||||
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
||||
**dataloader_args,
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# 4. build model
|
||||
# 3. build model
|
||||
# ======================================================
|
||||
# 4.1. build model
|
||||
model = build_module(cfg.model, MODELS)
|
||||
model.to(device, dtype)
|
||||
# == build vae model ==
|
||||
model = build_module(cfg.model, MODELS).to(device, dtype)
|
||||
model.train()
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
|
||||
"Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(model_numel_trainable),
|
||||
format_numel_str(model_numel),
|
||||
)
|
||||
# == build discriminator model ==
|
||||
if cfg.get("discriminator", False) != False:
|
||||
discriminator = build_module(cfg.discriminator, MODELS).to(device, dtype)
|
||||
discriminator.train()
|
||||
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
||||
logger.info(
|
||||
"Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(discriminator_numel_trainable),
|
||||
format_numel_str(discriminator_numel),
|
||||
)
|
||||
|
||||
# 4.4 loss functions
|
||||
# == setup loss functions ==
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.get("logvar_init", 0.0),
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1),
|
||||
kl_loss_weight=cfg.get("kl_loss_weight", 1e-6),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if cfg.get("discriminator", False) != False:
|
||||
adversarial_loss_fn = AdversarialLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
generator_factor=cfg.generator_factor,
|
||||
generator_loss_type=cfg.generator_loss_type,
|
||||
discriminator_factor=cfg.get("discriminator_factor", 1),
|
||||
discriminator_start=cfg.get("discriminator_start", -1),
|
||||
generator_factor=cfg.get("generator_factor", 0.5),
|
||||
generator_loss_type=cfg.get("generator_loss_type", "hinge"),
|
||||
)
|
||||
|
||||
disc_loss_fn = DiscriminatorLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
discriminator_factor=cfg.get("discriminator_factor", 1),
|
||||
discriminator_start=cfg.get("discriminator_start", -1),
|
||||
discriminator_loss_type=cfg.get("discriminator_loss_type", "hinge"),
|
||||
lecam_loss_weight=cfg.get("lecam_loss_weight", None),
|
||||
gradient_penalty_loss_weight=cfg.get("gradient_penalty_loss_weight", None),
|
||||
)
|
||||
|
||||
# 4.5. setup optimizer
|
||||
# vae optimizer
|
||||
# == setup vae optimizer ==
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
adamw_mode=True,
|
||||
lr=cfg.get("lr", 1e-5),
|
||||
weight_decay=cfg.get("weight_decay", 0),
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
# 4.6. prepare for training
|
||||
if cfg.grad_checkpoint:
|
||||
set_grad_checkpoint(model)
|
||||
model.train()
|
||||
|
||||
# 4.7 add discriminator if specified in config
|
||||
# == setup discriminator optimizer ==
|
||||
if cfg.get("discriminator", False) != False:
|
||||
discriminator = build_module(cfg.discriminator, MODELS)
|
||||
discriminator.to(device, dtype)
|
||||
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
||||
logger.info(
|
||||
f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
|
||||
)
|
||||
disc_optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
|
||||
filter(lambda p: p.requires_grad, discriminator.parameters()),
|
||||
adamw_mode=True,
|
||||
lr=cfg.get("lr", 1e-5),
|
||||
weight_decay=cfg.get("weight_decay", 0),
|
||||
)
|
||||
disc_lr_scheduler = None
|
||||
discriminator.train()
|
||||
|
||||
# == additional preparation ==
|
||||
if cfg.get("grad_checkpoint", False):
|
||||
set_grad_checkpoint(model)
|
||||
|
||||
# =======================================================
|
||||
# 5. boost model for distributed training with colossalai
|
||||
# 4. distributed training preparation with colossalai
|
||||
# =======================================================
|
||||
# == boosting ==
|
||||
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
||||
torch.set_default_dtype(dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
|
|
@ -182,49 +194,53 @@ def main():
|
|||
optimizer=disc_optimizer,
|
||||
lr_scheduler=disc_lr_scheduler,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
logger.info("Boost model for distributed training")
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
logger.info("Boosting model for distributed training")
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
else:
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
|
||||
|
||||
# =======================================================
|
||||
# 6. training loop
|
||||
# =======================================================
|
||||
start_epoch = start_step = log_step = sampler_start_idx = 0
|
||||
acc_step = 0
|
||||
cfg_epochs = cfg.get("epochs", 1000)
|
||||
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
||||
|
||||
# == global variables ==
|
||||
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
||||
running_loss = 0.0
|
||||
running_disc_loss = 0.0
|
||||
# 6.1. resume training
|
||||
if cfg.load is not None:
|
||||
|
||||
# == resume ==
|
||||
if cfg.get("load", None) is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
booster.load_model(model, os.path.join(cfg.load, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
|
||||
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
|
||||
if cfg.get("discriminator", False) != False and os.path.exists(os.path.join(cfg.load, "discriminator")):
|
||||
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
|
||||
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
|
||||
|
||||
dist.barrier()
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
start_epoch, start_step, sampler_start_idx = (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
||||
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
||||
dist.barrier()
|
||||
|
||||
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# 6.3. training loop
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
# =======================================================
|
||||
# 5. training loop
|
||||
# =======================================================
|
||||
for epoch in range(start_epoch, cfg_epochs):
|
||||
# == set dataloader to new epoch ==
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info(f"Beginning epoch {epoch}...")
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
|
|
@ -239,35 +255,36 @@ def main():
|
|||
length = random.randint(1, x.size(2))
|
||||
x = x[:, :, :length, :, :]
|
||||
|
||||
# ===== VAE =====
|
||||
# == vae encoding ===
|
||||
x_rec, x_z_rec, z, posterior, x_z = model(x)
|
||||
|
||||
# ====== Generator Loss ======
|
||||
# == loss initialization ==
|
||||
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
||||
disc_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
||||
|
||||
log_dict = {}
|
||||
|
||||
# real image reconstruction loss
|
||||
# == real image reconstruction loss computation ==
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
log_dict["kl_loss"] = weighted_kl_loss.item()
|
||||
log_dict["nll_loss"] = weighted_nll_loss.item()
|
||||
if cfg.get("use_real_rec_loss", False):
|
||||
vae_loss += weighted_nll_loss + weighted_kl_loss
|
||||
|
||||
# == temporal vae reconstruction loss computation ==
|
||||
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
|
||||
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
|
||||
# z reconstruction loss
|
||||
|
||||
# == use z reconstruction loss ==
|
||||
if cfg.get("use_z_rec_loss", False):
|
||||
vae_loss += weighted_z_nll_loss
|
||||
|
||||
# only for image
|
||||
# == use image loss only ==
|
||||
if cfg.get("use_image_identity_loss", False) and x.size(2) == 1:
|
||||
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior, no_perceptual=True)
|
||||
vae_loss += image_identity_loss
|
||||
log_dict["image_identity_loss"] = image_identity_loss.item()
|
||||
|
||||
# Adversarial Generator Loss
|
||||
# == generator adversarial loss ==
|
||||
if cfg.get("discriminator", False) != False:
|
||||
recon_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
|
|
@ -281,12 +298,12 @@ def main():
|
|||
)
|
||||
vae_loss += adversarial_loss
|
||||
|
||||
# Backward & update
|
||||
# == generator backward & update ==
|
||||
optimizer.zero_grad()
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
|
||||
# Adversarial Discriminator loss
|
||||
# == discriminator adversarial loss ==
|
||||
if cfg.get("discriminator", False) != False:
|
||||
real_video = rearrange(x, "b c t h w -> (b t) c h w").contiguous()
|
||||
fake_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
|
||||
|
|
@ -298,31 +315,30 @@ def main():
|
|||
global_step,
|
||||
)
|
||||
disc_loss = weighted_d_adversarial_loss
|
||||
# Backward & update
|
||||
|
||||
# == discriminator backward & update ==
|
||||
disc_optimizer.zero_grad()
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
disc_optimizer.step()
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
||||
# Log loss values:
|
||||
# == update log info ==
|
||||
all_reduce_mean(vae_loss)
|
||||
running_loss += vae_loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
|
||||
# Log to tensorboard
|
||||
# == logging ==
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
avg_disc_loss = running_disc_loss / log_step
|
||||
pbar.set_postfix(
|
||||
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
|
||||
)
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
running_disc_loss = 0
|
||||
writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
# tensorboard
|
||||
tb_writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
|
|
@ -335,8 +351,11 @@ def main():
|
|||
},
|
||||
step=global_step,
|
||||
)
|
||||
running_loss = 0
|
||||
running_disc_loss = 0
|
||||
log_step = 0
|
||||
|
||||
# Save checkpoint
|
||||
# == checkpoint saving ==
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
|
||||
|
|
@ -356,15 +375,25 @@ def main():
|
|||
"global_step": global_step + 1,
|
||||
"sample_start_index": (step + 1) * cfg.batch_size,
|
||||
}
|
||||
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
"Saved checkpoint at epoch %s step %s global_step %s to %s",
|
||||
epoch,
|
||||
step + 1,
|
||||
global_step + 1,
|
||||
exp_dir,
|
||||
)
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(0)
|
||||
else:
|
||||
dataloader.batch_sampler.set_epoch(epoch + 1)
|
||||
logger.info("Epoch done, recomputing batch sampler")
|
||||
start_step = 0
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue