diff --git a/tools/caption/utils.py b/tools/caption/utils.py index 978b96c..eb43566 100644 --- a/tools/caption/utils.py +++ b/tools/caption/utils.py @@ -9,6 +9,8 @@ import torchvision.transforms as transforms from PIL import Image from torchvision.datasets.folder import pil_loader +from tools.datasets.transform import extract_frames_new + PROMPTS = { "image": { "text": "Describe this image and its style to generate a succinct yet informative description. Pay attention to all objects in the image. The description should be useful for AI to re-generate the image. The description should be no more than five sentences. Remember do not exceed 5 sentences.", @@ -101,7 +103,10 @@ class VideoTextDataset(torch.utils.data.Dataset): images = [pil_loader(path)] length = 1 else: - images, length = extract_frames(sample["path"], points=self.points) + # images, length = extract_frames(sample["path"], points=self.points) + images, length = extract_frames_new( + sample["path"], points=self.points, backend="opencv", return_length=True + ) if self.resize_size is not None: images_r = [] for img in images: diff --git a/tools/datasets/transform.py b/tools/datasets/transform.py index 87673ad..e59c316 100644 --- a/tools/datasets/transform.py +++ b/tools/datasets/transform.py @@ -101,6 +101,91 @@ def extract_frames(video_path, input_dir, output, point): return path_new +def extract_frames_new( + video_path, + frame_inds=[0, 10, 20, 30], + points=None, + backend='opencv', + return_length=False, +): + """ + Args: + video_path (str): path to video + frame_inds (List[int]): indices of frames to extract + points (List[float]): values within [0, 1); multiply #frames to get frame indices + Return: + List[PIL.Image] + """ + assert backend in ['av', 'opencv', 'decord'] + assert (frame_inds is None) or (points is None) + + if backend == 'av': + import av + container = av.open(video_path) + total_frames = container.streams.video[0].frames + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + target_timestamp = int( + idx * av.time_base / container.streams.video[0].average_rate + ) + container.seek(target_timestamp) + frame = next(container.decode(video=0)).to_image() + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + + elif backend == 'decord': + import decord + container = decord.VideoReader(video_path, num_threads=1) + total_frames = len(container) + # avg_fps = container.get_avg_fps() + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frame_inds = np.array(frame_inds).astype(np.int32) + frame_inds[frame_inds >= total_frames] = total_frames - 1 + frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] + frames = [Image.fromarray(x) for x in frames] + + if return_length: + return frames, total_frames + return frames + + elif backend == 'opencv': + import cv2 + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + else: + raise ValueError + + def main(args): data = pd.read_csv(args.input) if args.method == "img_rand_crop": @@ -145,3 +230,14 @@ if __name__ == "__main__": if args.disable_parallel: pandas_has_parallel = False main(args) + exit() + + from torchvision.transforms.functional import pil_to_tensor + ret = extract_frames_new( + 'E:/data/video/pexels_new/8974385_scene-0.mp4', + frame_inds=[0, 50, 100, 150], + backend='opencv') + for idx, img in enumerate(ret): + save_path = f'./checkpoints/vis/{idx}.png' + ret[idx].save(save_path) + exit() diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index 6b98ad4..dcbf9a1 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -4,7 +4,6 @@ import os from datetime import timedelta import clip -import decord import numpy as np import pandas as pd import torch @@ -17,6 +16,8 @@ from PIL import Image from torchvision.datasets.folder import pil_loader from tqdm import tqdm +from tools.datasets.transform import extract_frames_new + try: from torchvision.transforms import InterpolationMode @@ -34,16 +35,6 @@ def is_video(filename): return ext in VID_EXTENSIONS -def extract_frames(video_path, points=(0.1, 0.5, 0.9)): - container = decord.VideoReader(video_path, num_threads=1) - total_frames = len(container) - frame_inds = (np.array(points) * total_frames).astype(np.int32) - frame_inds[frame_inds >= total_frames] = total_frames - 1 - frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] - frames_pil = [Image.fromarray(frame) for frame in frames] - return frames_pil - - class VideoTextDataset(torch.utils.data.Dataset): def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)): self.csv_path = csv_path @@ -57,7 +48,7 @@ class VideoTextDataset(torch.utils.data.Dataset): if not is_video(path): images = [pil_loader(path)] else: - images = extract_frames(sample["path"], points=self.points) + images = extract_frames_new(sample["path"], points=self.points, backend="opencv") images = [self.transform(img) for img in images] images = torch.stack(images) return dict(index=index, images=images) diff --git a/tools/scoring/matching/inference_parallel.py b/tools/scoring/matching/inference_parallel.py index e8a1704..23ddddc 100644 --- a/tools/scoring/matching/inference_parallel.py +++ b/tools/scoring/matching/inference_parallel.py @@ -13,6 +13,8 @@ from torch.utils.data import DataLoader, DistributedSampler from torchvision.datasets.folder import pil_loader from tqdm import tqdm +from tools.datasets.transform import extract_frames_new + IMG_EXTENSIONS = ( ".jpg", ".jpeg", @@ -32,21 +34,6 @@ def is_video(filename): return ext in VID_EXTENSIONS -def extract_frames(video_path, points=[0.5]): - container = av.open(video_path) - total_frames = container.streams.video[0].frames - frames = [] - for point in points: - target_frame = total_frames * point - target_timestamp = int( - (target_frame * av.time_base) / container.streams.video[0].average_rate - ) - container.seek(target_timestamp) - frame = next(container.decode(video=0)).to_image() - frames.append(frame) - return frames - - class VideoTextDataset(torch.utils.data.Dataset): def __init__(self, meta_path, transform): self.meta_path = meta_path @@ -58,7 +45,7 @@ class VideoTextDataset(torch.utils.data.Dataset): path = row["path"] if is_video(path): - img = extract_frames(path, points=[0.5])[0] + img = extract_frames_new(path, points=[0.5], backend='opencv')[0] else: img = pil_loader(path) diff --git a/tools/scoring/optical_flow/inference.py b/tools/scoring/optical_flow/inference.py deleted file mode 100644 index db62b35..0000000 --- a/tools/scoring/optical_flow/inference.py +++ /dev/null @@ -1,147 +0,0 @@ -import argparse -import os - -import av -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F -from einops import rearrange -from tqdm import tqdm - -from .unimatch import UniMatch - -import decord # isort: skip - - -def extract_frames_av(video_path, frame_inds=[0, 10, 20, 30]): - container = av.open(video_path) - total_frames = container.streams.video[0].frames - frames = [] - for idx in frame_inds: - if idx >= total_frames: - idx = total_frames - 1 - target_timestamp = int( - idx * av.time_base / container.streams.video[0].average_rate - ) - container.seek(target_timestamp) - frame = next(container.decode(video=0)).to_image() - frames.append(frame) - return frames - - -def extract_frames(video_path, frame_inds=[0, 10, 20, 30]): - container = decord.VideoReader(video_path, num_threads=1) - total_frames = len(container) - # avg_fps = container.get_avg_fps() - - frame_inds = np.array(frame_inds).astype(np.int32) - frame_inds[frame_inds >= total_frames] = total_frames - 1 - frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] - return frames - - -class VideoTextDataset(torch.utils.data.Dataset): - def __init__(self, meta_path, frame_inds=[0, 10, 20, 30]): - self.meta_path = meta_path - self.meta = pd.read_csv(meta_path) - self.frame_inds = frame_inds - - def __getitem__(self, index): - row = self.meta.iloc[index] - images = extract_frames(row["path"], frame_inds=self.frame_inds) - # images = [pil_to_tensor(x) for x in images] # [C, H, W] - - # transform - images = torch.from_numpy(images).float() - images = rearrange(images, "N H W C -> N C H W") - H, W = images.shape[-2:] - if H > W: - images = rearrange(images, "N C H W -> N C W H") - images = F.interpolate( - images, size=(320, 576), mode="bilinear", align_corners=True - ) - - return images - - def __len__(self): - return len(self.meta) - - -def main(): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - parser = argparse.ArgumentParser() - parser.add_argument("meta_path", type=str, help="Path to the input CSV file") - parser.add_argument("--bs", type=int, default=4, help="Batch size") - parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") - args = parser.parse_args() - - meta_path = args.meta_path - wo_ext, ext = os.path.splitext(meta_path) - out_path = f"{wo_ext}_flow{ext}" - - # build model - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = UniMatch( - feature_channels=128, - num_scales=2, - upsample_factor=4, - num_head=1, - ffn_dim_expansion=4, - num_transformer_layers=6, - reg_refine=True, - task="flow", - ).eval() - ckpt = torch.load( - "./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth" - ) - model.load_state_dict(ckpt["model"]) - model = model.to(device) - # model = torch.nn.DataParallel(model) - - # build dataset - dataset = VideoTextDataset(meta_path=meta_path, frame_inds=[0, 10, 20, 30]) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=args.bs, - num_workers=args.num_workers, - shuffle=False, - ) - - # compute optical flow scores - dataset.meta["flow"] = np.nan - index = 0 - for images in tqdm(dataloader): - images = images.to(device) - B = images.shape[0] - - batch_0 = rearrange(images[:, :-1], "B N C H W -> (B N) C H W").contiguous() - batch_1 = rearrange(images[:, 1:], "B N C H W -> (B N) C H W").contiguous() - - with torch.no_grad(): - res = model( - batch_0, - batch_1, - attn_type="swin", - attn_splits_list=[2, 8], - corr_radius_list=[-1, 4], - prop_radius_list=[-1, 1], - num_reg_refine=6, - task="flow", - pred_bidir_flow=False, - ) - flow_maps = res["flow_preds"][-1].cpu() # [B * (N-1), 2, H, W] - flow_maps = rearrange(flow_maps, "(B N) C H W -> B N H W C", B=B) - flow_scores = flow_maps.abs().mean(dim=[1, 2, 3, 4]) - flow_scores_np = flow_scores.numpy() - - dataset.meta.loc[index : index + B - 1, "flow"] = flow_scores_np - index += B - - dataset.meta.to_csv(out_path, index=False) - print(f"New meta with optical flow scores saved to '{out_path}'.") - - -if __name__ == "__main__": - main() diff --git a/tools/scoring/optical_flow/inference_parallel.py b/tools/scoring/optical_flow/inference_parallel.py index c64b840..a821beb 100644 --- a/tools/scoring/optical_flow/inference_parallel.py +++ b/tools/scoring/optical_flow/inference_parallel.py @@ -10,38 +10,11 @@ import torch.distributed as dist import torch.nn.functional as F from einops import rearrange from torch.utils.data import DataLoader, DistributedSampler +from torchvision.transforms.functional import pil_to_tensor from tqdm import tqdm from .unimatch import UniMatch - -import decord # isort: skip - - -def extract_frames_av(video_path, frame_inds=[0, 10, 20, 30]): - container = av.open(video_path) - total_frames = container.streams.video[0].frames - frames = [] - for idx in frame_inds: - if idx >= total_frames: - idx = total_frames - 1 - target_timestamp = int( - idx * av.time_base / container.streams.video[0].average_rate - ) - container.seek(target_timestamp) - frame = next(container.decode(video=0)).to_image() - frames.append(frame) - return frames - - -def extract_frames(video_path, frame_inds=[0, 10, 20, 30]): - container = decord.VideoReader(video_path, num_threads=1) - total_frames = len(container) - # avg_fps = container.get_avg_fps() - - frame_inds = np.array(frame_inds).astype(np.int32) - frame_inds[frame_inds >= total_frames] = total_frames - 1 - frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] - return frames +from tools.datasets.transform import extract_frames_new def merge_scores(gathered_list: list, meta: pd.DataFrame): @@ -69,12 +42,11 @@ class VideoTextDataset(torch.utils.data.Dataset): def __getitem__(self, index): row = self.meta.iloc[index] - images = extract_frames(row["path"], frame_inds=self.frame_inds) - # images = [pil_to_tensor(x) for x in images] # [C, H, W] + images = extract_frames_new(row["path"], frame_inds=self.frame_inds, backend='opencv') # transform - images = torch.from_numpy(images).float() - images = rearrange(images, "N H W C -> N C H W") + images = torch.stack([pil_to_tensor(x) for x in images]) # shape: [N, C, H, W]; dtype: torch.uint8 + images = images.float() H, W = images.shape[-2:] if H > W: images = rearrange(images, "N C H W -> N C W H")