2025-03-12 06:14:22 +01:00
import copy
import os
import re
from enum import Enum
import torch
from torch import nn
from opensora . datasets import save_sample
from opensora . datasets . aspect import get_image_size
from opensora . datasets . utils import read_from_path , rescale_image_by_path
from opensora . utils . logger import log_message
from opensora . utils . prompt_refine import refine_prompts
class SamplingMethod ( Enum ) :
I2V = " i2v " # for open sora video generation
DISTILLED = " distill " # for flux image generation
2025-03-17 08:37:03 +01:00
I2VINFERENCESCALING = " i2v_inference_scaling " # for open sora video generation with inference scaling
2025-03-12 06:14:22 +01:00
def create_tmp_csv ( save_dir : str , prompt : str , ref : str = None , create = True ) - > str :
"""
Create a temporary CSV file with the prompt text .
Args :
save_dir ( str ) : The directory where the CSV file will be saved .
prompt ( str ) : The prompt text .
Returns :
str : The path to the temporary CSV file .
"""
tmp_file = os . path . join ( save_dir , " prompt.csv " )
if not create :
return tmp_file
with open ( tmp_file , " w " , encoding = " utf-8 " ) as f :
if ref is not None :
f . write ( f ' text,ref \n " { prompt } " , " { ref } " ' )
else :
f . write ( f ' text \n " { prompt } " ' )
return tmp_file
def modify_option_to_t2i ( sampling_option , distilled : bool = False , img_resolution : str = " 1080px " ) :
"""
Modify the sampling option to be used for text - to - image generation .
"""
sampling_option_t2i = copy . copy ( sampling_option )
if distilled :
sampling_option_t2i . method = SamplingMethod . DISTILLED
sampling_option_t2i . num_frames = 1
sampling_option_t2i . height , sampling_option_t2i . width = get_image_size ( img_resolution , sampling_option . aspect_ratio )
sampling_option_t2i . guidance = 4.0
sampling_option_t2i . resized_resolution = sampling_option . resolution
return sampling_option_t2i
def get_save_path_name (
save_dir ,
sub_dir ,
save_prefix = " " ,
name = None ,
fallback_name = None ,
index = None ,
num_sample_pos = None , # idx for prompt as path
prompt_as_path = False , # save sample with same name as prompt
prompt = None ,
) :
"""
Get the save path for the generated samples .
"""
if prompt_as_path : # for vbench
cleaned_prompt = prompt . strip ( " . " )
fname = f " { cleaned_prompt } - { num_sample_pos } "
else :
if name is not None :
fname = save_prefix + name
else :
fname = f " { save_prefix + fallback_name } _ { index : 04d } "
if num_sample_pos > 0 :
fname + = f " _ { num_sample_pos } "
return os . path . join ( save_dir , sub_dir , fname )
def get_names_from_path ( path ) :
"""
Get the filename and extension from a path .
Args :
path ( str ) : The path to the file .
Returns :
tuple [ str , str ] : The filename and the extension .
"""
filename = os . path . basename ( path )
name , _ = os . path . splitext ( filename )
return name
def process_and_save (
x : torch . Tensor ,
batch : dict ,
cfg : dict ,
sub_dir : str ,
generate_sampling_option ,
epoch : int ,
start_index : int ,
saving : bool = True ,
) :
"""
Process the generated samples and save them to disk .
"""
fallback_name = cfg . dataset . data_path . split ( " / " ) [ - 1 ] . split ( " . " ) [ 0 ]
prompt_as_path = cfg . get ( " prompt_as_path " , False )
fps_save = cfg . get ( " fps_save " , 16 )
save_dir = cfg . save_dir
names = batch [ " name " ] if " name " in batch else [ None ] * len ( x )
indices = batch [ " index " ] if " index " in batch else [ None ] * len ( x )
if " index " in batch :
indices = [ idx + start_index for idx in indices ]
prompts = batch [ " text " ]
ret_names = [ ]
is_image = generate_sampling_option . num_frames == 1
for img , name , index , prompt in zip ( x , names , indices , prompts ) :
# == get save path ==
save_path = get_save_path_name (
save_dir ,
sub_dir ,
save_prefix = cfg . get ( " save_prefix " , " " ) ,
name = name ,
fallback_name = fallback_name ,
index = index ,
num_sample_pos = epoch ,
prompt_as_path = prompt_as_path ,
prompt = prompt ,
)
ret_name = get_names_from_path ( save_path )
ret_names . append ( ret_name )
if saving :
# == write txt to disk ==
with open ( save_path + " .txt " , " w " , encoding = " utf-8 " ) as f :
f . write ( prompt )
# == save samples ==
save_sample ( img , save_path = save_path , fps = fps_save )
# == resize image for t2i2v ==
if (
cfg . get ( " use_t2i2v " , False )
and is_image
and generate_sampling_option . resolution != generate_sampling_option . resized_resolution
) :
log_message ( " Rescaling image to %s ... " , generate_sampling_option . resized_resolution )
height , width = get_image_size (
generate_sampling_option . resized_resolution , generate_sampling_option . aspect_ratio
)
rescale_image_by_path ( save_path + " .png " , width , height )
return ret_names
def check_fps_added ( sentence ) :
"""
Check if the sentence ends with the FPS information .
"""
pattern = r " \ d+ FPS \ .$ "
if re . search ( pattern , sentence ) :
return True
return False
def ensure_sentence_ends_with_period ( sentence : str ) :
"""
Ensure that the sentence ends with a period .
"""
sentence = sentence . strip ( )
if not sentence . endswith ( " . " ) :
sentence + = " . "
return sentence
def add_fps_info_to_text ( text : list [ str ] , fps : int = 16 ) :
"""
Add the FPS information to the text .
"""
mod_text = [ ]
for item in text :
item = ensure_sentence_ends_with_period ( item )
if not check_fps_added ( item ) :
item = item + f " { fps } FPS. "
mod_text . append ( item )
return mod_text
def add_motion_score_to_text ( text , motion_score : int | str ) :
"""
Add the motion score to the text .
"""
if motion_score == " dynamic " :
ms = refine_prompts ( text , type = " motion_score " )
return [ f " { t } { ms [ i ] } . " for i , t in enumerate ( text ) ]
else :
return [ f " { t } { motion_score } motion score. " for t in text ]
def add_noise_to_ref ( masked_ref : torch . Tensor , masks : torch . Tensor , t : float , sigma_min : float = 1e-5 ) :
z_1 = torch . randn_like ( masked_ref )
z_noisy = ( 1 - ( 1 - sigma_min ) * t ) * masked_ref + t * z_1
return masks * z_noisy
def collect_references_batch (
reference_paths : list [ str ] ,
cond_type : str ,
model_ae : nn . Module ,
image_size : tuple [ int , int ] ,
is_causal = False ,
) :
refs_x = [ ] # refs_x: [batch, ref_num, C, T, H, W]
device = next ( model_ae . parameters ( ) ) . device
dtype = next ( model_ae . parameters ( ) ) . dtype
for reference_path in reference_paths :
if reference_path == " " :
refs_x . append ( None )
continue
ref_path = reference_path . split ( " ; " )
ref = [ ]
if " v2v " in cond_type :
r = read_from_path ( ref_path [ 0 ] , image_size , transform_name = " resize_crop " ) # size [C, T, H, W]
actual_t = r . size ( 1 )
target_t = (
64 if ( actual_t > = 64 and " easy " in cond_type ) else 32
) # if reference not long enough, default to shorter ref
if is_causal :
target_t + = 1
assert actual_t > = target_t , f " need at least { target_t } reference frames for v2v generation "
if " head " in cond_type : # v2v head
r = r [ : , : target_t ]
elif " tail " in cond_type : # v2v tail
r = r [ : , - target_t : ]
else :
raise NotImplementedError
r_x = model_ae . encode ( r . unsqueeze ( 0 ) . to ( device , dtype ) )
r_x = r_x . squeeze ( 0 ) # size [C, T, H, W]
ref . append ( r_x )
elif cond_type == " i2v_head " : # take the 1st frame from first ref_path
r = read_from_path ( ref_path [ 0 ] , image_size , transform_name = " resize_crop " ) # size [C, T, H, W]
r = r [ : , : 1 ]
r_x = model_ae . encode ( r . unsqueeze ( 0 ) . to ( device , dtype ) )
r_x = r_x . squeeze ( 0 ) # size [C, T, H, W]
ref . append ( r_x )
elif cond_type == " i2v_tail " : # take the last frame from last ref_path
r = read_from_path ( ref_path [ - 1 ] , image_size , transform_name = " resize_crop " ) # size [C, T, H, W]
r = r [ : , - 1 : ]
r_x = model_ae . encode ( r . unsqueeze ( 0 ) . to ( device , dtype ) )
r_x = r_x . squeeze ( 0 ) # size [C, T, H, W]
ref . append ( r_x )
elif cond_type == " i2v_loop " :
# first frame
r_head = read_from_path ( ref_path [ 0 ] , image_size , transform_name = " resize_crop " ) # size [C, T, H, W]
r_head = r_head [ : , : 1 ]
r_x_head = model_ae . encode ( r_head . unsqueeze ( 0 ) . to ( device , dtype ) )
r_x_head = r_x_head . squeeze ( 0 ) # size [C, T, H, W]
ref . append ( r_x_head )
# last frame
r_tail = read_from_path ( ref_path [ - 1 ] , image_size , transform_name = " resize_crop " ) # size [C, T, H, W]
r_tail = r_tail [ : , - 1 : ]
r_x_tail = model_ae . encode ( r_tail . unsqueeze ( 0 ) . to ( device , dtype ) )
r_x_tail = r_x_tail . squeeze ( 0 ) # size [C, T, H, W]
ref . append ( r_x_tail )
else :
raise NotImplementedError ( f " Unknown condition type { cond_type } " )
refs_x . append ( ref )
return refs_x
def prepare_inference_condition (
z : torch . Tensor ,
mask_cond : str ,
ref_list : list [ list [ torch . Tensor ] ] = None ,
causal : bool = True ,
) - > torch . Tensor :
"""
Prepare the visual condition for the model , using causal vae .
Args :
z ( torch . Tensor ) : The latent noise tensor , of shape [ B , C , T , H , W ]
mask_cond ( dict ) : The condition configuration .
ref_list : list of lists of media ( image / video ) for i2v and v2v condition , of shape [ C , T ' , H, W]; len(ref_list)==B; ref_list[i] is the list of media for the generation in batch idx i, we use a list of media for each batch item so that it can have multiple references. For example, ref_list[i] could be [ref_image_1, ref_image_2] for i2v_loop condition.
Returns :
torch . Tensor : The visual condition tensor .
"""
# x has shape [b, c, t, h, w], where b is the batch size
B , C , T , H , W = z . shape
masks = torch . zeros ( B , 1 , T , H , W )
masked_z = torch . zeros ( B , C , T , H , W )
if ref_list is None :
assert mask_cond == " t2v " , f " reference is required for { mask_cond } "
for i in range ( B ) :
ref = ref_list [ i ]
# warning message
if ref is None and mask_cond != " t2v " :
print ( " no reference found. will default to cond_type t2v! " )
if ref is not None and T > 1 : # video
# Apply the selected mask condition directly on the masks tensor
if mask_cond == " i2v_head " : # equivalent to masking the first timestep
masks [ i , : , 0 , : , : ] = 1
masked_z [ i , : , 0 , : , : ] = ref [ 0 ] [ : , 0 , : , : ]
elif mask_cond == " i2v_tail " : # mask the last timestep
masks [ i , : , - 1 , : , : ] = 1
masked_z [ i , : , - 1 , : , : ] = ref [ - 1 ] [ : , - 1 , : , : ]
elif mask_cond == " v2v_head " :
k = 8 + int ( causal )
masks [ i , : , : k , : , : ] = 1
masked_z [ i , : , : k , : , : ] = ref [ 0 ] [ : , : k , : , : ]
elif mask_cond == " v2v_tail " :
k = 8 + int ( causal )
masks [ i , : , - k : , : , : ] = 1
masked_z [ i , : , - k : , : , : ] = ref [ 0 ] [ : , - k : , : , : ]
elif mask_cond == " v2v_head_easy " :
k = 16 + int ( causal )
masks [ i , : , : k , : , : ] = 1
masked_z [ i , : , : k , : , : ] = ref [ 0 ] [ : , : k , : , : ]
elif mask_cond == " v2v_tail_easy " :
k = 16 + int ( causal )
masks [ i , : , - k : , : , : ] = 1
masked_z [ i , : , - k : , : , : ] = ref [ 0 ] [ : , - k : , : , : ]
elif mask_cond == " i2v_loop " : # mask first and last timesteps
masks [ i , : , 0 , : , : ] = 1
masks [ i , : , - 1 , : , : ] = 1
masked_z [ i , : , 0 , : , : ] = ref [ 0 ] [ : , 0 , : , : ]
masked_z [ i , : , - 1 , : , : ] = ref [ - 1 ] [ : , - 1 , : , : ] # last frame of last referenced content
else :
# "t2v" is the fallback case where no specific condition is specified
assert mask_cond == " t2v " , f " Unknown mask condition { mask_cond } "
masks = masks . to ( z . device , z . dtype )
masked_z = masked_z . to ( z . device , z . dtype )
return masks , masked_z