This commit is contained in:
Shen-Chenhui 2024-04-05 16:20:39 +08:00
parent 9091419b28
commit e3584b4e43
2 changed files with 10 additions and 10 deletions

View file

@ -56,14 +56,14 @@ class ResBlock(nn.Module):
device, dtype = x.device, x.dtype
input_dim = x.shape[1]
residual = x
x = self.norm1(x)
x = self.norm1.to(device,dtype)(x)
x = self.activate(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.conv1.to(device,dtype)(x)
x = self.norm2.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2(x)
x = self.conv2.to(device, dtype)(x)
if input_dim != self.filters: # TODO: what does it do here
residual = self.conv3(residual)
residual = self.conv3.to(device, dtype)(residual)
return x + residual
def _get_selected_flags(total_len: int, select_len: int, suffix: bool):

View file

@ -32,11 +32,11 @@ def main():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
# if coordinator.world_size > 1:
# set_sequence_parallel_group(dist.group.WORLD)
# enable_sequence_parallelism = True
# else:
# enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables