mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-27 12:01:32 +02:00
debug
This commit is contained in:
parent
9091419b28
commit
e3584b4e43
|
|
@ -56,14 +56,14 @@ class ResBlock(nn.Module):
|
|||
device, dtype = x.device, x.dtype
|
||||
input_dim = x.shape[1]
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.norm1.to(device,dtype)(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.conv1.to(device,dtype)(x)
|
||||
x = self.norm2.to(device, dtype)(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv2.to(device, dtype)(x)
|
||||
if input_dim != self.filters: # TODO: what does it do here
|
||||
residual = self.conv3(residual)
|
||||
residual = self.conv3.to(device, dtype)(residual)
|
||||
return x + residual
|
||||
|
||||
def _get_selected_flags(total_len: int, select_len: int, suffix: bool):
|
||||
|
|
|
|||
|
|
@ -32,11 +32,11 @@ def main():
|
|||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
# if coordinator.world_size > 1:
|
||||
# set_sequence_parallel_group(dist.group.WORLD)
|
||||
# enable_sequence_parallelism = True
|
||||
# else:
|
||||
# enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
|
|
|
|||
Loading…
Reference in a new issue