From e3584b4e4351595dbc91c81d1231e340af6713b7 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 5 Apr 2024 16:20:39 +0800 Subject: [PATCH] debug --- opensora/models/vae/vae_3d.py | 10 +++++----- scripts/inference-vae.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index e2aebbf..38698a4 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -56,14 +56,14 @@ class ResBlock(nn.Module): device, dtype = x.device, x.dtype input_dim = x.shape[1] residual = x - x = self.norm1(x) + x = self.norm1.to(device,dtype)(x) x = self.activate(x) - x = self.conv1(x) - x = self.norm2(x) + x = self.conv1.to(device,dtype)(x) + x = self.norm2.to(device, dtype)(x) x = self.activate(x) - x = self.conv2(x) + x = self.conv2.to(device, dtype)(x) if input_dim != self.filters: # TODO: what does it do here - residual = self.conv3(residual) + residual = self.conv3.to(device, dtype)(residual) return x + residual def _get_selected_flags(total_len: int, select_len: int, suffix: bool): diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 82c2c6a..0da2165 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -32,11 +32,11 @@ def main(): colossalai.launch_from_torch({}) coordinator = DistCoordinator() - if coordinator.world_size > 1: - set_sequence_parallel_group(dist.group.WORLD) - enable_sequence_parallelism = True - else: - enable_sequence_parallelism = False + # if coordinator.world_size > 1: + # set_sequence_parallel_group(dist.group.WORLD) + # enable_sequence_parallelism = True + # else: + # enable_sequence_parallelism = False # ====================================================== # 2. runtime variables