mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/Rflow
This commit is contained in:
commit
62d96a2ced
|
|
@ -22,7 +22,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = False
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
70
configs/opensora-v1-1/train/stage1.py
Normal file
70
configs/opensora-v1-1/train/stage1.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VariableVideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
image_size=(None, None),
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = { # 6s/it
|
||||
"240p": {16: (1.0, 17), 32: (1.0, 9), 64: (1.0, 4), 128: (1.0, 2)},
|
||||
"256": {1: (1.0, 254)},
|
||||
"512": {1: (0.5, 86)},
|
||||
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
|
||||
"720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now
|
||||
"1024": {1: (0.3, 20)},
|
||||
"1080p": {1: (0.4, 8)},
|
||||
}
|
||||
mask_ratios = {
|
||||
"mask_no": 0.9,
|
||||
"mask_random": 0.06,
|
||||
"mask_head": 0.01,
|
||||
"mask_tail": 0.01,
|
||||
"mask_head_tail": 0.02,
|
||||
}
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
from_pretrained=None,
|
||||
input_sq_size=512, # pretrained model is trained on 512x512
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=200,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm-speed",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
load = None
|
||||
|
||||
batch_size = None
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -28,7 +28,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ scheduler = dict(
|
|||
cfg_scale=7.0,
|
||||
cfg_channel=3, # or None
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ scheduler = dict(
|
|||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
@ -23,7 +23,13 @@ model = dict(
|
|||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
|
||||
mask_ratios = {
|
||||
"mask_no": 0.7,
|
||||
"mask_random": 0.15,
|
||||
"mask_head": 0.05,
|
||||
"mask_tail": 0.05,
|
||||
"mask_head_tail": 0.05,
|
||||
}
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
|
|
|
|||
|
|
@ -23,7 +23,13 @@ model = dict(
|
|||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
|
||||
mask_ratios = {
|
||||
"mask_no": 0.5,
|
||||
"mask_random": 0.29,
|
||||
"mask_head": 0.07,
|
||||
"mask_tail": 0.07,
|
||||
"mask_head_tail": 0.07,
|
||||
}
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -1,55 +0,0 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT-XL/2",
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
from_pretrained=None,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=128,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(512, 512),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = False
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
num_frames = 360
|
||||
frame_interval = 1
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=360,
|
||||
frame_interval=3,
|
||||
image_size=(512, 512),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
num_frames = 64
|
||||
frame_interval = 2
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(512, 512),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2-seq"
|
||||
plugin = "zero2"
|
||||
sp_size = 2
|
||||
|
||||
# Define model
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
num_frames = 64
|
||||
frame_interval = 2
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=64,
|
||||
frame_interval=3,
|
||||
image_size=(512, 512),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ scheduler = dict(
|
|||
num_sampling_steps=20,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
dtype = "bf16"
|
||||
|
||||
# prompt_path = "./assets/texts/t2i_samples.txt"
|
||||
prompt = [
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = False
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
num_frames = 1
|
||||
frame_interval = 1
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = True
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=1,
|
||||
frame_interval=3,
|
||||
image_size=(512, 512),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
num_frames = 64
|
||||
frame_interval = 2
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=64,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="PixArt-XL/2",
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
from pandarallel import pandarallel
|
||||
|
||||
|
||||
def process_csv(original_csv_path, original_video_format="mp4"):
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
|
||||
# Construct the new CSV file name by adding '_caption' before the '.csv' extension
|
||||
caption_file = original_csv_path.replace(".csv", "_caption.csv")
|
||||
df = pd.read_csv(original_csv_path)
|
||||
|
||||
# Add a new column for captions initialized with 'None'
|
||||
df["caption"] = "None"
|
||||
|
||||
def process_row(row):
|
||||
path = row["path"]
|
||||
json_path = path.replace(original_video_format, "json")
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
row["caption"] = json_data["caption"]
|
||||
return row
|
||||
|
||||
# Iterate over each row to replace video format with json in the path, and extract captions
|
||||
df = df.parallel_apply(process_row, axis=1)
|
||||
|
||||
# Save the modified DataFrame to a new CSV file
|
||||
df.to_csv(caption_file, index=False)
|
||||
print(f"New CSV file with captions is saved as {caption_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Setup argument parser
|
||||
parser = argparse.ArgumentParser(description="Process a CSV file to add video captions.")
|
||||
parser.add_argument("csv_path", type=str, help="The path to the original CSV file.")
|
||||
parser.add_argument("video_format", nargs="?", default="mp4", help="The original video format (default: mp4).")
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the function with the provided arguments
|
||||
process_csv(args.csv_path, args.video_format)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -21,7 +21,10 @@
|
|||
# T5: https://github.com/google-research/text-to-text-transfer-transformer
|
||||
# --------------------------------------------------------
|
||||
|
||||
import html
|
||||
import re
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
|
|
@ -96,12 +99,8 @@ class T5Embedder:
|
|||
self.hf_token = hf_token
|
||||
|
||||
assert from_pretrained in self.available_models
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir
|
||||
)
|
||||
self.model = T5EncoderModel.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, **t5_model_kwargs
|
||||
).eval()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, cache_dir=cache_dir)
|
||||
self.model = T5EncoderModel.from_pretrained(from_pretrained, cache_dir=cache_dir, **t5_model_kwargs).eval()
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def get_text_embeddings(self, texts):
|
||||
|
|
@ -185,3 +184,142 @@ class T5Encoder:
|
|||
def null(self, n):
|
||||
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
|
||||
return null_y
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
BAD_PUNCT_REGEX = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
|
||||
def clean_caption(caption):
|
||||
import urllib.parse as ul
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = basic_clean(caption)
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
|
||||
def text_preprocessing(text, use_text_preprocessing: bool = True):
|
||||
if use_text_preprocessing:
|
||||
# The exact text cleaning as was in the training stage:
|
||||
text = clean_caption(text)
|
||||
text = clean_caption(text)
|
||||
return text
|
||||
else:
|
||||
return text.lower().strip()
|
||||
|
|
|
|||
|
|
@ -169,11 +169,12 @@ def save(
|
|||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
sampler_start_idx = step * batch_size if batch_size is not None else None
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"sample_start_index": step * batch_size,
|
||||
"sample_start_index": sampler_start_idx,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from mmengine.runner import set_random_seed
|
|||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.datasets.utils import read_from_path
|
||||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
|
@ -65,11 +66,13 @@ def process_prompts(prompts, num_loop):
|
|||
for i in range(0, len(prompt_list), 2):
|
||||
start_loop = int(prompt_list[i])
|
||||
text = prompt_list[i + 1]
|
||||
text = text_preprocessing(text)
|
||||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
|
||||
text_list.extend([text] * (end_loop - start_loop))
|
||||
assert len(text_list) == num_loop
|
||||
ret_prompts.append(text_list)
|
||||
else:
|
||||
prompt = text_preprocessing(prompt)
|
||||
ret_prompts.append([prompt] * num_loop)
|
||||
return ret_prompts
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from mmengine.runner import set_random_seed
|
|||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
|
@ -97,6 +98,7 @@ def main():
|
|||
for i in range(0, len(prompts), cfg.batch_size):
|
||||
# 4.2 sample in hidden space
|
||||
batch_prompts = prompts[i : i + cfg.batch_size]
|
||||
batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts]
|
||||
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
||||
|
||||
# 4.3. diffusion sampling
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from glob import glob
|
|||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import IMG_EXTENSIONS
|
||||
|
|
@ -31,6 +32,8 @@ def apply(df, func, **kwargs):
|
|||
return df.progress_apply(func, **kwargs)
|
||||
|
||||
|
||||
TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"]
|
||||
|
||||
# ======================================================
|
||||
# --info
|
||||
# ======================================================
|
||||
|
|
@ -71,14 +74,15 @@ def get_info(path):
|
|||
|
||||
|
||||
def get_video_info(path):
|
||||
import torchvision
|
||||
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
num_frames, height, width = vframes.shape[0], vframes.shape[1], vframes.shape[2]
|
||||
aspect_ratio = height / width
|
||||
fps = np.nan
|
||||
resolution = height * width
|
||||
return num_frames, height, width, aspect_ratio, fps, resolution
|
||||
try:
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
||||
aspect_ratio = height / width
|
||||
fps = np.nan
|
||||
resolution = height * width
|
||||
return num_frames, height, width, aspect_ratio, fps, resolution
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -134,13 +138,13 @@ CMOTION_TEXT = {
|
|||
}
|
||||
CMOTION_PROBS = {
|
||||
# hard-coded probabilities
|
||||
"static": 0.1,
|
||||
"dynamic": 0.25,
|
||||
"static": 1.0,
|
||||
"dynamic": 1.0,
|
||||
"unknown": 0.0,
|
||||
"zoom in": 0.8,
|
||||
"zoom in": 1.0,
|
||||
"zoom out": 1.0,
|
||||
"pan left": 1.0,
|
||||
"pan right": 0.1,
|
||||
"pan right": 1.0,
|
||||
"tilt up": 1.0,
|
||||
"tilt down": 1.0,
|
||||
"pan/tilt": 1.0,
|
||||
|
|
@ -356,6 +360,9 @@ def read_file(input_path):
|
|||
|
||||
|
||||
def save_file(data, output_path):
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir) and output_dir != "":
|
||||
os.makedirs(output_dir)
|
||||
if output_path.endswith(".csv"):
|
||||
return data.to_csv(output_path, index=False)
|
||||
elif output_path.endswith(".parquet"):
|
||||
|
|
@ -410,6 +417,12 @@ def main(args):
|
|||
data = pd.merge(data, data_new[cols_to_use], on="path", how="inner")
|
||||
print(f"Intersection number of samples: {len(data)}.")
|
||||
|
||||
# train columns
|
||||
if args.train_column:
|
||||
all_columns = data.columns
|
||||
columns_to_drop = all_columns.difference(TRAIN_COLUMNS)
|
||||
data = data.drop(columns=columns_to_drop)
|
||||
|
||||
# get output path
|
||||
output_path = get_output_path(args, input_name)
|
||||
|
||||
|
|
@ -469,20 +482,18 @@ def main(args):
|
|||
data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath))
|
||||
if args.abspath is not None:
|
||||
data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x))
|
||||
if args.merge_cmotion:
|
||||
data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1)
|
||||
if args.refine_llm_caption:
|
||||
assert "text" in data.columns
|
||||
data["text"] = apply(data["text"], remove_caption_prefix)
|
||||
if args.unescape:
|
||||
assert "text" in data.columns
|
||||
data["text"] = apply(data["text"], html.unescape)
|
||||
if args.clean_caption:
|
||||
assert "text" in data.columns
|
||||
data["text"] = apply(
|
||||
data["text"],
|
||||
partial(text_preprocessing, use_text_preprocessing=True),
|
||||
)
|
||||
if args.merge_cmotion:
|
||||
data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1)
|
||||
|
||||
if args.count_num_token is not None:
|
||||
assert "text" in data.columns
|
||||
data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"]))
|
||||
|
|
@ -505,7 +516,11 @@ def main(args):
|
|||
assert "num_frames" in data.columns
|
||||
data = data[data["num_frames"] <= args.fmax]
|
||||
if args.hwmax is not None:
|
||||
assert "resolution" in data.columns
|
||||
if "resolution" not in data.columns:
|
||||
height = data["height"]
|
||||
width = data["width"]
|
||||
data["resolution"] = height * width
|
||||
breakpoint()
|
||||
data = data[data["resolution"] <= args.hwmax]
|
||||
if args.aesmin is not None:
|
||||
assert "aes" in data.columns
|
||||
|
|
@ -549,6 +564,7 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns"
|
||||
)
|
||||
parser.add_argument("--train-column", action="store_true", help="only keep the train column")
|
||||
|
||||
# IO-related
|
||||
parser.add_argument("--info", action="store_true", help="get the basic information of each video and image")
|
||||
|
|
@ -578,7 +594,6 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training"
|
||||
)
|
||||
parser.add_argument("--unescape", action="store_true", help="unescape the caption")
|
||||
parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption")
|
||||
parser.add_argument(
|
||||
"--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption"
|
||||
|
|
@ -641,8 +656,6 @@ def get_output_path(args, input_name):
|
|||
# caption processing
|
||||
if args.refine_llm_caption:
|
||||
name += "_llm"
|
||||
if args.unescape:
|
||||
name += "_unescape"
|
||||
if args.clean_caption:
|
||||
name += "_clean"
|
||||
if args.merge_cmotion:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def extract_frames(
|
|||
points=None,
|
||||
backend="opencv",
|
||||
return_length=False,
|
||||
num_frames=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -34,7 +35,10 @@ def extract_frames(
|
|||
import av
|
||||
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = container.streams.video[0].frames
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
|
@ -56,8 +60,10 @@ def extract_frames(
|
|||
import decord
|
||||
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frames = len(container)
|
||||
# avg_fps = container.get_avg_fps()
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = len(container)
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
|
@ -73,7 +79,10 @@ def extract_frames(
|
|||
|
||||
elif backend == "opencv":
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
|
@ -90,7 +99,6 @@ def extract_frames(
|
|||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
except Exception as e:
|
||||
breakpoint()
|
||||
print(f"Error reading frame {video_path}: {e}")
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
|
|
|
|||
|
|
@ -45,10 +45,14 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
if not is_video(path):
|
||||
images = [pil_loader(path)]
|
||||
else:
|
||||
images = extract_frames(sample["path"], points=self.points, backend="opencv")
|
||||
num_frames = None
|
||||
if "num_frames" in sample:
|
||||
num_frames = sample["num_frames"]
|
||||
images = extract_frames(sample["path"], points=self.points, backend="opencv", num_frames=num_frames)
|
||||
images = [self.transform(img) for img in images]
|
||||
images = torch.stack(images)
|
||||
return dict(index=index, images=images)
|
||||
ret = dict(index=index, images=images)
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
|
|
|||
Loading…
Reference in a new issue