fix cfg_channel (#217)

This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-03-25 21:15:16 +08:00 committed by GitHub
parent 1a913cd21b
commit e826311de4

View file

@ -22,6 +22,7 @@ class IDDPM(SpacedDiffusion):
rescale_learned_sigmas=False,
diffusion_steps=1000,
cfg_scale=4.0,
cfg_channel=None,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
@ -49,6 +50,7 @@ class IDDPM(SpacedDiffusion):
)
self.cfg_scale = cfg_scale
self.cfg_channel = cfg_channel
def sample(
self,
@ -68,7 +70,7 @@ class IDDPM(SpacedDiffusion):
if additional_args is not None:
model_args.update(additional_args)
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale)
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale, cfg_channel=self.cfg_channel)
samples = self.p_sample_loop(
forward,
z.shape,