mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 20:36:58 +02:00
update docs
This commit is contained in:
parent
9adf0cafb1
commit
921e3138b3
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -164,6 +164,7 @@ cython_debug/
|
|||
*.DS_Store
|
||||
|
||||
# misc files
|
||||
data/
|
||||
dataset/
|
||||
runs/
|
||||
checkpoints/
|
||||
|
|
|
|||
22
README.md
22
README.md
|
|
@ -111,30 +111,30 @@ After installation, we suggest reading [structure.md](docs/structure.md) to lear
|
|||
| 16×512×512 | 20K HQ | 20k | 2×64 | 35 | |
|
||||
| 64×512×512 | 50K HQ | | 4×64 | | |
|
||||
|
||||
Our model's weight is partially initialized from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha). The number of parameters is 724M. More information about training can be found in our **[report](/docs/report_v1.md)**. More about dataset can be found in [dataset.md](/docs/dataset.md).
|
||||
Our model's weight is partially initialized from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha). The number of parameters is 724M. More information about training can be found in our **[report](/docs/report_v1.md)**. More about dataset can be found in [dataset.md](/docs/dataset.md). HQ means high quality.
|
||||
|
||||
**LIMITATION**: Our model is trained on a limited budget. The quality and text alignment is relatively poor. The model performs badly especially on generating human beings and cannot follow detailed instructions. We are working on improving the quality and text alignment.
|
||||
:warning: **LIMITATION**: Our model is trained on a limited budget. The quality and text alignment is relatively poor. The model performs badly especially on generating human beings and cannot follow detailed instructions. We are working on improving the quality and text alignment.
|
||||
|
||||
## Inference
|
||||
|
||||
To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then run the following commands to generate samples. See [here](docs/structure.md#inference-config-demos) to customize the configuration.
|
||||
|
||||
```bash
|
||||
# Sample 16x256x256 (may take less than 1 min)
|
||||
# Sample 16x256x256 (5s/sample)
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
|
||||
# Sample 16x512x512 (may take less than 1 min)
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py
|
||||
# Sample 16x512x512 (20s/sample, 100 time steps)
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
|
||||
# Sample 64x512x512 (may take 1 min or more)
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py
|
||||
# Sample 64x512x512 (40s/sample, 100 time steps)
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
|
||||
# Sample 64x512x512 with sequence parallelism (may take 1 min or more)
|
||||
# Sample 64x512x512 with sequence parallelism (30s/sample, 100 time steps)
|
||||
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
|
||||
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py
|
||||
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
```
|
||||
|
||||
For inference with other models, see [here](docs/commands.md) for more instructions.
|
||||
The speed is tested on H800 GPUs. For inference with other models, see [here](docs/commands.md) for more instructions.
|
||||
|
||||
## Data Processing (WIP)
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ For training other models and advanced usage, see [here](docs/commands.md) for m
|
|||
## Acknowledgement
|
||||
|
||||
* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.
|
||||
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. OpenDiT's team provides valuable suggestions on acceleration of our training process.
|
||||
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. We adopt valuable acceleration strategies for training progress from OpenDiT.
|
||||
* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model.
|
||||
* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video.
|
||||
* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model.
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ We also try to use a 3D patch embedder in DiT. However, with 2x downsampling on
|
|||
|
||||
## Data is the key to high quality
|
||||
|
||||
We find that the number and quality of data have a great impact on the quality of generated videos, even larger than the model architecture and training strategy. At this time, we only prepared the first split (366K video clips) from [HD-VG-130M](https://github.com/daooshee/HD-VG-130M). The quality of these videos varies greatly, and the captions are not that accurate. Thus, we further collect 20k relatively high quality videos from [Pexels](https://www.pexels.com/), which provides free license videos. We label the video with LLaVA, an image captioning model, with three frames and a designed prompt.
|
||||
We find that the number and quality of data have a great impact on the quality of generated videos, even larger than the model architecture and training strategy. At this time, we only prepared the first split (366K video clips) from [HD-VG-130M](https://github.com/daooshee/HD-VG-130M). The quality of these videos varies greatly, and the captions are not that accurate. Thus, we further collect 20k relatively high quality videos from [Pexels](https://www.pexels.com/), which provides free license videos. We label the video with LLaVA, an image captioning model, with three frames and a designed prompt. With designed prompt, LLaVA can generate good quality of captions.
|
||||
|
||||
[figure]
|
||||

|
||||
|
||||
As we lay more emphasis on the quality of data, we prepare to collect more data and build a video preprocessing pipeline in our next version.
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class DatasetFromCSV(torch.utils.data.Dataset):
|
|||
self.samples = list(reader)
|
||||
|
||||
ext = self.samples[0][0].split(".")[-1]
|
||||
if ext.lower() in ["mp4", "avi", "mov", "mkv"]:
|
||||
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
|
||||
self.is_video = True
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
|
|
|
|||
0
tools/caption/README.md
Normal file
0
tools/caption/README.md
Normal file
0
tools/caption/__init__.py
Normal file
0
tools/caption/__init__.py
Normal file
69
tools/caption/caption_gpt4.py
Normal file
69
tools/caption/caption_gpt4.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import argparse
|
||||
import csv
|
||||
import os
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
from .utils import extract_frames, prompts, read_video_list
|
||||
|
||||
|
||||
def get_caption(frame, prompt, api_key):
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[0]}"}},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[1]}"}},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[2]}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
|
||||
caption = response.json()["choices"][0]["message"]["content"]
|
||||
caption = caption.replace("\n", " ")
|
||||
return caption
|
||||
|
||||
|
||||
def main(args):
|
||||
# ======================================================
|
||||
# 1. read video list
|
||||
# ======================================================
|
||||
videos = read_video_list(args.video_folder, args.output_file)
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
|
||||
# ======================================================
|
||||
# 2. generate captions
|
||||
# ======================================================
|
||||
for video in tqdm.tqdm(videos):
|
||||
video_path = os.path.join(args.video_folder, video)
|
||||
frame, length = extract_frames(video_path, base_64=True)
|
||||
if len(frame) < 3:
|
||||
continue
|
||||
|
||||
prompt = prompts[args.prompt]
|
||||
caption = get_caption(frame, prompt, args.key)
|
||||
|
||||
writer.writerow((video, caption, length))
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("video_folder", type=str)
|
||||
parser.add_argument("output_file", type=str)
|
||||
parser.add_argument("--prompt", type=str, default="three_frames")
|
||||
parser.add_argument("--key", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
||||
from llava.conversation import conv_templates
|
||||
|
|
@ -10,43 +10,12 @@ from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path
|
|||
from llava.model.builder import load_pretrained_model
|
||||
from llava.model.llava_arch import unpad_image
|
||||
from llava.utils import disable_torch_init
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import extract_frames, prompts, read_video_list
|
||||
|
||||
disable_torch_init()
|
||||
|
||||
prompts = {
|
||||
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
}
|
||||
|
||||
|
||||
def get_filelist(file_path):
|
||||
Filelist = []
|
||||
for home, dirs, files in os.walk(file_path):
|
||||
for filename in files:
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
return Filelist
|
||||
|
||||
|
||||
def get_video_length(cap):
|
||||
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
|
||||
def extract_frames(video_path, points=[0.2, 0.5, 0.8]):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = get_video_length(cap)
|
||||
points = [int(length * point) for point in points]
|
||||
frames = []
|
||||
if length < 3:
|
||||
return frames, length
|
||||
for point in points:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
frames.append(frame)
|
||||
return frames, length
|
||||
|
||||
|
||||
def prepare_inputs_labels_for_multimodal(
|
||||
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
||||
|
|
@ -274,43 +243,44 @@ def prepare_inputs_labels_for_multimodal(
|
|||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
bs = args.bs
|
||||
video_folder = args.video_folder
|
||||
|
||||
processed_videos = []
|
||||
if os.path.exists(args.output_file):
|
||||
with open(args.output_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
samples = list(reader)
|
||||
processed_videos = [sample[0] for sample in samples]
|
||||
# ======================================================
|
||||
# 1. read video list
|
||||
# ======================================================
|
||||
videos = read_video_list(args.video_folder, args.output_file)
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
|
||||
# ======================================================
|
||||
# 2. load model and prepare prompts
|
||||
# ======================================================
|
||||
model_path = "liuhaotian/llava-v1.6-34b"
|
||||
query = prompts["three_frames"]
|
||||
query = prompts[args.prompt]
|
||||
print(f"Prompt: {query}")
|
||||
conv = conv_templates["chatml_direct"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path),
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # Pytorch non-meta copying warning fills out the console
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path),
|
||||
)
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
||||
input_ids = input_ids.unsqueeze(0).to(model.device)
|
||||
|
||||
videos = get_filelist(video_folder)
|
||||
print(f"Dataset contains {len(videos)} videos.")
|
||||
videos = [video for video in videos if video not in processed_videos]
|
||||
print(f"Processing {len(videos)} new videos.")
|
||||
# ======================================================
|
||||
# 3. generate captions
|
||||
# ======================================================
|
||||
bs = args.bs
|
||||
for i in tqdm(range(0, len(videos), bs)):
|
||||
# prepare a batch of inputs
|
||||
video_files = videos[i : i + bs]
|
||||
frames = []
|
||||
video_lengths = []
|
||||
for video_file in video_files:
|
||||
frame, length = extract_frames(os.path.join(video_folder, video_file))
|
||||
frame, length = extract_frames(os.path.join(args.video_folder, video_file))
|
||||
if len(frame) < 3:
|
||||
continue
|
||||
frames.append(frame)
|
||||
|
|
@ -373,8 +343,10 @@ def main(args):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_folder", type=str, required=True)
|
||||
parser.add_argument("video_folder", type=str)
|
||||
parser.add_argument("output_file", type=str)
|
||||
parser.add_argument("--bs", type=int, default=32)
|
||||
parser.add_argument("--output_file", type=str, default="video_captions.csv")
|
||||
parser.add_argument("--prompt", type=str, default="three_frames")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
67
tools/caption/utils.py
Normal file
67
tools/caption/utils.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
import base64
|
||||
import csv
|
||||
import os
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
prompts = {
|
||||
"naive": "Describe the video",
|
||||
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
}
|
||||
|
||||
|
||||
def get_filelist(file_path):
|
||||
Filelist = []
|
||||
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
|
||||
for home, dirs, files in os.walk(file_path):
|
||||
for filename in files:
|
||||
ext = filename.split(".")[-1]
|
||||
if ext in VID_EXTENSIONS:
|
||||
Filelist.append(filename)
|
||||
return Filelist
|
||||
|
||||
|
||||
def get_video_length(cap):
|
||||
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def extract_frames(video_path, points=(0.2, 0.5, 0.8), base_64=False):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = get_video_length(cap)
|
||||
points = [int(length * point) for point in points]
|
||||
frames = []
|
||||
if length < 3:
|
||||
return frames, length
|
||||
for point in points:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
if not base_64:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
else:
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
frame = base64.b64encode(buffer).decode("utf-8")
|
||||
frames.append(frame)
|
||||
return frames, length
|
||||
|
||||
|
||||
def read_video_list(video_folder, output_file):
|
||||
processed_videos = []
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
samples = list(reader)
|
||||
processed_videos = [sample[0] for sample in samples]
|
||||
|
||||
# read video list
|
||||
videos = get_filelist(video_folder)
|
||||
print(f"Dataset contains {len(videos)} videos.")
|
||||
videos = [video for video in videos if video not in processed_videos]
|
||||
print(f"Processing {len(videos)} new videos.")
|
||||
return videos
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import argparse
|
||||
import base64
|
||||
import csv
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
# OpenAI API Key
|
||||
api_key = ""
|
||||
|
||||
|
||||
def get_caption(frame):
|
||||
prompt = "The middle frame from a video clip are given. Describe this video and its style to generate a description for the video. The description should be useful for AI to re-generate the video. Here are some examples of good descriptions:\n\n 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.\n2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.\n 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway."
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
|
||||
caption = response.json()["choices"][0]["message"]["content"]
|
||||
return caption
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def extract_frames(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
point = length // 2
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
img_base64 = base64.b64encode(buffer).decode("utf-8")
|
||||
return img_base64
|
||||
|
||||
|
||||
def main(args):
|
||||
processed_videos = []
|
||||
if os.path.exists(args.output_file):
|
||||
with open(args.output_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
samples = list(reader)
|
||||
processed_videos = [sample[0] for sample in samples]
|
||||
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
for video in tqdm.tqdm(os.listdir(args.video_folder)):
|
||||
if video in processed_videos:
|
||||
continue
|
||||
video_path = os.path.join(args.video_folder, video)
|
||||
base64_image = extract_frames(video_path)
|
||||
caption = get_caption(base64_image)
|
||||
caption = caption.replace("\n", " ")
|
||||
writer.writerow([video, caption])
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_folder", type=str, required=True, help="Path to the folder containing the videos.")
|
||||
parser.add_argument("--output_file", type=str, default="video_captions.csv", help="Path to the output file.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
0
tools/datasets/README.md
Normal file
0
tools/datasets/README.md
Normal file
0
tools/intepolate/README.md
Normal file
0
tools/intepolate/README.md
Normal file
Loading…
Reference in a new issue