import math import random from collections import OrderedDict import torch @torch.no_grad() def update_ema( ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True ) -> None: """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): if name == "pos_embed": continue if param.requires_grad == False: continue if not sharded: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) else: if param.data.dtype != torch.float32: param_id = id(param) master_param = optimizer._param_store.working_to_master_param[param_id] param_data = master_param.data else: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) class MaskGenerator: def __init__(self, mask_ratios): valid_mask_names = [ "mask_no", "mask_quarter_random", "mask_quarter_head", "mask_quarter_tail", "mask_quarter_head_tail", "mask_image_random", "mask_image_head", "mask_image_tail", "mask_image_head_tail", "mask_random", "mask_intepolate", ] assert all( mask_name in valid_mask_names for mask_name in mask_ratios.keys() ), f"mask_name should be one of {valid_mask_names}, got {mask_ratios.keys()}" assert all( mask_ratio >= 0 for mask_ratio in mask_ratios.values() ), f"mask_ratio should be greater than or equal to 0, got {mask_ratios.values()}" assert all( mask_ratio <= 1 for mask_ratio in mask_ratios.values() ), f"mask_ratio should be less than or equal to 1, got {mask_ratios.values()}" # sum of mask_ratios should be 1 if "mask_no" not in mask_ratios: mask_ratios["mask_no"] = 1.0 - sum(mask_ratios.values()) assert math.isclose( sum(mask_ratios.values()), 1.0, abs_tol=1e-6 ), f"sum of mask_ratios should be 1, got {sum(mask_ratios.values())}" print(f"mask ratios: {mask_ratios}") self.mask_ratios = mask_ratios def get_mask(self, x): mask_type = random.random() mask_name = None prob_acc = 0.0 for mask, mask_ratio in self.mask_ratios.items(): prob_acc += mask_ratio if mask_type < prob_acc: mask_name = mask break num_frames = x.shape[2] # Hardcoded condition_frames condition_frames_max = num_frames // 4 mask = torch.ones(num_frames, dtype=torch.bool, device=x.device) if num_frames <= 1: return mask if mask_name == "mask_quarter_random": random_size = random.randint(1, condition_frames_max) random_pos = random.randint(0, x.shape[2] - random_size) mask[random_pos : random_pos + random_size] = 0 elif mask_name == "mask_image_random": random_size = 1 random_pos = random.randint(0, x.shape[2] - random_size) mask[random_pos : random_pos + random_size] = 0 elif mask_name == "mask_quarter_head": random_size = random.randint(1, condition_frames_max) mask[:random_size] = 0 elif mask_name == "mask_image_head": random_size = 1 mask[:random_size] = 0 elif mask_name == "mask_quarter_tail": random_size = random.randint(1, condition_frames_max) mask[-random_size:] = 0 elif mask_name == "mask_image_tail": random_size = 1 mask[-random_size:] = 0 elif mask_name == "mask_quarter_head_tail": random_size = random.randint(1, condition_frames_max) mask[:random_size] = 0 mask[-random_size:] = 0 elif mask_name == "mask_image_head_tail": random_size = 1 mask[:random_size] = 0 mask[-random_size:] = 0 elif mask_name == "mask_intepolate": random_start = random.randint(0, 1) mask[random_start::2] = 0 elif mask_name == "mask_random": mask_ratio = random.uniform(0.3, 0.7) mask = torch.rand(num_frames, device=x.device) > mask_ratio # if mask is all False, set the last frame to True if not mask.any(): mask[-1] = 1 return mask def get_masks(self, x): masks = [] for _ in range(len(x)): mask = self.get_mask(x) masks.append(mask) masks = torch.stack(masks, dim=0) return masks