update docs

This commit is contained in:
Zangwei Zheng 2024-03-17 15:47:48 +08:00
parent 9adf0cafb1
commit 921e3138b3
19 changed files with 179 additions and 152 deletions

1
.gitignore vendored
View file

@ -164,6 +164,7 @@ cython_debug/
*.DS_Store
# misc files
data/
dataset/
runs/
checkpoints/

View file

@ -111,30 +111,30 @@ After installation, we suggest reading [structure.md](docs/structure.md) to lear
| 16×512×512 | 20K HQ | 20k | 2×64 | 35 | |
| 64×512×512 | 50K HQ | | 4×64 | | |
Our model's weight is partially initialized from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha). The number of parameters is 724M. More information about training can be found in our **[report](/docs/report_v1.md)**. More about dataset can be found in [dataset.md](/docs/dataset.md).
Our model's weight is partially initialized from [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha). The number of parameters is 724M. More information about training can be found in our **[report](/docs/report_v1.md)**. More about dataset can be found in [dataset.md](/docs/dataset.md). HQ means high quality.
**LIMITATION**: Our model is trained on a limited budget. The quality and text alignment is relatively poor. The model performs badly especially on generating human beings and cannot follow detailed instructions. We are working on improving the quality and text alignment.
:warning: **LIMITATION**: Our model is trained on a limited budget. The quality and text alignment is relatively poor. The model performs badly especially on generating human beings and cannot follow detailed instructions. We are working on improving the quality and text alignment.
## Inference
To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then run the following commands to generate samples. See [here](docs/structure.md#inference-config-demos) to customize the configuration.
```bash
# Sample 16x256x256 (may take less than 1 min)
# Sample 16x256x256 (5s/sample)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
# Sample 16x512x512 (may take less than 1 min)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py
# Sample 16x512x512 (20s/sample, 100 time steps)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py --ckpt-path ./path/to/your/ckpt.pth
# Sample 64x512x512 (may take 1 min or more)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py
# Sample 64x512x512 (40s/sample, 100 time steps)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth
# Sample 64x512x512 with sequence parallelism (may take 1 min or more)
# Sample 64x512x512 with sequence parallelism (30s/sample, 100 time steps)
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py --ckpt-path ./path/to/your/ckpt.pth
```
For inference with other models, see [here](docs/commands.md) for more instructions.
The speed is tested on H800 GPUs. For inference with other models, see [here](docs/commands.md) for more instructions.
## Data Processing (WIP)
@ -166,7 +166,7 @@ For training other models and advanced usage, see [here](docs/commands.md) for m
## Acknowledgement
* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. OpenDiT's team provides valuable suggestions on acceleration of our training process.
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. We adopt valuable acceleration strategies for training progress from OpenDiT.
* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model.
* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video.
* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model.

View file

@ -22,9 +22,9 @@ We also try to use a 3D patch embedder in DiT. However, with 2x downsampling on
## Data is the key to high quality
We find that the number and quality of data have a great impact on the quality of generated videos, even larger than the model architecture and training strategy. At this time, we only prepared the first split (366K video clips) from [HD-VG-130M](https://github.com/daooshee/HD-VG-130M). The quality of these videos varies greatly, and the captions are not that accurate. Thus, we further collect 20k relatively high quality videos from [Pexels](https://www.pexels.com/), which provides free license videos. We label the video with LLaVA, an image captioning model, with three frames and a designed prompt.
We find that the number and quality of data have a great impact on the quality of generated videos, even larger than the model architecture and training strategy. At this time, we only prepared the first split (366K video clips) from [HD-VG-130M](https://github.com/daooshee/HD-VG-130M). The quality of these videos varies greatly, and the captions are not that accurate. Thus, we further collect 20k relatively high quality videos from [Pexels](https://www.pexels.com/), which provides free license videos. We label the video with LLaVA, an image captioning model, with three frames and a designed prompt. With designed prompt, LLaVA can generate good quality of captions.
[figure]
![Caption](https://i0.imgs.ovh/2024/03/16/eXdvC.png)
As we lay more emphasis on the quality of data, we prepare to collect more data and build a video preprocessing pipeline in our next version.

View file

@ -58,7 +58,7 @@ class DatasetFromCSV(torch.utils.data.Dataset):
self.samples = list(reader)
ext = self.samples[0][0].split(".")[-1]
if ext.lower() in ["mp4", "avi", "mov", "mkv"]:
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"

0
tools/caption/README.md Normal file
View file

View file

View file

@ -0,0 +1,69 @@
import argparse
import csv
import os
import requests
import tqdm
from .utils import extract_frames, prompts, read_video_list
def get_caption(frame, prompt, api_key):
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
payload = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[0]}"}},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[1]}"}},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[2]}"}},
],
}
],
"max_tokens": 300,
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
caption = response.json()["choices"][0]["message"]["content"]
caption = caption.replace("\n", " ")
return caption
def main(args):
# ======================================================
# 1. read video list
# ======================================================
videos = read_video_list(args.video_folder, args.output_file)
f = open(args.output_file, "a")
writer = csv.writer(f)
# ======================================================
# 2. generate captions
# ======================================================
for video in tqdm.tqdm(videos):
video_path = os.path.join(args.video_folder, video)
frame, length = extract_frames(video_path, base_64=True)
if len(frame) < 3:
continue
prompt = prompts[args.prompt]
caption = get_caption(frame, prompt, args.key)
writer.writerow((video, caption, length))
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("video_folder", type=str)
parser.add_argument("output_file", type=str)
parser.add_argument("--prompt", type=str, default="three_frames")
parser.add_argument("--key", type=str)
args = parser.parse_args()
main(args)

View file

