2024-03-15 15:06:36 +01:00
import argparse
2024-03-24 09:58:16 +01:00
import html
2024-04-11 07:34:02 +02:00
import json
2024-03-15 15:06:36 +01:00
import os
2024-04-05 17:42:31 +02:00
import random
2024-04-02 08:56:44 +02:00
import re
from functools import partial
2024-04-01 05:45:38 +02:00
from glob import glob
2024-03-15 15:06:36 +01:00
2024-04-13 09:44:24 +02:00
import cv2
2024-03-25 08:12:18 +01:00
import numpy as np
import pandas as pd
2024-05-21 06:05:02 +02:00
from PIL import Image
2024-03-15 15:06:36 +01:00
from tqdm import tqdm
2024-05-31 09:17:01 +02:00
from opensora . datasets . read_video import read_video
2024-04-13 18:20:00 +02:00
2024-06-11 08:21:07 +02:00
from . utils import IMG_EXTENSIONS
2024-03-25 08:12:18 +01:00
tqdm . pandas ( )
2024-03-25 08:36:32 +01:00
try :
from pandarallel import pandarallel
2024-04-20 15:23:10 +02:00
PANDA_USE_PARALLEL = True
2024-03-25 08:36:32 +01:00
except ImportError :
2024-04-20 15:23:10 +02:00
PANDA_USE_PARALLEL = False
2024-03-25 08:36:32 +01:00
2024-04-11 05:48:06 +02:00
def apply ( df , func , * * kwargs ) :
2024-04-20 15:23:10 +02:00
if PANDA_USE_PARALLEL :
2024-04-11 05:48:06 +02:00
return df . parallel_apply ( func , * * kwargs )
return df . progress_apply ( func , * * kwargs )
2024-04-15 11:36:20 +02:00
TRAIN_COLUMNS = [ " path " , " text " , " num_frames " , " fps " , " height " , " width " , " aspect_ratio " , " resolution " , " text_len " ]
2024-04-11 05:48:06 +02:00
# ======================================================
# --info
# ======================================================
2024-04-13 18:20:00 +02:00
def get_video_length ( cap , method = " header " ) :
assert method in [ " header " , " set " ]
if method == " header " :
length = int ( cap . get ( cv2 . CAP_PROP_FRAME_COUNT ) )
else :
cap . set ( cv2 . CAP_PROP_POS_AVI_RATIO , 1 )
length = int ( cap . get ( cv2 . CAP_PROP_POS_FRAMES ) )
return length
2024-05-21 04:11:23 +02:00
def get_info_old ( path ) :
2024-04-11 04:10:13 +02:00
try :
ext = os . path . splitext ( path ) [ 1 ] . lower ( )
if ext in IMG_EXTENSIONS :
im = cv2 . imread ( path )
if im is None :
2024-05-14 05:21:14 +02:00
return 0 , 0 , 0 , np . nan , np . nan , np . nan
2024-04-11 04:10:13 +02:00
height , width = im . shape [ : 2 ]
num_frames , fps = 1 , np . nan
else :
cap = cv2 . VideoCapture ( path )
num_frames , height , width , fps = (
2024-04-13 18:20:00 +02:00
get_video_length ( cap , method = " header " ) ,
2024-04-11 04:10:13 +02:00
int ( cap . get ( cv2 . CAP_PROP_FRAME_HEIGHT ) ) ,
int ( cap . get ( cv2 . CAP_PROP_FRAME_WIDTH ) ) ,
float ( cap . get ( cv2 . CAP_PROP_FPS ) ) ,
)
hw = height * width
aspect_ratio = height / width if width > 0 else np . nan
return num_frames , height , width , aspect_ratio , fps , hw
except :
2024-04-11 04:24:01 +02:00
return 0 , 0 , 0 , np . nan , np . nan , np . nan
2024-03-25 08:12:18 +01:00
2024-05-21 04:11:23 +02:00
def get_info ( path ) :
2024-04-14 11:05:25 +02:00
try :
2024-05-21 04:11:23 +02:00
ext = os . path . splitext ( path ) [ 1 ] . lower ( )
if ext in IMG_EXTENSIONS :
return get_image_info ( path )
2024-05-13 08:33:12 +02:00
else :
2024-05-21 04:11:23 +02:00
return get_video_info ( path )
2024-04-14 11:05:25 +02:00
except :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
2024-04-13 18:20:00 +02:00
2024-05-21 06:05:02 +02:00
def get_image_info ( path , backend = " pillow " ) :
if backend == " pillow " :
2024-05-21 04:11:23 +02:00
try :
with open ( path , " rb " ) as f :
img = Image . open ( f )
img = img . convert ( " RGB " )
width , height = img . size
num_frames , fps = 1 , np . nan
hw = height * width
aspect_ratio = height / width if width > 0 else np . nan
return num_frames , height , width , aspect_ratio , fps , hw
except :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
2024-05-21 06:05:02 +02:00
elif backend == " cv2 " :
2024-05-21 04:11:23 +02:00
try :
im = cv2 . imread ( path )
if im is None :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
height , width = im . shape [ : 2 ]
num_frames , fps = 1 , np . nan
hw = height * width
aspect_ratio = height / width if width > 0 else np . nan
return num_frames , height , width , aspect_ratio , fps , hw
except :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
else :
raise ValueError
2024-05-21 06:05:02 +02:00
def get_video_info ( path , backend = " torchvision " ) :
if backend == " torchvision " :
2024-05-21 04:11:23 +02:00
try :
2024-06-01 07:08:42 +02:00
vframes , infos = read_video ( path )
2024-05-21 04:11:23 +02:00
num_frames , height , width = vframes . shape [ 0 ] , vframes . shape [ 2 ] , vframes . shape [ 3 ]
if " video_fps " in infos :
fps = infos [ " video_fps " ]
else :
fps = np . nan
hw = height * width
aspect_ratio = height / width if width > 0 else np . nan
return num_frames , height , width , aspect_ratio , fps , hw
except :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
2024-05-21 06:05:02 +02:00
elif backend == " cv2 " :
2024-05-21 04:11:23 +02:00
try :
cap = cv2 . VideoCapture ( path )
num_frames , height , width , fps = (
get_video_length ( cap , method = " header " ) ,
int ( cap . get ( cv2 . CAP_PROP_FRAME_HEIGHT ) ) ,
int ( cap . get ( cv2 . CAP_PROP_FRAME_WIDTH ) ) ,
float ( cap . get ( cv2 . CAP_PROP_FPS ) ) ,
)
hw = height * width
aspect_ratio = height / width if width > 0 else np . nan
return num_frames , height , width , aspect_ratio , fps , hw
except :
return 0 , 0 , 0 , np . nan , np . nan , np . nan
else :
raise ValueError
2024-04-11 05:48:06 +02:00
# ======================================================
# --refine-llm-caption
# ======================================================
2024-03-25 08:12:18 +01:00
LLAVA_PREFIX = [
2024-03-15 15:06:36 +01:00
" The video shows " ,
" The video captures " ,
" The video features " ,
" The video depicts " ,
" The video presents " ,
" The video features " ,
" The video is " ,
" In the video, " ,
2024-03-31 14:59:33 +02:00
" The image shows " ,
" The image captures " ,
" The image features " ,
" The image depicts " ,
" The image presents " ,
" The image features " ,
" The image is " ,
" The image portrays " ,
" In the image, " ,
2024-03-15 15:06:36 +01:00
]
2024-03-25 08:12:18 +01:00
def remove_caption_prefix ( caption ) :
for prefix in LLAVA_PREFIX :
2024-04-11 05:48:06 +02:00
if caption . startswith ( prefix ) or caption . startswith ( prefix . lower ( ) ) :
2024-03-25 08:12:18 +01:00
caption = caption [ len ( prefix ) : ] . strip ( )
if caption [ 0 ] . islower ( ) :
caption = caption [ 0 ] . upper ( ) + caption [ 1 : ]
return caption
2024-04-07 10:28:44 +02:00
return caption
2024-03-15 15:06:36 +01:00
2024-04-11 05:48:06 +02:00
# ======================================================
# --merge-cmotion
# ======================================================
2024-04-05 17:42:31 +02:00
CMOTION_TEXT = {
2024-06-11 08:21:07 +02:00
" static " : " static " ,
" pan_right " : " pan right " ,
" pan_left " : " pan left " ,
" zoom_in " : " zoom in " ,
" zoom_out " : " zoom out " ,
" tilt_up " : " tilt up " ,
" tilt_down " : " tilt down " ,
# "pan/tilt": "The camera is panning.",
# "dynamic": "The camera is moving.",
# "unknown": None,
2024-04-05 17:42:31 +02:00
}
CMOTION_PROBS = {
# hard-coded probabilities
2024-04-14 11:05:25 +02:00
" static " : 1.0 ,
2024-06-11 08:21:07 +02:00
" zoom_in " : 1.0 ,
" zoom_out " : 1.0 ,
" pan_left " : 1.0 ,
" pan_right " : 1.0 ,
" tilt_up " : 1.0 ,
" tilt_down " : 1.0 ,
# "dynamic": 1.0,
# "unknown": 0.0,
# "pan/tilt": 1.0,
2024-04-05 17:42:31 +02:00
}
def merge_cmotion ( caption , cmotion ) :
text = CMOTION_TEXT [ cmotion ]
prob = CMOTION_PROBS [ cmotion ]
if text is not None and random . random ( ) < prob :
2024-06-11 08:21:07 +02:00
caption = f " { caption } Camera motion: { text } . "
2024-04-05 17:42:31 +02:00
return caption
2024-04-11 05:48:06 +02:00
# ======================================================
# --lang
# ======================================================
2024-03-25 08:12:18 +01:00
def build_lang_detector ( lang_to_detect ) :
from lingua import Language , LanguageDetectorBuilder
2024-03-15 15:06:36 +01:00
2024-03-25 08:12:18 +01:00
lang_dict = dict ( en = Language . ENGLISH )
assert lang_to_detect in lang_dict
valid_lang = lang_dict [ lang_to_detect ]
2024-04-03 07:33:46 +02:00
detector = LanguageDetectorBuilder . from_all_spoken_languages ( ) . with_low_accuracy_mode ( ) . build ( )
2024-03-24 09:58:16 +01:00
2024-03-25 08:12:18 +01:00
def detect_lang ( caption ) :
confidence_values = detector . compute_language_confidence_values ( caption )
confidence = [ x . language for x in confidence_values [ : 5 ] ]
if valid_lang not in confidence :
return False
return True
2024-03-23 09:02:26 +01:00
2024-03-25 08:12:18 +01:00
return detect_lang
2024-03-15 15:06:36 +01:00
2024-04-11 05:48:06 +02:00
# ======================================================
# --clean-caption
# ======================================================
2024-04-02 08:56:44 +02:00
def basic_clean ( text ) :
import ftfy
text = ftfy . fix_text ( text )
text = html . unescape ( html . unescape ( text ) )
return text . strip ( )
BAD_PUNCT_REGEX = re . compile (
2024-04-03 07:33:46 +02:00
r " [ " + " #®•©™&@·º½¾¿¡§~ " + " \ ) " + " \ ( " + " \ ] " + " \ [ " + " \ } " + " \ { " + " \ | " + " \\ " + " \ / " + " \ * " + r " ] { 1,} "
2024-04-02 08:56:44 +02:00
) # noqa
def clean_caption ( caption ) :
import urllib . parse as ul
from bs4 import BeautifulSoup
caption = str ( caption )
caption = ul . unquote_plus ( caption )
caption = caption . strip ( ) . lower ( )
caption = re . sub ( " <person> " , " person " , caption )
# urls:
caption = re . sub (
r " \ b((?:https?:(?: \ / { 1,3}|[a-zA-Z0-9 % ])|[a-zA-Z0-9. \ -]+[.](?:com|co|ru|net|org|edu|gov|it)[ \ w/-]* \ b \ /?(?!@))) " , # noqa
" " ,
caption ,
) # regex for urls
caption = re . sub (
r " \ b((?:www:(?: \ / { 1,3}|[a-zA-Z0-9 % ])|[a-zA-Z0-9. \ -]+[.](?:com|co|ru|net|org|edu|gov|it)[ \ w/-]* \ b \ /?(?!@))) " , # noqa
" " ,
caption ,
) # regex for urls
# html:
caption = BeautifulSoup ( caption , features = " html.parser " ) . text
# @<nickname>
caption = re . sub ( r " @[ \ w \ d]+ \ b " , " " , caption )
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re . sub ( r " [ \ u31c0- \ u31ef]+ " , " " , caption )
caption = re . sub ( r " [ \ u31f0- \ u31ff]+ " , " " , caption )
caption = re . sub ( r " [ \ u3200- \ u32ff]+ " , " " , caption )
caption = re . sub ( r " [ \ u3300- \ u33ff]+ " , " " , caption )
caption = re . sub ( r " [ \ u3400- \ u4dbf]+ " , " " , caption )
caption = re . sub ( r " [ \ u4dc0- \ u4dff]+ " , " " , caption )
caption = re . sub ( r " [ \ u4e00- \ u9fff]+ " , " " , caption )
#######################################################
# все виды тире / all types of dash --> "-"
caption = re . sub (
r " [ \ u002D \ u058A \ u05BE \ u1400 \ u1806 \ u2010- \ u2015 \ u2E17 \ u2E1A \ u2E3A \ u2E3B \ u2E40 \ u301C \ u3030 \ u30A0 \ uFE31 \ uFE32 \ uFE58 \ uFE63 \ uFF0D]+ " , # noqa
" - " ,
caption ,
)
# кавычки к одному стандарту
caption = re . sub ( r " [`´«»“”¨] " , ' " ' , caption )
caption = re . sub ( r " [‘ ’ ] " , " ' " , caption )
# "
caption = re . sub ( r " "? " , " " , caption )
# &
caption = re . sub ( r " & " , " " , caption )
# ip adresses:
caption = re . sub ( r " \ d { 1,3} \ . \ d { 1,3} \ . \ d { 1,3} \ . \ d { 1,3} " , " " , caption )
# article ids:
caption = re . sub ( r " \ d: \ d \ d \ s+$ " , " " , caption )
# \n
caption = re . sub ( r " \\ n " , " " , caption )
# "#123"
caption = re . sub ( r " # \ d { 1,3} \ b " , " " , caption )
# "#12345.."
caption = re . sub ( r " # \ d { 5,} \ b " , " " , caption )
# "123456.."
caption = re . sub ( r " \ b \ d { 6,} \ b " , " " , caption )
# filenames:
caption = re . sub ( r " [ \ S]+ \ .(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4) " , " " , caption )
#
caption = re . sub ( r " [ \" \ ' ] { 2,} " , r ' " ' , caption ) # """AUSVERKAUFT"""
caption = re . sub ( r " [ \ .] { 2,} " , r " " , caption ) # """AUSVERKAUFT"""
caption = re . sub ( BAD_PUNCT_REGEX , r " " , caption ) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re . sub ( r " \ s+ \ . \ s+ " , r " " , caption ) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re . compile ( r " (?: \ -| \ _) " )
if len ( re . findall ( regex2 , caption ) ) > 3 :
caption = re . sub ( regex2 , " " , caption )
caption = basic_clean ( caption )
caption = re . sub ( r " \ b[a-zA-Z] { 1,3} \ d { 3,15} \ b " , " " , caption ) # jc6640
caption = re . sub ( r " \ b[a-zA-Z]+ \ d+[a-zA-Z]+ \ b " , " " , caption ) # jc6640vc
caption = re . sub ( r " \ b \ d+[a-zA-Z]+ \ d+ \ b " , " " , caption ) # 6640vc231
caption = re . sub ( r " (worldwide \ s+)?(free \ s+)?shipping " , " " , caption )
caption = re . sub ( r " (free \ s)?download( \ sfree)? " , " " , caption )
caption = re . sub ( r " \ bclick \ b \ s(?:for|on) \ s \ w+ " , " " , caption )
2024-04-03 07:33:46 +02:00
caption = re . sub ( r " \ b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)( \ simage[s]?)? " , " " , caption )
2024-04-02 08:56:44 +02:00
caption = re . sub ( r " \ bpage \ s+ \ d+ \ b " , " " , caption )
2024-04-03 07:33:46 +02:00
caption = re . sub ( r " \ b \ d*[a-zA-Z]+ \ d+[a-zA-Z]+ \ d+[a-zA-Z \ d]* \ b " , r " " , caption ) # j2d1a2a...
2024-04-02 08:56:44 +02:00
caption = re . sub ( r " \ b \ d+ \ .? \ d*[xх × ] \ d+ \ .? \ d* \ b " , " " , caption )
caption = re . sub ( r " \ b \ s+ \ : \ s+ " , r " : " , caption )
caption = re . sub ( r " ( \ D[, \ ./]) \ b " , r " \ 1 " , caption )
caption = re . sub ( r " \ s+ " , " " , caption )
caption . strip ( )
caption = re . sub ( r " ^[ \" \ ' ]([ \ w \ W]+)[ \" \ ' ]$ " , r " \ 1 " , caption )
caption = re . sub ( r " ^[ \ ' \ _, \ - \ :;] " , r " " , caption )
caption = re . sub ( r " [ \ ' \ _, \ - \ : \ - \ +]$ " , r " " , caption )
caption = re . sub ( r " ^ \ . \ S+$ " , " " , caption )
return caption . strip ( )
def text_preprocessing ( text , use_text_preprocessing : bool = True ) :
if use_text_preprocessing :
# The exact text cleaning as was in the training stage:
text = clean_caption ( text )
text = clean_caption ( text )
return text
else :
return text . lower ( ) . strip ( )
2024-04-05 17:42:31 +02:00
2024-04-11 07:34:02 +02:00
# ======================================================
# load caption
# ======================================================
def load_caption ( path , ext ) :
2024-04-13 09:44:24 +02:00
try :
assert ext in [ " json " ]
json_path = path . split ( " . " ) [ 0 ] + " .json "
with open ( json_path , " r " ) as f :
data = json . load ( f )
caption = data [ " caption " ]
return caption
except :
return " "
2024-04-11 07:34:02 +02:00
2024-06-02 09:49:46 +02:00
# ======================================================
# --clean-caption
# ======================================================
DROP_SCORE_PROB = 0.2
def score_to_text ( data ) :
text = data [ " text " ]
scores = [ ]
# aesthetic
if " aes " in data :
aes = data [ " aes " ]
if random . random ( ) > DROP_SCORE_PROB :
score_text = f " aesthetic score: { aes : .1f } "
scores . append ( score_text )
if " flow " in data :
flow = data [ " flow " ]
if random . random ( ) > DROP_SCORE_PROB :
score_text = f " motion score: { flow : .1f } "
scores . append ( score_text )
if len ( scores ) > 0 :
text = f " { text } [ { ' , ' . join ( scores ) } ] "
return text
2024-04-11 05:48:06 +02:00
# ======================================================
# read & write
# ======================================================
2024-04-04 08:51:36 +02:00
2024-04-11 05:48:06 +02:00
def read_file ( input_path ) :
if input_path . endswith ( " .csv " ) :
return pd . read_csv ( input_path )
elif input_path . endswith ( " .parquet " ) :
return pd . read_parquet ( input_path )
else :
raise NotImplementedError ( f " Unsupported file format: { input_path } " )
2024-03-25 08:12:18 +01:00
2024-04-11 05:48:06 +02:00
def save_file ( data , output_path ) :
2024-04-15 07:44:30 +02:00
output_dir = os . path . dirname ( output_path )
2024-04-15 11:36:20 +02:00
if not os . path . exists ( output_dir ) and output_dir != " " :
2024-04-15 07:44:30 +02:00
os . makedirs ( output_dir )
2024-04-11 05:48:06 +02:00
if output_path . endswith ( " .csv " ) :
return data . to_csv ( output_path , index = False )
elif output_path . endswith ( " .parquet " ) :
return data . to_parquet ( output_path , index = False )
else :
raise NotImplementedError ( f " Unsupported file format: { output_path } " )
2024-03-25 08:12:18 +01:00
2024-04-11 05:48:06 +02:00
def read_data ( input_paths ) :
2024-03-25 08:12:18 +01:00
data = [ ]
input_name = " "
2024-04-01 05:45:38 +02:00
input_list = [ ]
2024-04-11 05:48:06 +02:00
for input_path in input_paths :
2024-04-01 05:45:38 +02:00
input_list . extend ( glob ( input_path ) )
print ( " Input files: " , input_list )
for i , input_path in enumerate ( input_list ) :
2024-06-10 05:42:58 +02:00
if not os . path . exists ( input_path ) :
continue
2024-04-11 05:48:06 +02:00
data . append ( read_file ( input_path ) )
2024-03-25 08:12:18 +01:00
input_name + = os . path . basename ( input_path ) . split ( " . " ) [ 0 ]
2024-04-01 05:45:38 +02:00
if i != len ( input_list ) - 1 :
2024-03-25 11:36:56 +01:00
input_name + = " + "
2024-06-11 08:21:07 +02:00
print ( f " Loaded { len ( data [ - 1 ] ) } samples from ' { input_path } ' . " )
2024-06-10 07:30:59 +02:00
if len ( data ) == 0 :
print ( f " No samples to process. Exit. " )
exit ( )
2024-03-25 08:12:18 +01:00
data = pd . concat ( data , ignore_index = True , sort = False )
2024-06-10 05:42:58 +02:00
print ( f " Total number of samples: { len ( data ) } " )
2024-04-11 05:48:06 +02:00
return data , input_name
# ======================================================
# main
# ======================================================
# To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation
def main ( args ) :
# reading data
data , input_name = read_data ( args . input )
2024-03-25 08:12:18 +01:00
2024-03-31 14:59:33 +02:00
# make difference
if args . difference is not None :
data_diff = pd . read_csv ( args . difference )
print ( f " Difference csv contains { len ( data_diff ) } samples. " )
data = data [ ~ data [ " path " ] . isin ( data_diff [ " path " ] ) ]
input_name + = f " - { os . path . basename ( args . difference ) . split ( ' . ' ) [ 0 ] } "
print ( f " Filtered number of samples: { len ( data ) } . " )
2024-03-31 17:44:37 +02:00
# make intersection
if args . intersection is not None :
2024-04-09 07:55:55 +02:00
data_new = pd . read_csv ( args . intersection )
print ( f " Intersection csv contains { len ( data_new ) } samples. " )
2024-04-09 08:01:10 +02:00
cols_to_use = data_new . columns . difference ( data . columns )
2024-05-14 05:21:14 +02:00
2024-05-15 09:26:22 +02:00
col_on = " path "
# if 'id' in data.columns and 'id' in data_new.columns:
2024-05-14 05:21:14 +02:00
# col_on = 'id'
cols_to_use = cols_to_use . insert ( 0 , col_on )
data = pd . merge ( data , data_new [ cols_to_use ] , on = col_on , how = " inner " )
2024-04-09 07:55:55 +02:00
print ( f " Intersection number of samples: { len ( data ) } . " )
2024-03-31 17:44:37 +02:00
2024-03-25 08:12:18 +01:00
# get output path
output_path = get_output_path ( args , input_name )
# preparation
if args . lang is not None :
detect_lang = build_lang_detector ( args . lang )
2024-04-11 05:48:06 +02:00
if args . count_num_token == " t5 " :
2024-04-05 17:42:31 +02:00
from transformers import AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( " DeepFloyd/t5-v1_1-xxl " )
2024-03-25 08:12:18 +01:00
2024-04-13 09:44:24 +02:00
# IO-related
if args . load_caption is not None :
assert " path " in data . columns
data [ " text " ] = apply ( data [ " path " ] , load_caption , ext = args . load_caption )
if args . info :
info = apply ( data [ " path " ] , get_info )
(
data [ " num_frames " ] ,
data [ " height " ] ,
data [ " width " ] ,
data [ " aspect_ratio " ] ,
data [ " fps " ] ,
data [ " resolution " ] ,
) = zip ( * info )
2024-04-13 18:20:00 +02:00
if args . video_info :
info = apply ( data [ " path " ] , get_video_info )
(
data [ " num_frames " ] ,
data [ " height " ] ,
data [ " width " ] ,
data [ " aspect_ratio " ] ,
data [ " fps " ] ,
data [ " resolution " ] ,
) = zip ( * info )
2024-04-01 05:45:38 +02:00
if args . ext :
assert " path " in data . columns
data = data [ apply ( data [ " path " ] , os . path . exists ) ]
2024-04-13 09:44:24 +02:00
# filtering
2024-04-01 05:45:38 +02:00
if args . remove_url :
assert " text " in data . columns
data = data [ ~ data [ " text " ] . str . contains ( r " (?P<url>https?://[^ \ s]+) " , regex = True ) ]
if args . lang is not None :
assert " text " in data . columns
data = data [ data [ " text " ] . progress_apply ( detect_lang ) ] # cannot parallelize
2024-06-10 09:41:11 +02:00
if args . remove_empty_path :
assert " path " in data . columns
data = data [ data [ " path " ] . str . len ( ) > 0 ]
data = data [ ~ data [ " path " ] . isna ( ) ]
2024-04-03 09:15:04 +02:00
if args . remove_empty_caption :
assert " text " in data . columns
data = data [ data [ " text " ] . str . len ( ) > 0 ]
data = data [ ~ data [ " text " ] . isna ( ) ]
2024-04-11 05:48:06 +02:00
if args . remove_path_duplication :
2024-04-05 11:28:17 +02:00
assert " path " in data . columns
2024-04-07 10:28:44 +02:00
data = data . drop_duplicates ( subset = [ " path " ] )
2024-05-31 11:47:08 +02:00
if args . path_subset :
2024-06-02 09:49:46 +02:00
data = data [ data [ " path " ] . str . contains ( args . path_subset ) ]
2024-04-01 05:45:38 +02:00
2024-03-25 08:12:18 +01:00
# processing
if args . relpath is not None :
2024-03-25 08:36:32 +01:00
data [ " path " ] = apply ( data [ " path " ] , lambda x : os . path . relpath ( x , args . relpath ) )
2024-03-31 17:44:37 +02:00
if args . abspath is not None :
data [ " path " ] = apply ( data [ " path " ] , lambda x : os . path . join ( args . abspath , x ) )
2024-05-14 05:21:14 +02:00
if args . path_to_id :
data [ " id " ] = apply ( data [ " path " ] , lambda x : os . path . splitext ( os . path . basename ( x ) ) [ 0 ] )
2024-04-14 11:05:25 +02:00
if args . merge_cmotion :
data [ " text " ] = apply ( data , lambda x : merge_cmotion ( x [ " text " ] , x [ " cmotion " ] ) , axis = 1 )
2024-04-11 05:48:06 +02:00
if args . refine_llm_caption :
2024-03-25 08:12:18 +01:00
assert " text " in data . columns
2024-03-25 08:36:32 +01:00
data [ " text " ] = apply ( data [ " text " ] , remove_caption_prefix )
2024-05-08 10:07:57 +02:00
if args . append_text is not None :
assert " text " in data . columns
data [ " text " ] = data [ " text " ] + args . append_text
2024-06-11 15:52:49 +02:00
if args . score_to_text :
data [ " text " ] = apply ( data , score_to_text , axis = 1 )
2024-04-02 09:55:58 +02:00
if args . clean_caption :
assert " text " in data . columns
2024-04-02 08:56:44 +02:00
data [ " text " ] = apply (
data [ " text " ] ,
2024-04-02 09:55:58 +02:00
partial ( text_preprocessing , use_text_preprocessing = True ) ,
2024-04-02 08:56:44 +02:00
)
2024-04-11 05:48:06 +02:00
if args . count_num_token is not None :
2024-04-05 17:42:31 +02:00
assert " text " in data . columns
data [ " text_len " ] = apply ( data [ " text " ] , lambda x : len ( tokenizer ( x ) [ " input_ids " ] ) )
2024-06-11 08:21:07 +02:00
if args . update_text is not None :
data_new = pd . read_csv ( args . update_text )
num_updated = data . path . isin ( data_new . path ) . sum ( )
print ( f " Number of updated samples: { num_updated } . " )
data = data . set_index ( " path " )
data_new = data_new [ [ " path " , " text " ] ] . set_index ( " path " )
data . update ( data_new )
data = data . reset_index ( )
2024-03-25 08:12:18 +01:00
2024-04-07 10:28:44 +02:00
# sort
if args . sort is not None :
data = data . sort_values ( by = args . sort , ascending = False )
if args . sort_ascending is not None :
data = data . sort_values ( by = args . sort_ascending , ascending = True )
2024-03-25 08:12:18 +01:00
# filtering
2024-06-10 05:14:25 +02:00
if args . filesize :
assert " path " in data . columns
data [ " filesize " ] = apply ( data [ " path " ] , lambda x : os . stat ( x ) . st_size / 1024 / 1024 )
if args . fsmax is not None :
assert " filesize " in data . columns
data = data [ data [ " filesize " ] < = args . fsmax ]
2024-04-03 07:33:46 +02:00
if args . remove_empty_caption :
assert " text " in data . columns
data = data [ data [ " text " ] . str . len ( ) > 0 ]
data = data [ ~ data [ " text " ] . isna ( ) ]
2024-03-25 08:12:18 +01:00
if args . fmin is not None :
assert " num_frames " in data . columns
data = data [ data [ " num_frames " ] > = args . fmin ]
if args . fmax is not None :
assert " num_frames " in data . columns
data = data [ data [ " num_frames " ] < = args . fmax ]
2024-05-14 08:49:01 +02:00
if args . fpsmax is not None :
assert " fps " in data . columns
2024-05-15 09:26:22 +02:00
data = data [ ( data [ " fps " ] < = args . fpsmax ) | np . isnan ( data [ " fps " ] ) ]
2024-04-08 08:12:50 +02:00
if args . hwmax is not None :
2024-04-15 11:36:20 +02:00
if " resolution " not in data . columns :
height = data [ " height " ]
width = data [ " width " ]
data [ " resolution " ] = height * width
2024-04-08 08:12:50 +02:00
data = data [ data [ " resolution " ] < = args . hwmax ]
2024-03-25 08:12:18 +01:00
if args . aesmin is not None :
2024-04-02 08:51:21 +02:00
assert " aes " in data . columns
data = data [ data [ " aes " ] > = args . aesmin ]
2024-03-25 10:08:35 +01:00
if args . matchmin is not None :
2024-04-02 08:51:21 +02:00
assert " match " in data . columns
data = data [ data [ " match " ] > = args . matchmin ]
2024-04-05 08:48:46 +02:00
if args . flowmin is not None :
assert " flow " in data . columns
data = data [ data [ " flow " ] > = args . flowmin ]
2024-04-07 10:28:44 +02:00
if args . remove_text_duplication :
data = data . drop_duplicates ( subset = [ " text " ] , keep = " first " )
2024-06-10 06:20:50 +02:00
if args . img_only :
data = data [ data [ " path " ] . str . lower ( ) . str . endswith ( IMG_EXTENSIONS ) ]
if args . vid_only :
data = data [ ~ data [ " path " ] . str . lower ( ) . str . endswith ( IMG_EXTENSIONS ) ]
2024-03-25 08:12:18 +01:00
2024-05-21 04:36:26 +02:00
# process data
if args . shuffle :
data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # shuffle
2024-05-21 06:05:02 +02:00
if args . head is not None :
data = data . head ( args . head )
2024-06-11 15:52:49 +02:00
# train columns
if args . train_column :
all_columns = data . columns
columns_to_drop = all_columns . difference ( TRAIN_COLUMNS )
data = data . drop ( columns = columns_to_drop )
2024-05-21 06:05:02 +02:00
print ( f " Filtered number of samples: { len ( data ) } . " )
2024-05-21 04:36:26 +02:00
2024-03-25 08:12:18 +01:00
# shard data
if args . shard is not None :
sharded_data = np . array_split ( data , args . shard )
for i in range ( args . shard ) :
2024-04-11 05:48:06 +02:00
output_path_part = output_path . split ( " . " )
output_path_s = " . " . join ( output_path_part [ : - 1 ] ) + f " _ { i } . " + output_path_part [ - 1 ]
save_file ( sharded_data [ i ] , output_path_s )
2024-03-31 17:44:37 +02:00
print ( f " Saved { len ( sharded_data [ i ] ) } samples to { output_path_s } . " )
2024-03-25 08:12:18 +01:00
else :
2024-04-11 05:48:06 +02:00
save_file ( data , output_path )
2024-03-25 08:12:18 +01:00
print ( f " Saved { len ( data ) } samples to { output_path } . " )
2024-04-11 05:48:06 +02:00
def parse_args ( ) :
parser = argparse . ArgumentParser ( )
2024-04-11 09:28:49 +02:00
parser . add_argument ( " input " , type = str , nargs = " + " , help = " path to the input dataset " )
2024-04-11 05:48:06 +02:00
parser . add_argument ( " --output " , type = str , default = None , help = " output path " )
parser . add_argument ( " --format " , type = str , default = " csv " , help = " output format " , choices = [ " csv " , " parquet " ] )
parser . add_argument ( " --disable-parallel " , action = " store_true " , help = " disable parallel processing " )
2024-04-20 15:23:10 +02:00
parser . add_argument ( " --num-workers " , type = int , default = None , help = " number of workers " )
2024-05-21 04:36:26 +02:00
parser . add_argument ( " --seed " , type = int , default = 42 , help = " random seed " )
2024-04-11 05:48:06 +02:00
# special case
parser . add_argument ( " --shard " , type = int , default = None , help = " shard the dataset " )
parser . add_argument ( " --sort " , type = str , default = None , help = " sort by column " )
parser . add_argument ( " --sort-ascending " , type = str , default = None , help = " sort by column (ascending order) " )
2024-04-11 07:34:02 +02:00
parser . add_argument ( " --difference " , type = str , default = None , help = " get difference from the dataset " )
2024-04-11 05:48:06 +02:00
parser . add_argument (
" --intersection " , type = str , default = None , help = " keep the paths in csv from the dataset and merge columns "
)
2024-04-15 11:36:20 +02:00
parser . add_argument ( " --train-column " , action = " store_true " , help = " only keep the train column " )
2024-04-11 05:48:06 +02:00
# IO-related
parser . add_argument ( " --info " , action = " store_true " , help = " get the basic information of each video and image " )
2024-04-13 18:20:00 +02:00
parser . add_argument ( " --video-info " , action = " store_true " , help = " get the basic information of each video " )
2024-04-11 05:48:06 +02:00
parser . add_argument ( " --ext " , action = " store_true " , help = " check if the file exists " )
2024-04-13 09:44:24 +02:00
parser . add_argument (
" --load-caption " , type = str , default = None , choices = [ " json " , " txt " ] , help = " load the caption from json or txt "
)
2024-04-11 05:48:06 +02:00
# path processing
parser . add_argument ( " --relpath " , type = str , default = None , help = " modify the path to relative path by root given " )
parser . add_argument ( " --abspath " , type = str , default = None , help = " modify the path to absolute path by root given " )
2024-05-15 09:26:22 +02:00
parser . add_argument ( " --path-to-id " , action = " store_true " , help = " add id based on path " )
2024-06-02 09:49:46 +02:00
parser . add_argument (
" --path-subset " , type = str , default = None , help = " extract a subset data containing the given `path-subset` value "
)
2024-06-10 09:41:11 +02:00
parser . add_argument (
" --remove-empty-path " ,
action = " store_true " ,
help = " remove rows with empty path " , # caused by transform, cannot read path
)
2024-04-11 05:48:06 +02:00
# caption filtering
parser . add_argument (
" --remove-empty-caption " ,
action = " store_true " ,
help = " remove rows with empty caption " ,
)
parser . add_argument ( " --remove-url " , action = " store_true " , help = " remove rows with url in caption " )
parser . add_argument ( " --lang " , type = str , default = None , help = " remove rows with other language " )
parser . add_argument ( " --remove-path-duplication " , action = " store_true " , help = " remove rows with duplicated path " )
parser . add_argument ( " --remove-text-duplication " , action = " store_true " , help = " remove rows with duplicated caption " )
# caption processing
parser . add_argument ( " --refine-llm-caption " , action = " store_true " , help = " modify the caption generated by LLM " )
parser . add_argument (
" --clean-caption " , action = " store_true " , help = " modify the caption according to T5 pipeline to suit training "
)
parser . add_argument ( " --merge-cmotion " , action = " store_true " , help = " merge the camera motion to the caption " )
parser . add_argument (
" --count-num-token " , type = str , choices = [ " t5 " ] , default = None , help = " Count the number of tokens in the caption "
)
2024-05-08 10:07:57 +02:00
parser . add_argument ( " --append-text " , type = str , default = None , help = " append text to the caption " )
2024-06-02 09:49:46 +02:00
parser . add_argument ( " --score-to-text " , action = " store_true " , help = " convert score to text " )
2024-06-11 08:21:07 +02:00
parser . add_argument ( " --update-text " , type = str , default = None , help = " update the text with the given text " )
2024-04-11 05:48:06 +02:00
# score filtering
2024-06-10 05:14:25 +02:00
parser . add_argument ( " --filesize " , action = " store_true " , help = " get the filesize of each video and image in MB " )
parser . add_argument ( " --fsmax " , type = int , default = None , help = " filter the dataset by maximum filesize " )
2024-04-11 05:48:06 +02:00
parser . add_argument ( " --fmin " , type = int , default = None , help = " filter the dataset by minimum number of frames " )
parser . add_argument ( " --fmax " , type = int , default = None , help = " filter the dataset by maximum number of frames " )
parser . add_argument ( " --hwmax " , type = int , default = None , help = " filter the dataset by maximum resolution " )
parser . add_argument ( " --aesmin " , type = float , default = None , help = " filter the dataset by minimum aes score " )
parser . add_argument ( " --matchmin " , type = float , default = None , help = " filter the dataset by minimum match score " )
parser . add_argument ( " --flowmin " , type = float , default = None , help = " filter the dataset by minimum flow score " )
2024-05-14 08:49:01 +02:00
parser . add_argument ( " --fpsmax " , type = float , default = None , help = " filter the dataset by maximum fps " )
2024-06-10 06:20:50 +02:00
parser . add_argument ( " --img-only " , action = " store_true " , help = " only keep the image data " )
parser . add_argument ( " --vid-only " , action = " store_true " , help = " only keep the video data " )
2024-04-11 05:48:06 +02:00
2024-05-21 04:36:26 +02:00
# data processing
parser . add_argument ( " --shuffle " , default = False , action = " store_true " , help = " shuffle the dataset " )
2024-05-21 06:05:02 +02:00
parser . add_argument ( " --head " , type = int , default = None , help = " return the first n rows of data " )
2024-05-21 04:36:26 +02:00
2024-04-11 05:48:06 +02:00
return parser . parse_args ( )
def get_output_path ( args , input_name ) :
if args . output is not None :
return args . output
name = input_name
dir_path = os . path . dirname ( args . input [ 0 ] )
# sort
if args . sort is not None :
assert args . sort_ascending is None
name + = " _sort "
if args . sort_ascending is not None :
assert args . sort is None
name + = " _sort "
# IO-related
2024-04-13 09:44:24 +02:00
# for IO-related, the function must be wrapped in try-except
2024-04-11 05:48:06 +02:00
if args . info :
name + = " _info "
2024-04-13 18:20:00 +02:00
if args . video_info :
name + = " _vinfo "
2024-04-11 05:48:06 +02:00
if args . ext :
name + = " _ext "
2024-04-13 09:44:24 +02:00
if args . load_caption :
name + = f " _load { args . load_caption } "
2024-04-11 05:48:06 +02:00
# path processing
if args . relpath is not None :
name + = " _relpath "
if args . abspath is not None :
name + = " _abspath "
2024-06-10 09:41:11 +02:00
if args . remove_empty_path :
name + = " _noemptypath "
2024-04-11 05:48:06 +02:00
# caption filtering
if args . remove_empty_caption :
name + = " _noempty "
if args . remove_url :
name + = " _nourl "
if args . lang is not None :
name + = f " _ { args . lang } "
if args . remove_path_duplication :
name + = " _noduppath "
if args . remove_text_duplication :
name + = " _noduptext "
2024-05-31 11:47:08 +02:00
if args . path_subset :
name + = " _subset "
2024-04-11 05:48:06 +02:00
# caption processing
if args . refine_llm_caption :
name + = " _llm "
if args . clean_caption :
name + = " _clean "
if args . merge_cmotion :
name + = " _cmcaption "
if args . count_num_token :
name + = " _ntoken "
2024-05-08 10:07:57 +02:00
if args . append_text is not None :
name + = " _appendtext "
2024-06-02 09:49:46 +02:00
if args . score_to_text :
name + = " _score2text "
2024-06-11 08:21:07 +02:00
if args . update_text is not None :
name + = " _update "
2024-04-11 05:48:06 +02:00
# score filtering
2024-06-10 05:14:25 +02:00
if args . filesize :
name + = " _filesize "
if args . fsmax is not None :
name + = f " _fsmax { args . fsmax } "
2024-04-11 05:48:06 +02:00
if args . fmin is not None :
name + = f " _fmin { args . fmin } "
if args . fmax is not None :
name + = f " _fmax { args . fmax } "
2024-05-14 08:49:01 +02:00
if args . fpsmax is not None :
name + = f " _fpsmax { args . fpsmax } "
2024-04-11 05:48:06 +02:00
if args . hwmax is not None :
name + = f " _hwmax { args . hwmax } "
if args . aesmin is not None :
name + = f " _aesmin { args . aesmin } "
if args . matchmin is not None :
name + = f " _matchmin { args . matchmin } "
if args . flowmin is not None :
name + = f " _flowmin { args . flowmin } "
2024-06-10 06:20:50 +02:00
if args . img_only :
name + = " _img "
if args . vid_only :
name + = " _vid "
2024-04-11 05:48:06 +02:00
2024-05-21 04:36:26 +02:00
# processing
if args . shuffle :
name + = f " _shuffled_seed { args . seed } "
2024-05-21 06:05:02 +02:00
if args . head is not None :
name + = f " _first_ { args . head } _data "
2024-05-21 04:36:26 +02:00
2024-04-11 05:48:06 +02:00
output_path = os . path . join ( dir_path , f " { name } . { args . format } " )
return output_path
2024-03-25 08:12:18 +01:00
if __name__ == " __main__ " :
args = parse_args ( )
2024-03-25 08:36:32 +01:00
if args . disable_parallel :
2024-04-20 15:23:10 +02:00
PANDA_USE_PARALLEL = False
if PANDA_USE_PARALLEL :
if args . num_workers is not None :
pandarallel . initialize ( nb_workers = args . num_workers , progress_bar = True )
else :
pandarallel . initialize ( progress_bar = True )
2024-04-05 17:42:31 +02:00
if args . seed is not None :
random . seed ( args . seed )
np . random . seed ( args . seed )
2024-03-15 15:06:36 +01:00
main ( args )