mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
35 lines
1 KiB
Python
35 lines
1 KiB
Python
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
|
|
@torch.no_grad()
|
|
def update_ema(
|
|
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
|
|
) -> None:
|
|
"""
|
|
Step the EMA model towards the current model.
|
|
"""
|
|
ema_params = OrderedDict(ema_model.named_parameters())
|
|
model_params = OrderedDict(model.named_parameters())
|
|
|
|
for name, param in model_params.items():
|
|
if name == "pos_embed":
|
|
continue
|
|
if param.requires_grad == False:
|
|
continue
|
|
if not sharded:
|
|
param_data = param.data
|
|
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
|
else:
|
|
if param.data.dtype != torch.float32:
|
|
param_id = id(param)
|
|
master_param = optimizer._param_store.working_to_master_param[param_id]
|
|
param_data = master_param.data
|
|
else:
|
|
param_data = param.data
|
|
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
|
|
|
|
|
|