@ -1,8 +1,8 @@
import argparse
import csv
import os
import warnings
import cv2
import torch
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
@ -10,43 +10,12 @@ from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.model.llava_arch import unpad_image
from llava.utils import disable_torch_init
from PIL import Image
from tqdm import tqdm
from .utils import extract_frames, prompts, read_video_list
disable_torch_init()
prompts = {
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliffs edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
}
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
return Filelist
def get_video_length(cap):
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
def extract_frames(video_path, points=[0.2, 0.5, 0.8]):
cap = cv2.VideoCapture(video_path)
length = get_video_length(cap)
points = [int(length * point) for point in points]
frames = []
if length < 3:
return frames, length
for point in points:
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
return frames, length
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
@ -274,43 +243,44 @@ def prepare_inputs_labels_for_multimodal(
@torch.inference_mode()
def main(args):
bs = args.bs
video_folder = args.video_folder
processed_videos = []
if os.path.exists(args.output_file):
with open(args.output_file, "r") as f:
reader = csv.reader(f)
samples = list(reader)
processed_videos = [sample[0] for sample in samples]
# ======================================================
# 1. read video list
# ======================================================
videos = read_video_list(args.video_folder, args.output_file)
f = open(args.output_file, "a")
writer = csv.writer(f)
# ======================================================
# 2. load model and prepare prompts
# ======================================================
model_path = "liuhaotian/llava-v1.6-34b"
query = prompts["three_frames"]
query = prompts[args.prompt]
print(f"Prompt: {query}")
conv = conv_templates["chatml_direct"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query)
prompt = conv.get_prompt()
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Pytorch non-meta copying warning fills out the console
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).to(model.device)
videos = get_filelist(video_folder)
print(f"Dataset contains {len(videos)} videos.")
videos = [video for video in videos if video not in processed_videos]
print(f"Processing {len(videos)} new videos.")
# ======================================================
# 3. generate captions
# ======================================================
bs = args.bs
for i in tqdm(range(0, len(videos), bs)):
# prepare a batch of inputs
video_files = videos[i : i + bs]
frames = []
video_lengths = []
for video_file in video_files:
frame, length = extract_frames(os.path.join(video_folder, video_file))
frame, length = extract_frames(os.path.join(args.video_folder, video_file))
if len(frame) < 3:
continue
frames.append(frame)
@ -373,8 +343,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_folder", type=str, required=True)
parser.add_argument("video_folder", type=str)
parser.add_argument("output_file", type=str)
parser.add_argument("--bs", type=int, default=32)
parser.add_argument("--output_file", type=str, default="video_captions.csv")
parser.add_argument("--prompt", type=str, default="three_frames")
args = parser.parse_args()
main(args)

67
tools/caption/utils.py Normal file
View file

@ -0,0 +1,67 @@
import base64
import csv
import os
import cv2
from PIL import Image
prompts = {
"naive": "Describe the video",
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliffs edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
}
def get_filelist(file_path):
Filelist = []
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
for home, dirs, files in os.walk(file_path):
for filename in files:
ext = filename.split(".")[-1]
if ext in VID_EXTENSIONS:
Filelist.append(filename)
return Filelist
def get_video_length(cap):
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def extract_frames(video_path, points=(0.2, 0.5, 0.8), base_64=False):
cap = cv2.VideoCapture(video_path)
length = get_video_length(cap)
points = [int(length * point) for point in points]
frames = []
if length < 3:
return frames, length
for point in points:
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
if not base_64:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
else:
_, buffer = cv2.imencode(".jpg", frame)
frame = base64.b64encode(buffer).decode("utf-8")
frames.append(frame)
return frames, length
def read_video_list(video_folder, output_file):
processed_videos = []
if os.path.exists(output_file):
with open(output_file, "r") as f:
reader = csv.reader(f)
samples = list(reader)
processed_videos = [sample[0] for sample in samples]
# read video list
videos = get_filelist(video_folder)
print(f"Dataset contains {len(videos)} videos.")
videos = [video for video in videos if video not in processed_videos]
print(f"Processing {len(videos)} new videos.")
return videos

View file

@ -1,82 +0,0 @@
import argparse
import base64
import csv
import os
import cv2
import requests
import tqdm
# OpenAI API Key
api_key = ""
def get_caption(frame):
prompt = "The middle frame from a video clip are given. Describe this video and its style to generate a description for the video. The description should be useful for AI to re-generate the video. Here are some examples of good descriptions:\n\n 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.\n2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.\n 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliffs edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway."
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
payload = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}},
],
}
],
"max_tokens": 300,
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
caption = response.json()["choices"][0]["message"]["content"]
return caption
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def extract_frames(video_path):
cap = cv2.VideoCapture(video_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
point = length // 2
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
_, buffer = cv2.imencode(".jpg", frame)
img_base64 = base64.b64encode(buffer).decode("utf-8")
return img_base64
def main(args):
processed_videos = []
if os.path.exists(args.output_file):
with open(args.output_file, "r") as f:
reader = csv.reader(f)
samples = list(reader)
processed_videos = [sample[0] for sample in samples]
f = open(args.output_file, "a")
writer = csv.writer(f)
for video in tqdm.tqdm(os.listdir(args.video_folder)):
if video in processed_videos:
continue
video_path = os.path.join(args.video_folder, video)
base64_image = extract_frames(video_path)
caption = get_caption(base64_image)
caption = caption.replace("\n", " ")
writer.writerow([video, caption])
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_folder", type=str, required=True, help="Path to the folder containing the videos.")
parser.add_argument("--output_file", type=str, default="video_captions.csv", help="Path to the output file.")
args = parser.parse_args()
main(args)

0
tools/datasets/README.md Normal file
View file

View file