mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
import math
|
|
import random
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
|
|
@torch.no_grad()
|
|
def update_ema(
|
|
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
|
|
) -> None:
|
|
"""
|
|
Step the EMA model towards the current model.
|
|
"""
|
|
ema_params = OrderedDict(ema_model.named_parameters())
|
|
model_params = OrderedDict(model.named_parameters())
|
|
|
|
for name, param in model_params.items():
|
|
if name == "pos_embed":
|
|
continue
|
|
if param.requires_grad == False:
|
|
continue
|
|
if not sharded:
|
|
param_data = param.data
|
|
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
|
else:
|
|
if param.data.dtype != torch.float32:
|
|
param_id = id(param)
|
|
master_param = optimizer._param_store.working_to_master_param[param_id]
|
|
param_data = master_param.data
|
|
else:
|
|
param_data = param.data
|
|
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
|
|
|
|
|
class MaskGenerator:
|
|
def __init__(self, mask_ratios):
|
|
valid_mask_names = [
|
|
"mask_no",
|
|
"mask_quarter_random",
|
|
"mask_quarter_head",
|
|
"mask_quarter_tail",
|
|
"mask_quarter_head_tail",
|
|
"mask_image_random",
|
|
"mask_image_head",
|
|
"mask_image_tail",
|
|
"mask_image_head_tail",
|
|
]
|
|
assert all(
|
|
mask_name in valid_mask_names for mask_name in mask_ratios.keys()
|
|
), f"mask_name should be one of {valid_mask_names}, got {mask_ratios.keys()}"
|
|
assert all(
|
|
mask_ratio >= 0 for mask_ratio in mask_ratios.values()
|
|
), f"mask_ratio should be greater than or equal to 0, got {mask_ratios.values()}"
|
|
assert all(
|
|
mask_ratio <= 1 for mask_ratio in mask_ratios.values()
|
|
), f"mask_ratio should be less than or equal to 1, got {mask_ratios.values()}"
|
|
# sum of mask_ratios should be 1
|
|
assert math.isclose(
|
|
sum(mask_ratios.values()), 1.0, abs_tol=1e-6
|
|
), f"sum of mask_ratios should be 1, got {sum(mask_ratios.values())}"
|
|
print(f"mask ratios: {mask_ratios}")
|
|
self.mask_ratios = mask_ratios
|
|
|
|
def get_mask(self, x):
|
|
mask_type = random.random()
|
|
mask_name = None
|
|
prob_acc = 0.0
|
|
for mask, mask_ratio in self.mask_ratios.items():
|
|
prob_acc += mask_ratio
|
|
if mask_type < prob_acc:
|
|
mask_name = mask
|
|
break
|
|
|
|
num_frames = x.shape[2]
|
|
# Hardcoded condition_frames
|
|
condition_frames_max = num_frames // 4
|
|
|
|
mask = torch.ones(num_frames, dtype=torch.bool, device=x.device)
|
|
if num_frames <= 1:
|
|
return mask
|
|
|
|
if mask_name == "mask_quarter_random":
|
|
random_size = random.randint(1, condition_frames_max)
|
|
random_pos = random.randint(0, x.shape[2] - random_size)
|
|
mask[random_pos : random_pos + random_size] = 0
|
|
elif mask_name == "mask_image_random":
|
|
random_size = 1
|
|
random_pos = random.randint(0, x.shape[2] - random_size)
|
|
mask[random_pos : random_pos + random_size] = 0
|
|
elif mask_name == "mask_quarter_head":
|
|
random_size = random.randint(1, condition_frames_max)
|
|
mask[:random_size] = 0
|
|
elif mask_name == "mask_image_head":
|
|
random_size = 1
|
|
mask[:random_size] = 0
|
|
elif mask_name == "mask_quarter_tail":
|
|
random_size = random.randint(1, condition_frames_max)
|
|
mask[-random_size:] = 0
|
|
elif mask_name == "mask_image_tail":
|
|
random_size = 1
|
|
mask[-random_size:] = 0
|
|
elif mask_name == "mask_quarter_head_tail":
|
|
random_size = random.randint(1, condition_frames_max)
|
|
mask[:random_size] = 0
|
|
mask[-random_size:] = 0
|
|
elif mask_name == "mask_image_head_tail":
|
|
random_size = 1
|
|
mask[:random_size] = 0
|
|
mask[-random_size:] = 0
|
|
|
|
return mask
|
|
|
|
def get_masks(self, x):
|
|
masks = []
|
|
for _ in range(len(x)):
|
|
mask = self.get_mask(x)
|
|
masks.append(mask)
|
|
masks = torch.stack(masks, dim=0)
|
|
return masks
|