Open-Sora/opensora/utils/train_utils.py

134 lines
4.8 KiB
Python
Raw Normal View History

2024-03-25 15:20:10 +01:00
import math
2024-03-23 15:06:19 +01:00
import random
2024-03-15 15:06:36 +01:00
from collections import OrderedDict
import torch
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
) -> None:
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if param.requires_grad == False:
continue
if not sharded:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
else:
if param.data.dtype != torch.float32:
param_id = id(param)
master_param = optimizer._param_store.working_to_master_param[param_id]
param_data = master_param.data
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
2024-03-23 15:06:19 +01:00
class MaskGenerator:
2024-04-03 09:15:04 +02:00
def __init__(self, mask_ratios):
2024-04-18 10:13:53 +02:00
valid_mask_names = [
"mask_no",
"mask_quarter_random",
"mask_quarter_head",
"mask_quarter_tail",
"mask_quarter_head_tail",
"mask_image_random",
"mask_image_head",
"mask_image_tail",
"mask_image_head_tail",
2024-04-29 08:00:14 +02:00
"mask_random",
"mask_intepolate",
2024-04-18 10:13:53 +02:00
]
2024-04-03 09:15:04 +02:00
assert all(
mask_name in valid_mask_names for mask_name in mask_ratios.keys()
), f"mask_name should be one of {valid_mask_names}, got {mask_ratios.keys()}"
assert all(
mask_ratio >= 0 for mask_ratio in mask_ratios.values()
), f"mask_ratio should be greater than or equal to 0, got {mask_ratios.values()}"
assert all(
mask_ratio <= 1 for mask_ratio in mask_ratios.values()
), f"mask_ratio should be less than or equal to 1, got {mask_ratios.values()}"
# sum of mask_ratios should be 1
2024-04-29 08:00:14 +02:00
if "mask_no" not in mask_ratios:
mask_ratios["mask_no"] = 1.0 - sum(mask_ratios.values())
2024-04-03 09:15:04 +02:00
assert math.isclose(
sum(mask_ratios.values()), 1.0, abs_tol=1e-6
), f"sum of mask_ratios should be 1, got {sum(mask_ratios.values())}"
print(f"mask ratios: {mask_ratios}")
self.mask_ratios = mask_ratios
2024-03-23 15:06:19 +01:00
def get_mask(self, x):
mask_type = random.random()
2024-04-03 09:15:04 +02:00
mask_name = None
prob_acc = 0.0
for mask, mask_ratio in self.mask_ratios.items():
prob_acc += mask_ratio
if mask_type < prob_acc:
mask_name = mask
2024-03-23 15:06:19 +01:00
break
2024-04-03 09:15:04 +02:00
num_frames = x.shape[2]
# Hardcoded condition_frames
condition_frames_max = num_frames // 4
mask = torch.ones(num_frames, dtype=torch.bool, device=x.device)
if num_frames <= 1:
return mask
2024-04-18 10:13:53 +02:00
if mask_name == "mask_quarter_random":
2024-04-03 09:15:04 +02:00
random_size = random.randint(1, condition_frames_max)
2024-03-23 15:06:19 +01:00
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
2024-04-18 10:13:53 +02:00
elif mask_name == "mask_image_random":
random_size = 1
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
elif mask_name == "mask_quarter_head":
2024-04-03 09:15:04 +02:00
random_size = random.randint(1, condition_frames_max)
2024-03-23 15:06:19 +01:00
mask[:random_size] = 0
2024-04-18 10:13:53 +02:00
elif mask_name == "mask_image_head":
random_size = 1
mask[:random_size] = 0
elif mask_name == "mask_quarter_tail":
2024-04-03 09:15:04 +02:00
random_size = random.randint(1, condition_frames_max)
2024-03-23 15:06:19 +01:00
mask[-random_size:] = 0
2024-04-18 10:13:53 +02:00
elif mask_name == "mask_image_tail":
random_size = 1
mask[-random_size:] = 0
elif mask_name == "mask_quarter_head_tail":
2024-04-03 09:15:04 +02:00
random_size = random.randint(1, condition_frames_max)
2024-03-23 15:06:19 +01:00
mask[:random_size] = 0
mask[-random_size:] = 0
2024-04-18 10:13:53 +02:00
elif mask_name == "mask_image_head_tail":
random_size = 1
mask[:random_size] = 0
mask[-random_size:] = 0
2024-04-29 08:00:14 +02:00
elif mask_name == "mask_intepolate":
random_start = random.randint(0, 1)
mask[random_start::2] = 0
elif mask_name == "mask_random":
mask_ratio = random.uniform(0.3, 0.7)
mask = torch.rand(num_frames, device=x.device) > mask_ratio
# if mask is all False, set the last frame to True
if not mask.any():
mask[-1] = 1
2024-03-23 15:06:19 +01:00
return mask
def get_masks(self, x):
masks = []
for _ in range(len(x)):
mask = self.get_mask(x)
masks.append(mask)
masks = torch.stack(masks, dim=0)
return masks