mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-16 04:54:29 +02:00
merge
This commit is contained in:
commit
f5e347e08e
|
|
@ -9,6 +9,8 @@ import torchvision.transforms as transforms
|
|||
from PIL import Image
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
|
||||
from tools.datasets.transform import extract_frames_new
|
||||
|
||||
PROMPTS = {
|
||||
"image": {
|
||||
"text": "Describe this image and its style to generate a succinct yet informative description. Pay attention to all objects in the image. The description should be useful for AI to re-generate the image. The description should be no more than five sentences. Remember do not exceed 5 sentences.",
|
||||
|
|
@ -101,7 +103,10 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
images = [pil_loader(path)]
|
||||
length = 1
|
||||
else:
|
||||
images, length = extract_frames(sample["path"], points=self.points)
|
||||
# images, length = extract_frames(sample["path"], points=self.points)
|
||||
images, length = extract_frames_new(
|
||||
sample["path"], points=self.points, backend="opencv", return_length=True
|
||||
)
|
||||
if self.resize_size is not None:
|
||||
images_r = []
|
||||
for img in images:
|
||||
|
|
|
|||
|
|
@ -101,6 +101,91 @@ def extract_frames(video_path, input_dir, output, point):
|
|||
return path_new
|
||||
|
||||
|
||||
def extract_frames_new(
|
||||
video_path,
|
||||
frame_inds=[0, 10, 20, 30],
|
||||
points=None,
|
||||
backend='opencv',
|
||||
return_length=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
video_path (str): path to video
|
||||
frame_inds (List[int]): indices of frames to extract
|
||||
points (List[float]): values within [0, 1); multiply #frames to get frame indices
|
||||
Return:
|
||||
List[PIL.Image]
|
||||
"""
|
||||
assert backend in ['av', 'opencv', 'decord']
|
||||
assert (frame_inds is None) or (points is None)
|
||||
|
||||
if backend == 'av':
|
||||
import av
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
target_timestamp = int(
|
||||
idx * av.time_base / container.streams.video[0].average_rate
|
||||
)
|
||||
container.seek(target_timestamp)
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
|
||||
elif backend == 'decord':
|
||||
import decord
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(container)
|
||||
# avg_fps = container.get_avg_fps()
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frame_inds = np.array(frame_inds).astype(np.int32)
|
||||
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||
frames = [Image.fromarray(x) for x in frames]
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
|
||||
elif backend == 'opencv':
|
||||
import cv2
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
frames.append(frame)
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def main(args):
|
||||
data = pd.read_csv(args.input)
|
||||
if args.method == "img_rand_crop":
|
||||
|
|
@ -145,3 +230,14 @@ if __name__ == "__main__":
|
|||
if args.disable_parallel:
|
||||
pandas_has_parallel = False
|
||||
main(args)
|
||||
exit()
|
||||
|
||||
from torchvision.transforms.functional import pil_to_tensor
|
||||
ret = extract_frames_new(
|
||||
'E:/data/video/pexels_new/8974385_scene-0.mp4',
|
||||
frame_inds=[0, 50, 100, 150],
|
||||
backend='opencv')
|
||||
for idx, img in enumerate(ret):
|
||||
save_path = f'./checkpoints/vis/{idx}.png'
|
||||
ret[idx].save(save_path)
|
||||
exit()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
from datetime import timedelta
|
||||
|
||||
import clip
|
||||
import decord
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
|
@ -17,6 +16,8 @@ from PIL import Image
|
|||
from torchvision.datasets.folder import pil_loader
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.datasets.transform import extract_frames_new
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
|
|
@ -34,16 +35,6 @@ def is_video(filename):
|
|||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def extract_frames(video_path, points=(0.1, 0.5, 0.9)):
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(container)
|
||||
frame_inds = (np.array(points) * total_frames).astype(np.int32)
|
||||
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||
frames_pil = [Image.fromarray(frame) for frame in frames]
|
||||
return frames_pil
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)):
|
||||
self.csv_path = csv_path
|
||||
|
|
@ -57,7 +48,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
if not is_video(path):
|
||||
images = [pil_loader(path)]
|
||||
else:
|
||||
images = extract_frames(sample["path"], points=self.points)
|
||||
images = extract_frames_new(sample["path"], points=self.points, backend="opencv")
|
||||
images = [self.transform(img) for img in images]
|
||||
images = torch.stack(images)
|
||||
return dict(index=index, images=images)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchvision.datasets.folder import pil_loader
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.datasets.transform import extract_frames_new
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
|
|
@ -32,21 +34,6 @@ def is_video(filename):
|
|||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def extract_frames(video_path, points=[0.5]):
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
frames = []
|
||||
for point in points:
|
||||
target_frame = total_frames * point
|
||||
target_timestamp = int(
|
||||
(target_frame * av.time_base) / container.streams.video[0].average_rate
|
||||
)
|
||||
container.seek(target_timestamp)
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, meta_path, transform):
|
||||
self.meta_path = meta_path
|
||||
|
|
@ -58,7 +45,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
path = row["path"]
|
||||
|
||||
if is_video(path):
|
||||
img = extract_frames(path, points=[0.5])[0]
|
||||
img = extract_frames_new(path, points=[0.5], backend='opencv')[0]
|
||||
else:
|
||||
img = pil_loader(path)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,147 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
|
||||
from .unimatch import UniMatch
|
||||
|
||||
import decord # isort: skip
|
||||
|
||||
|
||||
def extract_frames_av(video_path, frame_inds=[0, 10, 20, 30]):
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
target_timestamp = int(
|
||||
idx * av.time_base / container.streams.video[0].average_rate
|
||||
)
|
||||
container.seek(target_timestamp)
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def extract_frames(video_path, frame_inds=[0, 10, 20, 30]):
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(container)
|
||||
# avg_fps = container.get_avg_fps()
|
||||
|
||||
frame_inds = np.array(frame_inds).astype(np.int32)
|
||||
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||
return frames
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, meta_path, frame_inds=[0, 10, 20, 30]):
|
||||
self.meta_path = meta_path
|
||||
self.meta = pd.read_csv(meta_path)
|
||||
self.frame_inds = frame_inds
|
||||
|
||||
def __getitem__(self, index):
|
||||
row = self.meta.iloc[index]
|
||||
images = extract_frames(row["path"], frame_inds=self.frame_inds)
|
||||
# images = [pil_to_tensor(x) for x in images] # [C, H, W]
|
||||
|
||||
# transform
|
||||
images = torch.from_numpy(images).float()
|
||||
images = rearrange(images, "N H W C -> N C H W")
|
||||
H, W = images.shape[-2:]
|
||||
if H > W:
|
||||
images = rearrange(images, "N C H W -> N C W H")
|
||||
images = F.interpolate(
|
||||
images, size=(320, 576), mode="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta)
|
||||
|
||||
|
||||
def main():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--bs", type=int, default=4, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
meta_path = args.meta_path
|
||||
wo_ext, ext = os.path.splitext(meta_path)
|
||||
out_path = f"{wo_ext}_flow{ext}"
|
||||
|
||||
# build model
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = UniMatch(
|
||||
feature_channels=128,
|
||||
num_scales=2,
|
||||
upsample_factor=4,
|
||||
num_head=1,
|
||||
ffn_dim_expansion=4,
|
||||
num_transformer_layers=6,
|
||||
reg_refine=True,
|
||||
task="flow",
|
||||
).eval()
|
||||
ckpt = torch.load(
|
||||
"./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth"
|
||||
)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model = model.to(device)
|
||||
# model = torch.nn.DataParallel(model)
|
||||
|
||||
# build dataset
|
||||
dataset = VideoTextDataset(meta_path=meta_path, frame_inds=[0, 10, 20, 30])
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.bs,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# compute optical flow scores
|
||||
dataset.meta["flow"] = np.nan
|
||||
index = 0
|
||||
for images in tqdm(dataloader):
|
||||
images = images.to(device)
|
||||
B = images.shape[0]
|
||||
|
||||
batch_0 = rearrange(images[:, :-1], "B N C H W -> (B N) C H W").contiguous()
|
||||
batch_1 = rearrange(images[:, 1:], "B N C H W -> (B N) C H W").contiguous()
|
||||
|
||||
with torch.no_grad():
|
||||
res = model(
|
||||
batch_0,
|
||||
batch_1,
|
||||
attn_type="swin",
|
||||
attn_splits_list=[2, 8],
|
||||
corr_radius_list=[-1, 4],
|
||||
prop_radius_list=[-1, 1],
|
||||
num_reg_refine=6,
|
||||
task="flow",
|
||||
pred_bidir_flow=False,
|
||||
)
|
||||
flow_maps = res["flow_preds"][-1].cpu() # [B * (N-1), 2, H, W]
|
||||
flow_maps = rearrange(flow_maps, "(B N) C H W -> B N H W C", B=B)
|
||||
flow_scores = flow_maps.abs().mean(dim=[1, 2, 3, 4])
|
||||
flow_scores_np = flow_scores.numpy()
|
||||
|
||||
dataset.meta.loc[index : index + B - 1, "flow"] = flow_scores_np
|
||||
index += B
|
||||
|
||||
dataset.meta.to_csv(out_path, index=False)
|
||||
print(f"New meta with optical flow scores saved to '{out_path}'.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,38 +10,11 @@ import torch.distributed as dist
|
|||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchvision.transforms.functional import pil_to_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .unimatch import UniMatch
|
||||
|
||||
import decord # isort: skip
|
||||
|
||||
|
||||
def extract_frames_av(video_path, frame_inds=[0, 10, 20, 30]):
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
target_timestamp = int(
|
||||
idx * av.time_base / container.streams.video[0].average_rate
|
||||
)
|
||||
container.seek(target_timestamp)
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def extract_frames(video_path, frame_inds=[0, 10, 20, 30]):
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(container)
|
||||
# avg_fps = container.get_avg_fps()
|
||||
|
||||
frame_inds = np.array(frame_inds).astype(np.int32)
|
||||
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||
return frames
|
||||
from tools.datasets.transform import extract_frames_new
|
||||
|
||||
|
||||
def merge_scores(gathered_list: list, meta: pd.DataFrame):
|
||||
|
|
@ -69,12 +42,11 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
row = self.meta.iloc[index]
|
||||
images = extract_frames(row["path"], frame_inds=self.frame_inds)
|
||||
# images = [pil_to_tensor(x) for x in images] # [C, H, W]
|
||||
images = extract_frames_new(row["path"], frame_inds=self.frame_inds, backend='opencv')
|
||||
|
||||
# transform
|
||||
images = torch.from_numpy(images).float()
|
||||
images = rearrange(images, "N H W C -> N C H W")
|
||||
images = torch.stack([pil_to_tensor(x) for x in images]) # shape: [N, C, H, W]; dtype: torch.uint8
|
||||
images = images.float()
|
||||
H, W = images.shape[-2:]
|
||||
if H > W:
|
||||
images = rearrange(images, "N C H W -> N C W H")
|
||||
|
|
|
|||
Loading…
Reference in a new issue