From ea96cdb5e4ef9c1105ad9d81f66b38be31795ce8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 5 Jun 2024 10:12:52 +0800 Subject: [PATCH 1/4] added warmup lr scheduler (#121) --- opensora/utils/lr_scheduler.py | 22 ++++++++++++++++++++++ scripts/train.py | 14 +++++++++++++- tests/test_lr_scheduler.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 opensora/utils/lr_scheduler.py create mode 100644 tests/test_lr_scheduler.py diff --git a/opensora/utils/lr_scheduler.py b/opensora/utils/lr_scheduler.py new file mode 100644 index 0000000..e0f75f5 --- /dev/null +++ b/opensora/utils/lr_scheduler.py @@ -0,0 +1,22 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupLR(_LRScheduler): + """Linearly warmup learning rate and then linearly decay. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_steps (int, optional): Number of warmup steps, defaults to 0 + last_step (int, optional): The index of last step, defaults to -1. When last_step=-1, + the schedule is started from the beginning or When last_step=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_steps: int = 0, last_epoch: int = -1): + self.warmup_steps = warmup_steps + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] + else: + return self.base_lrs diff --git a/scripts/train.py b/scripts/train.py index 08681dc..73d1fbd 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -17,6 +17,7 @@ from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module +from opensora.utils.lr_scheduler import LinearWarmupLR from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config from opensora.utils.misc import ( @@ -169,7 +170,13 @@ def main(): weight_decay=cfg.get("weight_decay", 0), eps=cfg.get("adam_eps", 1e-8), ) - lr_scheduler = None + + warmup_steps = cfg.get("warmup_steps", None) + + if warmup_steps is None: + lr_scheduler = None + else: + lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps")) # == additional preparation == if cfg.get("grad_checkpoint", False): @@ -288,6 +295,10 @@ def main(): booster.backward(loss=loss, optimizer=optimizer) optimizer.step() optimizer.zero_grad() + + # update learning rate + if lr_scheduler is not None: + lr_scheduler.step() coordinator.block_all() timer_list.append(backward_t) @@ -323,6 +334,7 @@ def main(): "loss": loss.item(), "avg_loss": avg_loss, "acc_step": acc_step, + "lr": optimizer.param_groups[0]["lr"], "move_data_time": move_data_t.elapsed_time, "encode_time": encode_t.elapsed_time, "mask_time": mask_t.elapsed_time, diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 0000000..23b5127 --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,30 @@ +import torch +from torch.optim import Adam +from torchvision.models import resnet50 + +from opensora.utils.lr_scheduler import LinearWarmupLR + + +def test_lr_scheduler(): + model = resnet50().cuda() + optimizer = Adam(model.parameters(), lr=0.01) + scheduler = LinearWarmupLR(optimizer, warmup_steps=10) + current_lr = scheduler.get_lr()[0] + data = torch.rand(128, 3, 224, 224).cuda() + + for i in range(100): + out = model(data) + out.mean().backward() + + optimizer.step() + scheduler.step() + + if i >= 10: + assert scheduler.get_lr()[0] == 0.01 + else: + assert scheduler.get_lr()[0] > current_lr, f"{scheduler.get_lr()[0]} <= {current_lr}" + current_lr = scheduler.get_lr()[0] + + +if __name__ == "__main__": + test_lr_scheduler() From c04f157fec7984f8adf2b2e854ecf2bd5794248e Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Wed, 5 Jun 2024 07:29:45 +0000 Subject: [PATCH 2/4] [config] stage2 training --- configs/opensora-v1-2/train/stage2.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/configs/opensora-v1-2/train/stage2.py b/configs/opensora-v1-2/train/stage2.py index 4ad4dce..6dc2299 100644 --- a/configs/opensora-v1-2/train/stage2.py +++ b/configs/opensora-v1-2/train/stage2.py @@ -8,20 +8,20 @@ dataset = dict( bucket_config = { # 12s/it "144p": {1: (1.0, 475), 51: (1.0, 51), 102: ((1.0, 0.33), 27), 204: ((1.0, 0.1), 13), 408: ((1.0, 0.1), 6)}, # --- - "256": {1: (0.4, 297), 51: (0.5, 20), 102: ((0.5, 0.33), 10), 204: ((0.5, 0.1), 5), 408: ((0.5, 0.1), 2)}, - "240p": {1: (0.3, 297), 51: (0.4, 20), 102: ((0.4, 0.33), 10), 204: ((0.4, 0.1), 5), 408: ((0.4, 0.1), 2)}, + "256": {1: (0.4, 297), 51: (0.5, 20), 102: ((0.5, 0.33), 10), 204: ((0.5, 1.0), 5), 408: ((0.5, 1.0), 2)}, + "240p": {1: (0.3, 297), 51: (0.4, 20), 102: ((0.4, 0.33), 10), 204: ((0.4, 1.0), 5), 408: ((0.4, 1.0), 2)}, # --- - "360p": {1: (0.2, 141), 51: (0.15, 8), 102: ((0.15, 0.33), 4), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)}, - "512": {1: (0.2, 141), 51: (0.15, 8), 102: ((0.15, 0.33), 4), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)}, + "360p": {1: (0.5, 141), 51: (0.15, 8), 102: ((0.3, 0.5), 4), 204: ((0.3, 1.0), 2), 408: ((0.5, 0.5), 1)}, + "512": {1: (0.4, 141), 51: (0.15, 8), 102: ((0.2, 0.4), 4), 204: ((0.2, 1.0), 2), 408: ((0.4, 0.5), 1)}, # --- - "480p": {1: (0.1, 89), 51: (0.1, 5), 102: (0.1, 2), 204: (0.1, 1)}, + "480p": {1: (0.5, 89), 51: (0.1, 5), 102: (0.1, 2), 204: (0.05, 1)}, # --- - "720p": {1: (0.05, 36), 51: (0.1, 1)}, - "1024": {1: (0.05, 36), 51: (0.1, 1)}, + "720p": {1: (0.1, 36), 51: (0.03, 1)}, + "1024": {1: (0.1, 36), 51: (0.02, 1)}, # --- - "1080p": {1: (0.1, 5)}, + "1080p": {1: (0.01, 5)}, # --- - "2048": {1: (0.1, 5)}, + "2048": {1: (0.01, 5)}, } grad_checkpoint = True @@ -88,3 +88,5 @@ grad_clip = 1.0 lr = 1e-4 ema_decay = 0.99 adam_eps = 1e-15 +warmup_steps = 1000 + From d028417a610177902b528bc06b78a736d618d16d Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Wed, 5 Jun 2024 07:42:45 +0000 Subject: [PATCH 3/4] [config] update training config --- configs/opensora-v1-2/train/stage2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/opensora-v1-2/train/stage2.py b/configs/opensora-v1-2/train/stage2.py index 6dc2299..a1d5b67 100644 --- a/configs/opensora-v1-2/train/stage2.py +++ b/configs/opensora-v1-2/train/stage2.py @@ -14,7 +14,7 @@ bucket_config = { # 12s/it "360p": {1: (0.5, 141), 51: (0.15, 8), 102: ((0.3, 0.5), 4), 204: ((0.3, 1.0), 2), 408: ((0.5, 0.5), 1)}, "512": {1: (0.4, 141), 51: (0.15, 8), 102: ((0.2, 0.4), 4), 204: ((0.2, 1.0), 2), 408: ((0.4, 0.5), 1)}, # --- - "480p": {1: (0.5, 89), 51: (0.1, 5), 102: (0.1, 2), 204: (0.05, 1)}, + "480p": {1: (0.5, 89), 51: (0.2, 5), 102: (0.2, 2), 204: (0.1, 1)}, # --- "720p": {1: (0.1, 36), 51: (0.03, 1)}, "1024": {1: (0.1, 36), 51: (0.02, 1)}, From 5f9e6422789d2861f744f9d01d2b2a4451988b67 Mon Sep 17 00:00:00 2001 From: "Zheng Zangwei (Alex Zheng)" Date: Wed, 5 Jun 2024 16:41:00 +0800 Subject: [PATCH 4/4] [wip] fix scoring (#120) * [wip] fix scoring * minor update --- requirements/requirements-data.txt | 1 + tools/datasets/utils.py | 1 + tools/scoring/aesthetic/inference.py | 54 +++++++++++++++---------- tools/scoring/optical_flow/inference.py | 50 ++++++++++++++--------- 4 files changed, 64 insertions(+), 42 deletions(-) diff --git a/requirements/requirements-data.txt b/requirements/requirements-data.txt index 008b90f..7a2d38d 100644 --- a/requirements/requirements-data.txt +++ b/requirements/requirements-data.txt @@ -20,6 +20,7 @@ lingua-language-detector==2.0.2 imageio>=2.34.1 # [aesthetic] +setuptools==68.2.2 clip @ git+https://github.com/openai/CLIP.git # [ocr] diff --git a/tools/datasets/utils.py b/tools/datasets/utils.py index fe723bf..ec7e1b9 100644 --- a/tools/datasets/utils.py +++ b/tools/datasets/utils.py @@ -1,6 +1,7 @@ import os import cv2 +import numpy as np from PIL import Image IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index a291914..5d791c1 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -1,5 +1,8 @@ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py +import cv2 # isort:skip + import argparse +import gc import os from datetime import timedelta @@ -11,7 +14,6 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from PIL import Image from torch.utils.data import DataLoader, DistributedSampler from torchvision.datasets.folder import pil_loader from tqdm import tqdm @@ -24,6 +26,7 @@ NUM_FRAMES_POINTS = { 3: (0.1, 0.5, 0.9), } + def merge_scores(gathered_list: list, meta: pd.DataFrame, column): # reorder indices_list = list(map(lambda x: x[0], gathered_list)) @@ -41,32 +44,36 @@ def merge_scores(gathered_list: list, meta: pd.DataFrame, column): # filter duplicates unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) meta.loc[unique_indices, column] = flat_scores[unique_indices_idx] - - # jun 3 quickfix - # lose indices in meta not in unique_indices + + # drop indices in meta not in unique_indices meta = meta.loc[unique_indices] return meta class VideoTextDataset(torch.utils.data.Dataset): - def __init__(self, csv_path, transform=None, num_frames=3): - self.csv_path = csv_path - self.meta = pd.read_csv(csv_path) + def __init__(self, meta_path, transform=None, num_frames=3): + self.meta_path = meta_path + self.meta = pd.read_csv(meta_path) self.transform = transform self.points = NUM_FRAMES_POINTS[num_frames] def __getitem__(self, index): sample = self.meta.iloc[index] path = sample["path"] + + # extract frames if not is_video(path): images = [pil_loader(path)] else: - num_frames = None - if "num_frames" in sample: - num_frames = sample["num_frames"] + num_frames = sample["num_frames"] if "num_frames" in sample else None images = extract_frames(sample["path"], points=self.points, backend="opencv", num_frames=num_frames) + + # transform images = [self.transform(img) for img in images] + + # stack images = torch.stack(images) + ret = dict(index=index, images=images) return ret @@ -97,7 +104,6 @@ class AestheticScorer(nn.Module): def __init__(self, input_size, device): super().__init__() self.mlp = MLP(input_size) - self.mlp.load_state_dict(torch.load("pretrained_models/aesthetic.pth")) self.clip, self.preprocess = clip.load("ViT-L/14", device=device) self.eval() @@ -122,6 +128,7 @@ def main(): # build model device = "cuda" if torch.cuda.is_available() else "cpu" model = AestheticScorer(768, device) + model.mlp.load_state_dict(torch.load("pretrained_models/aesthetic.pth", map_location=device)) preprocess = model.preprocess # build dataset @@ -138,7 +145,7 @@ def main(): drop_last=False, ), ) - + # compute aesthetic scores indices_list = [] scores_list = [] @@ -153,26 +160,28 @@ def main(): # compute score with torch.no_grad(): scores = model(images) + scores = rearrange(scores, "(B N) 1 -> B N", B=B) scores = scores.mean(dim=1) scores_np = scores.to(torch.float32).cpu().numpy() - indices_list.extend(indices) - scores_list.extend(scores_np) - - # jun 3 quickfix - meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column='aes') - out_path_local = out_path.replace('.csv', f'_part_{dist.get_rank()}.csv') + indices_list.extend(indices.tolist()) + scores_list.extend(scores_np.tolist()) + + # save local results + meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="aes") + out_path_local = out_path.replace(".csv", f"_part_{dist.get_rank()}.csv") meta_local.to_csv(out_path_local, index=False) # wait for all ranks to finish data processing - dist.barrier() + dist.barrier() + torch.cuda.empty_cache() + gc.collect() gathered_list = [None] * dist.get_world_size() - breakpoint() dist.all_gather_object(gathered_list, (indices_list, scores_list)) if dist.get_rank() == 0: - meta_new = merge_scores(gathered_list, dataset.meta, column='aes') + meta_new = merge_scores(gathered_list, dataset.meta, column="aes") meta_new.to_csv(out_path, index=False) print(f"New meta with aesthetic scores saved to '{out_path}'.") @@ -182,11 +191,12 @@ def parse_args(): parser.add_argument("meta_path", type=str, help="Path to the input CSV file") parser.add_argument("--bs", type=int, default=1024, help="Batch size") parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") - parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor") + parser.add_argument("--prefetch_factor", type=int, default=3, help="Prefetch factor") parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract") args = parser.parse_args() return args + if __name__ == "__main__": main() diff --git a/tools/scoring/optical_flow/inference.py b/tools/scoring/optical_flow/inference.py index 3f7d85f..2ea0d7a 100644 --- a/tools/scoring/optical_flow/inference.py +++ b/tools/scoring/optical_flow/inference.py @@ -1,6 +1,7 @@ import cv2 # isort:skip import argparse +import gc import os from datetime import timedelta @@ -38,8 +39,7 @@ def merge_scores(gathered_list: list, meta: pd.DataFrame, column): unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) meta.loc[unique_indices, column] = flat_scores[unique_indices_idx] - # jun 3 quickfix - # lose indices in meta not in unique_indices + # drop indices in meta not in unique_indices meta = meta.loc[unique_indices] return meta @@ -51,32 +51,30 @@ class VideoTextDataset(torch.utils.data.Dataset): self.frame_inds = frame_inds def __getitem__(self, index): - row = self.meta.iloc[index] - images = extract_frames(row["path"], frame_inds=self.frame_inds, backend="opencv") + sample = self.meta.iloc[index] + path = sample["path"] + + # extract frames + images = extract_frames(path, frame_inds=self.frame_inds, backend="opencv") # transform - images = torch.stack([pil_to_tensor(x) for x in images]) # shape: [N, C, H, W]; dtype: torch.uint8 + images = torch.stack([pil_to_tensor(x) for x in images]) + + # stack + # shape: [N, C, H, W]; dtype: torch.uint8 images = images.float() H, W = images.shape[-2:] if H > W: images = rearrange(images, "N C H W -> N C W H") images = F.interpolate(images, size=(320, 576), mode="bilinear", align_corners=True) - return images, index + ret = dict(index=index, images=images) + return ret def __len__(self): return len(self.meta) -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("meta_path", type=str, help="Path to the input CSV file") - parser.add_argument("--bs", type=int, default=4, help="Batch size") # don't use too large bs for unimatch - parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") - args = parser.parse_args() - return args - - def main(): args = parse_args() @@ -124,10 +122,11 @@ def main(): indices_list = [] scores_list = [] model.eval() - for images, indices in tqdm(dataloader, disable=dist.get_rank() != 0): - images = images.to(device) - B = images.shape[0] + for batch in tqdm(dataloader, disable=dist.get_rank() != 0): + indices = batch["index"] + images = batch["images"].to(device, non_blocking=True) + B = images.shape[0] batch_0 = rearrange(images[:, :-1], "B N C H W -> (B N) C H W").contiguous() batch_1 = rearrange(images[:, 1:], "B N C H W -> (B N) C H W").contiguous() @@ -148,10 +147,10 @@ def main(): flow_scores = flow_maps.abs().mean(dim=[1, 2, 3, 4]) flow_scores = flow_scores.tolist() - indices_list.extend(indices) + indices_list.extend(indices.tolist()) scores_list.extend(flow_scores) - # jun 3 quickfix + # save local results meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="flow") out_path_local = out_path.replace(".csv", f"_part_{dist.get_rank()}.csv") meta_local.to_csv(out_path_local, index=False) @@ -159,6 +158,8 @@ def main(): # wait for all ranks to finish data processing dist.barrier() + torch.cuda.empty_cache() + gc.collect() gathered_list = [None] * dist.get_world_size() dist.all_gather_object(gathered_list, (indices_list, scores_list)) if dist.get_rank() == 0: @@ -167,5 +168,14 @@ def main(): print(f"New meta with optical flow scores saved to '{out_path}'.") +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=4, help="Batch size") # don't use too large bs for unimatch + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + args = parser.parse_args() + return args + + if __name__ == "__main__": main()