Open-Sora/opensora/schedulers/dpms/__init__.py

61 lines
1.6 KiB
Python
Raw Normal View History

2024-03-15 15:06:36 +01:00
from functools import partial
import torch
from opensora.registry import SCHEDULERS
from .dpm_solver import DPMS
@SCHEDULERS.register_module("dpm-solver")
2024-04-10 04:07:21 +02:00
class DPM_SOLVER:
2024-03-15 15:06:36 +01:00
def __init__(self, num_sampling_steps=None, cfg_scale=4.0):
self.num_sampling_steps = num_sampling_steps
self.cfg_scale = cfg_scale
def sample(
self,
model,
text_encoder,
2024-03-23 13:28:34 +01:00
z,
2024-03-15 15:06:36 +01:00
prompts,
device,
additional_args=None,
2024-03-23 13:46:27 +01:00
mask=None,
2024-05-04 12:10:54 +02:00
progress=True,
2024-03-15 15:06:36 +01:00
):
2024-06-24 09:04:08 +02:00
if mask is not None:
print("[WARNING] mask is not supported in dpm-solver, it will be ignored")
2024-03-15 15:06:36 +01:00
n = len(prompts)
model_args = text_encoder.encode(prompts)
y = model_args.pop("y")
null_y = text_encoder.null(n)
if additional_args is not None:
model_args.update(additional_args)
dpms = DPMS(
partial(forward_with_dpmsolver, model),
condition=y,
uncondition=null_y,
cfg_scale=self.cfg_scale,
model_kwargs=model_args,
)
2024-05-04 12:10:54 +02:00
samples = dpms.sample(
z,
steps=self.num_sampling_steps,
order=2,
skip_type="time_uniform",
method="multistep",
progress=progress,
)
2024-03-15 15:06:36 +01:00
return samples
def forward_with_dpmsolver(self, x, timestep, y, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, **kwargs)
return model_out.chunk(2, dim=1)[0]