add optical flow

This commit is contained in:
xyupeng 2024-04-01 20:00:27 +08:00
parent 5ade5e5984
commit 833c70f2ae
14 changed files with 2123 additions and 0 deletions

View file

@ -0,0 +1 @@
`python tools/optical_flow/inference.py --meta_path ./data/Panda-70M/processed/meta/test_intact_cut_head-100.csv`

View file

View file

@ -0,0 +1,145 @@
import os
# os.chdir('../..')
print(f'Current working directory: {os.getcwd()}')
import argparse
import av
import decord
import numpy as np
import pandas as pd
from einops import rearrange
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms.functional import pil_to_tensor
from unimatch import UniMatch
def extract_frames_av(video_path, frame_inds=[0, 10, 20, 30]):
container = av.open(video_path)
total_frames = container.streams.video[0].frames
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate)
container.seek(target_timestamp)
frame = next(container.decode(video=0)).to_image()
frames.append(frame)
return frames
def extract_frames(video_path, frame_inds=[0, 10, 20, 30]):
container = decord.VideoReader(video_path, num_threads=1)
total_frames = len(container)
# avg_fps = container.get_avg_fps()
frame_inds = np.array(frame_inds).astype(np.int32)
frame_inds[frame_inds >= total_frames] = total_frames - 1
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
return frames
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, meta_path, frame_inds=[0, 10, 20, 30]):
self.meta_path = meta_path
self.meta = pd.read_csv(meta_path)
self.frame_inds = frame_inds
def __getitem__(self, index):
row = self.meta.iloc[index]
images = extract_frames(row["path"], frame_inds=self.frame_inds)
# images = [pil_to_tensor(x) for x in images] # [C, H, W]
# transform
images = torch.from_numpy(images).float()
images = rearrange(images, 'N H W C -> N C H W')
H, W = images.shape[-2:]
if H > W:
images = rearrange(images, 'N C H W -> N C W H')
images = F.interpolate(images, size=(320, 576), mode='bilinear', align_corners=True)
return images
def __len__(self):
return len(self.meta)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--meta_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=4, help="Batch size")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
args = parser.parse_args()
meta_path = args.meta_path
wo_ext, ext = os.path.splitext(meta_path)
out_path = f'{wo_ext}_flow{ext}'
# build model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UniMatch(
feature_channels=128,
num_scales=2,
upsample_factor=4,
num_head=1,
ffn_dim_expansion=4,
num_transformer_layers=6,
reg_refine=True,
task='flow',
)
# ckpt = torch.load(
# './checkpoints/pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth'
# )
# model.load_state_dict(ckpt['model'])
model = model.to(device)
model = torch.nn.DataParallel(model)
# build dataset
dataset = VideoTextDataset(meta_path=meta_path, frame_inds=[0, 10, 20, 30])
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
shuffle=False,
)
# compute optical flow scores
dataset.meta["flow"] = np.nan
index = 0
for images in tqdm(dataloader):
images = images.to(device)
B = images.shape[0]
batch_0 = rearrange(images[:, :-1], 'B N C H W -> (B N) C H W').contiguous()
batch_1 = rearrange(images[:, 1:], 'B N C H W -> (B N) C H W').contiguous()
with torch.no_grad():
res = model(
batch_0, batch_1,
attn_type='swin',
attn_splits_list=[2, 8],
corr_radius_list=[-1, 4],
prop_radius_list=[-1, 1],
num_reg_refine=6,
task='flow',
pred_bidir_flow=False,
)
flow_maps = res['flow_preds'][-1].cpu() # [B * (N-1), 2, H, W]
flow_maps = rearrange(flow_maps, '(B N) C H W -> B N H W C', B=B)
flow_scores = flow_maps.abs().mean(dim=[1, 2, 3, 4])
flow_scores_np = flow_scores.numpy()
dataset.meta.loc[index: index + B - 1, "flow"] = flow_scores_np
index += B
dataset.meta.to_csv(out_path, index=False)
print(f"New meta with optical flow scores saved to \'{out_path}\'.")
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
from .unimatch import UniMatch

View file

@ -0,0 +1,253 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
def single_head_full_attention(q, k, v):
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
attn = torch.softmax(scores, dim=2) # [B, L, L]
out = torch.matmul(attn, v) # [B, L, C]
return out
def single_head_full_attention_1d(q, k, v,
h=None,
w=None,
):
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scale_factor = c ** 0.5
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
return out
def single_head_split_window_attention(q, k, v,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
b_new = b * num_splits * num_splits
window_size_h = h // num_splits
window_size_w = w // num_splits
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_h = window_size_h // 2
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
k = split_feature(k, num_splits=num_splits, channel_last=True)
v = split_feature(v, num_splits=num_splits, channel_last=True)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
if with_shift:
scores += attn_mask.repeat(b, 1, 1)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
num_splits=num_splits, channel_last=True) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
out = out.view(b, -1, c)
return out
def single_head_split_window_attention_1d(q, k, v,
relative_position_bias=None,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
b_new = b * num_splits * h
window_size_w = w // num_splits
q = q.view(b * h, w, c) # [B*H, W, C]
k = k.view(b * h, w, c)
v = v.view(b * h, w, c)
scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=-shift_size_w, dims=1)
k = torch.roll(k, shifts=-shift_size_w, dims=1)
v = torch.roll(v, shifts=-shift_size_w, dims=1)
q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
k = split_feature_1d(k, num_splits=num_splits)
v = split_feature_1d(v, num_splits=num_splits)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*H*K, W/K, W/K]
if with_shift:
# attn_mask: [K, W/K, W/K]
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=shift_size_w, dims=2)
out = out.view(b, -1, c)
return out
class SelfAttnPropagation(nn.Module):
"""
flow propagation with self-attention on feature
query: feature0, key: feature0, value: flow
"""
def __init__(self, in_channels,
**kwargs,
):
super(SelfAttnPropagation, self).__init__()
self.q_proj = nn.Linear(in_channels, in_channels)
self.k_proj = nn.Linear(in_channels, in_channels)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feature0, flow,
local_window_attn=False,
local_window_radius=1,
**kwargs,
):
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
if local_window_attn:
return self.forward_local_window_attn(feature0, flow,
local_window_radius=local_window_radius)
b, c, h, w = feature0.size()
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
# a note: the ``correct'' implementation should be:
# ``query = self.q_proj(query), key = self.k_proj(query)''
# this problem is observed while cleaning up the code
# however, this doesn't affect the performance since the projection is a linear operation,
# thus the two projection matrices for key can be merged
# so I just leave it as is in order to not re-train all models :)
query = self.q_proj(query) # [B, H*W, C]
key = self.k_proj(query) # [B, H*W, C]
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, value) # [B, H*W, 2]
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
return out
def forward_local_window_attn(self, feature0, flow,
local_window_radius=1,
):
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
assert local_window_radius > 0
b, c, h, w = feature0.size()
value_channel = flow.size(1)
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
kernel_size = 2 * local_window_radius + 1
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
flow_window = F.unfold(flow, kernel_size=kernel_size,
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
return out

View file

@ -0,0 +1,117 @@
import torch.nn as nn
from .trident_conv import MultiScaleTridentConv
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
dilation=dilation, padding=dilation, stride=stride, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
dilation=dilation, padding=dilation, bias=False)
self.relu = nn.ReLU(inplace=True)
self.norm1 = norm_layer(planes)
self.norm2 = norm_layer(planes)
if not stride == 1 or in_planes != planes:
self.norm3 = norm_layer(planes)
if stride == 1 and in_planes == planes:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class CNNEncoder(nn.Module):
def __init__(self, output_dim=128,
norm_layer=nn.InstanceNorm2d,
num_output_scales=1,
**kwargs,
):
super(CNNEncoder, self).__init__()
self.num_branch = num_output_scales
feature_dims = [64, 96, 128]
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
self.norm1 = norm_layer(feature_dims[0])
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = feature_dims[0]
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
# highest resolution 1/4 or 1/8
stride = 2 if num_output_scales == 1 else 1
self.layer3 = self._make_layer(feature_dims[2], stride=stride,
norm_layer=norm_layer,
) # 1/4 or 1/8
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
if self.num_branch > 1:
if self.num_branch == 4:
strides = (1, 2, 4, 8)
elif self.num_branch == 3:
strides = (1, 2, 4)
elif self.num_branch == 2:
strides = (1, 2)
else:
raise ValueError
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
kernel_size=3,
strides=strides,
paddings=1,
num_branch=self.num_branch,
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x) # 1/2
x = self.layer2(x) # 1/4
x = self.layer3(x) # 1/8 or 1/4
x = self.conv2(x)
if self.num_branch > 1:
out = self.trident_conv([x] * self.num_branch) # high to low res
else:
out = [x]
return out

View file

@ -0,0 +1,195 @@
import torch
import torch.nn.functional as F
def coords_grid(b, h, w, homogeneous=False, device=None):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
stacks = [x, y]
if homogeneous:
ones = torch.ones_like(x) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(device)
return grid
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
torch.linspace(h_min, h_max, len_h, device=device)],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
return grid
def normalize_coords(coords, h, w):
# coords: [B, H, W, 2]
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
# img: [B, C, H, W]
# sample_coords: [B, 2, H, W] in image scale
if sample_coords.size(1) != 2: # [B, H, W, 2]
sample_coords = sample_coords.permute(0, 3, 1, 2)
b, _, h, w = sample_coords.shape
# Normalize to [-1, 1]
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
if return_mask:
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
return img, mask
return img
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
b, c, h, w = feature.size()
assert flow.size(1) == 2
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
return bilinear_sample(feature, grid, padding_mode=padding_mode,
return_mask=mask)
def forward_backward_consistency_check(fwd_flow, bwd_flow,
alpha=0.01,
beta=0.5
):
# fwd_flow, bwd_flow: [B, 2, H, W]
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
threshold = alpha * flow_mag + beta
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
bwd_occ = (diff_bwd > threshold).float()
return fwd_occ, bwd_occ
def back_project(depth, intrinsics):
# Back project 2D pixel coords to 3D points
# depth: [B, H, W]
# intrinsics: [B, 3, 3]
b, h, w = depth.shape
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W]
return points
def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
# Transform 3D points from reference camera to target camera
# points_ref: [B, 3, H, W]
# extrinsics_ref: [B, 4, 4]
# extrinsics_tgt: [B, 4, 4]
# extrinsics_rel: [B, 4, 4], relative pose transform
b, _, h, w = points_ref.shape
if extrinsics_rel is None:
extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W]
points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
return points_tgt
def reproject(points_tgt, intrinsics, return_mask=False):
# reproject to target view
# points_tgt: [B, 3, H, W]
# intrinsics: [B, 3, 3]
b, _, h, w = points_tgt.shape
proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
X = proj_points[:, 0]
Y = proj_points[:, 1]
Z = proj_points[:, 2].clamp(min=1e-3)
pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale
if return_mask:
# valid mask in pixel space
mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & (
pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W]
return pixel_coords, mask
return pixel_coords
def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
return_mask=False):
# Compute reprojection sample coords
points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
if return_mask:
reproj_coords, mask = reproject(points_tgt, intrinsics,
return_mask=return_mask) # [B, 2, H, W] in image scale
return reproj_coords, mask
reproj_coords = reproject(points_tgt, intrinsics,
return_mask=return_mask) # [B, 2, H, W] in image scale
return reproj_coords
def compute_flow_with_depth_pose(depth_ref, intrinsics,
extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
return_mask=False):
b, h, w = depth_ref.shape
coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
if return_mask:
reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
extrinsics_rel=extrinsics_rel,
return_mask=return_mask) # [B, 2, H, W]
rigid_flow = reproj_coords - coords_init
return rigid_flow, mask
reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
extrinsics_rel=extrinsics_rel,
return_mask=return_mask) # [B, 2, H, W]
rigid_flow = reproj_coords - coords_init
return rigid_flow

View file

@ -0,0 +1,279 @@
import torch
import torch.nn.functional as F
from .geometry import coords_grid, generate_window_grid, normalize_coords
def global_correlation_softmax(feature0, feature1,
pred_bidir_flow=False,
):
# global correlation
b, c, h, w = feature0.shape
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
# flow from softmax
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
if pred_bidir_flow:
correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
b = b * 2
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
flow = correspondence - init_grid
return flow, prob
def local_correlation_softmax(feature0, feature1, local_radius,
padding_mode='zeros',
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
local_h = 2 * local_radius + 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(-local_radius, local_radius,
-local_radius, local_radius,
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
sample_coords_softmax = sample_coords
# exclude coords that are out of image space
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(feature1, sample_coords_norm,
padding_mode=padding_mode, align_corners=True
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
# mask invalid locations
corr[~valid] = -1e9
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
flow = correspondence - coords_init
match_prob = prob
return flow, match_prob
def local_correlation_with_flow(feature0, feature1,
flow,
local_radius,
padding_mode='zeros',
dilation=1,
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
local_h = 2 * local_radius + 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(-local_radius, local_radius,
-local_radius, local_radius,
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
# flow can be zero when using features after transformer
if not isinstance(flow, float):
sample_coords = sample_coords + flow.view(
b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2]
else:
assert flow == 0.
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(feature1, sample_coords_norm,
padding_mode=padding_mode, align_corners=True
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
return corr
def global_correlation_softmax_stereo(feature0, feature1,
):
# global correlation on horizontal direction
b, c, h, w = feature0.shape
x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W]
# mask subsequent positions to make disparity positive
mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
correlation[~valid_mask] = -1e9
prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
# NOTE: unlike flow, disparity is typically positive
disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
return disparity.unsqueeze(1), prob # feature resolution
def local_correlation_softmax_stereo(feature0, feature1, local_radius,
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
local_h = 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(0, 0,
-local_radius, local_radius,
local_h, local_w, device=feature0.device) # [1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
sample_coords_softmax = sample_coords
# exclude coords that are out of image space
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(feature1, sample_coords_norm,
padding_mode='zeros', align_corners=True
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)]
feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)]
# mask invalid locations
corr[~valid] = -1e9
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
correspondence = torch.matmul(prob.unsqueeze(-2),
sample_coords_softmax).squeeze(-2).view(
b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
flow = correspondence - coords_init # flow at feature resolution
match_prob = prob
flow_x = -flow[:, :1] # [B, 1, H, W]
return flow_x, match_prob
def correlation_softmax_depth(feature0, feature1,
intrinsics,
pose,
depth_candidates,
depth_from_argmax=False,
pred_bidir_depth=False,
):
b, c, h, w = feature0.size()
assert depth_candidates.dim() == 4 # [B, D, H, W]
scale_factor = c ** 0.5
if pred_bidir_depth:
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
intrinsics = intrinsics.repeat(2, 1, 1)
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
# depth candidates are actually inverse depth
warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose,
1. / depth_candidates,
) # [B, C, D, H, W]
correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
# for cross-task transfer (flow -> depth), extract depth with argmax at test time
if depth_from_argmax:
index = torch.argmax(match_prob, dim=1, keepdim=True)
depth = torch.gather(depth_candidates, dim=1, index=index)
else:
depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
return depth, match_prob
def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth,
clamp_min_depth=1e-3,
):
"""
feature1: [B, C, H, W]
intrinsics: [B, 3, 3]
pose: [B, 4, 4]
depth: [B, D, H, W]
"""
assert intrinsics.size(1) == intrinsics.size(2) == 3
assert pose.size(1) == pose.size(2) == 4
assert depth.dim() == 4
b, d, h, w = depth.size()
c = feature1.size(1)
with torch.no_grad():
# pixel coordinates
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
# back project to 3D and transform viewpoint
points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(
1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W]
points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
# reproject to 2D image plane
points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W]
pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
# normalize to [-1, 1]
x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
# sample features
warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear',
padding_mode='zeros',
align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W]
return warped_feature

View file

@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
import torch
import torch.nn as nn
import math
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x):
# x = tensor_list.tensors # [B, C, H, W]
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
b, c, h, w = x.size()
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos

View file

@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256,
out_dim=2,
):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv2(self.relu(self.conv1(x)))
return out
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128,
kernel_size=5,
):
padding = (kernel_size - 1) // 2
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, corr_channels=324,
flow_channels=2,
):
super(BasicMotionEncoder, self).__init__()
self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicUpdateBlock(nn.Module):
def __init__(self, corr_channels=324,
hidden_dim=128,
context_dim=128,
downsample_factor=8,
flow_dim=2,
bilinear_up=False,
):
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(corr_channels=corr_channels,
flow_channels=flow_dim,
)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
out_dim=flow_dim,
)
if bilinear_up:
self.mask = None
else:
self.mask = nn.Sequential(
nn.Conv2d(hidden_dim, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
if self.mask is not None:
mask = self.mask(net)
else:
mask = None
return net, mask, delta_flow

View file

@ -0,0 +1,294 @@
import torch
import torch.nn as nn
from .attention import (single_head_full_attention, single_head_split_window_attention,
single_head_full_attention_1d, single_head_split_window_attention_1d)
from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
class TransformerLayer(nn.Module):
def __init__(self,
d_model=128,
nhead=1,
no_ffn=False,
ffn_dim_expansion=4,
):
super(TransformerLayer, self).__init__()
self.dim = d_model
self.nhead = nhead
self.no_ffn = no_ffn
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.merge = nn.Linear(d_model, d_model, bias=False)
self.norm1 = nn.LayerNorm(d_model)
# no ffn after self-attn, with ffn after cross-attn
if not self.no_ffn:
in_channels = d_model * 2
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
nn.GELU(),
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, source, target,
height=None,
width=None,
shifted_window_attn_mask=None,
shifted_window_attn_mask_1d=None,
attn_type='swin',
with_shift=False,
attn_num_splits=None,
):
# source, target: [B, L, C]
query, key, value = source, target, target
# for stereo: 2d attn in self-attn, 1d attn in cross-attn
is_self_attn = (query - key).abs().max() < 1e-6
# single-head attention
query = self.q_proj(query) # [B, L, C]
key = self.k_proj(key) # [B, L, C]
value = self.v_proj(value) # [B, L, C]
if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d
if self.nhead > 1:
# we observe that multihead attention slows down the speed and increases the memory consumption
# without bringing obvious performance gains and thus the implementation is removed
raise NotImplementedError
else:
message = single_head_split_window_attention(query, key, value,
num_splits=attn_num_splits,
with_shift=with_shift,
h=height,
w=width,
attn_mask=shifted_window_attn_mask,
)
elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d
if self.nhead > 1:
raise NotImplementedError
else:
if is_self_attn:
if attn_num_splits > 1:
message = single_head_split_window_attention(query, key, value,
num_splits=attn_num_splits,
with_shift=with_shift,
h=height,
w=width,
attn_mask=shifted_window_attn_mask,
)
else:
# full 2d attn
message = single_head_full_attention(query, key, value) # [N, L, C]
else:
# cross attn 1d
message = single_head_full_attention_1d(query, key, value,
h=height,
w=width,
)
elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d
if self.nhead > 1:
raise NotImplementedError
else:
if is_self_attn:
if attn_num_splits > 1:
# self attn shift window
message = single_head_split_window_attention(query, key, value,
num_splits=attn_num_splits,
with_shift=with_shift,
h=height,
w=width,
attn_mask=shifted_window_attn_mask,
)
else:
# full 2d attn
message = single_head_full_attention(query, key, value) # [N, L, C]
else:
if attn_num_splits > 1:
assert shifted_window_attn_mask_1d is not None
# cross attn 1d shift
message = single_head_split_window_attention_1d(query, key, value,
num_splits=attn_num_splits,
with_shift=with_shift,
h=height,
w=width,
attn_mask=shifted_window_attn_mask_1d,
)
else:
message = single_head_full_attention_1d(query, key, value,
h=height,
w=width,
)
else:
message = single_head_full_attention(query, key, value) # [B, L, C]
message = self.merge(message) # [B, L, C]
message = self.norm1(message)
if not self.no_ffn:
message = self.mlp(torch.cat([source, message], dim=-1))
message = self.norm2(message)
return source + message
class TransformerBlock(nn.Module):
"""self attention + cross attention + FFN"""
def __init__(self,
d_model=128,
nhead=1,
ffn_dim_expansion=4,
):
super(TransformerBlock, self).__init__()
self.self_attn = TransformerLayer(d_model=d_model,
nhead=nhead,
no_ffn=True,
ffn_dim_expansion=ffn_dim_expansion,
)
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
nhead=nhead,
ffn_dim_expansion=ffn_dim_expansion,
)
def forward(self, source, target,
height=None,
width=None,
shifted_window_attn_mask=None,
shifted_window_attn_mask_1d=None,
attn_type='swin',
with_shift=False,
attn_num_splits=None,
):
# source, target: [B, L, C]
# self attention
source = self.self_attn(source, source,
height=height,
width=width,
shifted_window_attn_mask=shifted_window_attn_mask,
attn_type=attn_type,
with_shift=with_shift,
attn_num_splits=attn_num_splits,
)
# cross attention and ffn
source = self.cross_attn_ffn(source, target,
height=height,
width=width,
shifted_window_attn_mask=shifted_window_attn_mask,
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
attn_type=attn_type,
with_shift=with_shift,
attn_num_splits=attn_num_splits,
)
return source
class FeatureTransformer(nn.Module):
def __init__(self,
num_layers=6,
d_model=128,
nhead=1,
ffn_dim_expansion=4,
):
super(FeatureTransformer, self).__init__()
self.d_model = d_model
self.nhead = nhead
self.layers = nn.ModuleList([
TransformerBlock(d_model=d_model,
nhead=nhead,
ffn_dim_expansion=ffn_dim_expansion,
)
for i in range(num_layers)])
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feature0, feature1,
attn_type='swin',
attn_num_splits=None,
**kwargs,
):
b, c, h, w = feature0.shape
assert self.d_model == c
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
# 2d attention
if 'swin' in attn_type and attn_num_splits > 1:
# global and refine use different number of splits
window_size_h = h // attn_num_splits
window_size_w = w // attn_num_splits
# compute attn mask once
shifted_window_attn_mask = generate_shift_window_attn_mask(
input_resolution=(h, w),
window_size_h=window_size_h,
window_size_w=window_size_w,
shift_size_h=window_size_h // 2,
shift_size_w=window_size_w // 2,
device=feature0.device,
) # [K*K, H/K*W/K, H/K*W/K]
else:
shifted_window_attn_mask = None
# 1d attention
if 'swin1d' in attn_type and attn_num_splits > 1:
window_size_w = w // attn_num_splits
# compute attn mask once
shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
input_w=w,
window_size_w=window_size_w,
shift_size_w=window_size_w // 2,
device=feature0.device,
) # [K, W/K, W/K]
else:
shifted_window_attn_mask_1d = None
# concat feature0 and feature1 in batch dimension to compute in parallel
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
for i, layer in enumerate(self.layers):
concat0 = layer(concat0, concat1,
height=h,
width=w,
attn_type=attn_type,
with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1,
attn_num_splits=attn_num_splits,
shifted_window_attn_mask=shifted_window_attn_mask,
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
)
# update feature1
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
# reshape back
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
return feature0, feature1

View file

@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
class MultiScaleTridentConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
strides=1,
paddings=0,
dilations=1,
dilation=1,
groups=1,
num_branch=1,
test_branch_idx=-1,
bias=False,
norm=None,
activation=None,
):
super(MultiScaleTridentConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.num_branch = num_branch
self.stride = _pair(stride)
self.groups = groups
self.with_bias = bias
self.dilation = dilation
if isinstance(paddings, int):
paddings = [paddings] * self.num_branch
if isinstance(dilations, int):
dilations = [dilations] * self.num_branch
if isinstance(strides, int):
strides = [strides] * self.num_branch
self.paddings = [_pair(padding) for padding in paddings]
self.dilations = [_pair(dilation) for dilation in dilations]
self.strides = [_pair(stride) for stride in strides]
self.test_branch_idx = test_branch_idx
self.norm = norm
self.activation = activation
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
if self.bias is not None:
nn.init.constant_(self.bias, 0)
def forward(self, inputs):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
assert len(inputs) == num_branch
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
for input, stride, padding in zip(inputs, self.strides, self.paddings)
]
else:
outputs = [
F.conv2d(
inputs[0],
self.weight,
self.bias,
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
self.dilation,
self.groups,
)
]
if self.norm is not None:
outputs = [self.norm(x) for x in outputs]
if self.activation is not None:
outputs = [self.activation(x) for x in outputs]
return outputs

View file

@ -0,0 +1,367 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import CNNEncoder
from .transformer import FeatureTransformer
from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow,
global_correlation_softmax_stereo, local_correlation_softmax_stereo,
correlation_softmax_depth)
from .attention import SelfAttnPropagation
from .geometry import flow_warp, compute_flow_with_depth_pose
from .reg_refine import BasicUpdateBlock
from .utils import normalize_img, feature_add_position, upsample_flow_with_mask
class UniMatch(nn.Module):
def __init__(self,
num_scales=1,
feature_channels=128,
upsample_factor=8,
num_head=1,
ffn_dim_expansion=4,
num_transformer_layers=6,
reg_refine=False, # optional local regression refinement
task='flow',
):
super(UniMatch, self).__init__()
self.feature_channels = feature_channels
self.num_scales = num_scales
self.upsample_factor = upsample_factor
self.reg_refine = reg_refine
# CNN
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
# Transformer
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
d_model=feature_channels,
nhead=num_head,
ffn_dim_expansion=ffn_dim_expansion,
)
# propagation with self-attn
self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
if not self.reg_refine or task == 'depth':
# convex upsampling simiar to RAFT
# concat feature0 and low res flow as input
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
# thus far, all the learnable parameters are task-agnostic
if reg_refine:
# optional task-specific local regression refinement
self.refine_proj = nn.Conv2d(128, 256, 1)
self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2,
downsample_factor=upsample_factor,
flow_dim=2 if task == 'flow' else 1,
bilinear_up=task == 'depth',
)
def extract_feature(self, img0, img1):
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
# reverse: resolution from low to high
features = features[::-1]
feature0, feature1 = [], []
for i in range(len(features)):
feature = features[i]
chunks = torch.chunk(feature, 2, 0) # tuple
feature0.append(chunks[0])
feature1.append(chunks[1])
return feature0, feature1
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
is_depth=False):
if bilinear:
multiplier = 1 if is_depth else upsample_factor
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
mode='bilinear', align_corners=True) * multiplier
else:
concat = torch.cat((flow, feature), dim=1)
mask = self.upsampler(concat)
up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
is_depth=is_depth)
return up_flow
def forward(self, img0, img1,
attn_type=None,
attn_splits_list=None,
corr_radius_list=None,
prop_radius_list=None,
num_reg_refine=1,
pred_bidir_flow=False,
task='flow',
intrinsics=None,
pose=None, # relative pose transform
min_depth=1. / 0.5, # inverse depth range
max_depth=1. / 10,
num_depth_candidates=64,
depth_from_argmax=False,
pred_bidir_depth=False,
**kwargs,
):
if pred_bidir_flow:
assert task == 'flow'
if task == 'depth':
assert self.num_scales == 1 # multi-scale depth model is not supported yet
results_dict = {}
flow_preds = []
if task == 'flow':
# stereo and depth tasks have normalized img in dataloader
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
# list of features, resolution low to high
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
flow = None
if task != 'depth':
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
else:
assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1
for scale_idx in range(self.num_scales):
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
if pred_bidir_flow and scale_idx > 0:
# predicting bidirectional flow with refinement
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
feature0_ori, feature1_ori = feature0, feature1
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
if task == 'depth':
# scale intrinsics
intrinsics_curr = intrinsics.clone()
intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
if scale_idx > 0:
assert task != 'depth' # not supported for multi-scale depth model
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
if flow is not None:
assert task != 'depth'
flow = flow.detach()
if task == 'stereo':
# construct flow vector for disparity
# flow here is actually disparity
zeros = torch.zeros_like(flow) # [B, 1, H, W]
# NOTE: reverse disp, disparity is positive
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
feature1 = flow_warp(feature1, displace) # [B, C, H, W]
elif task == 'flow':
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
else:
raise NotImplementedError
attn_splits = attn_splits_list[scale_idx]
if task != 'depth':
corr_radius = corr_radius_list[scale_idx]
prop_radius = prop_radius_list[scale_idx]
# add position to features
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
# Transformer
feature0, feature1 = self.transformer(feature0, feature1,
attn_type=attn_type,
attn_num_splits=attn_splits,
)
# correlation and softmax
if task == 'depth':
# first generate depth candidates
b, _, h, w = feature0.size()
depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0)
depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h,
w) # [B, D, H, W]
flow_pred = correlation_softmax_depth(feature0, feature1,
intrinsics_curr,
pose,
depth_candidates=depth_candidates,
depth_from_argmax=depth_from_argmax,
pred_bidir_depth=pred_bidir_depth,
)[0]
else:
if corr_radius == -1: # global matching
if task == 'flow':
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
elif task == 'stereo':
flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0]
else:
raise NotImplementedError
else: # local matching
if task == 'flow':
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
elif task == 'stereo':
flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0]
else:
raise NotImplementedError
# flow or residual flow
flow = flow + flow_pred if flow is not None else flow_pred
if task == 'stereo':
flow = flow.clamp(min=0) # positive disparity
# upsample to the original resolution for supervison at training time only
if self.training:
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_bilinear)
# flow propagation with self-attn
if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0:
feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
flow = self.feature_flow_attn(feature0, flow.detach(),
local_window_attn=prop_radius > 0,
local_window_radius=prop_radius,
)
# bilinear exclude the last one
if self.training and scale_idx < self.num_scales - 1:
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
if scale_idx == self.num_scales - 1:
if not self.reg_refine:
# upsample to the original image resolution
if task == 'stereo':
flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
flow_up_pad = self.upsample_flow(flow_pad, feature0)
flow_up = -flow_up_pad[:, :1] # [B, 1, H, W]
elif task == 'depth':
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
depth_up_pad = self.upsample_flow(depth_pad, feature0,
is_depth=True).clamp(min=min_depth, max=max_depth)
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
else:
flow_up = self.upsample_flow(flow, feature0)
flow_preds.append(flow_up)
else:
# task-specific local regression refinement
# supervise current flow
if self.training:
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
assert num_reg_refine > 0
for refine_iter_idx in range(num_reg_refine):
flow = flow.detach()
if task == 'stereo':
zeros = torch.zeros_like(flow) # [B, 1, H, W]
# NOTE: reverse disp, disparity is positive
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=displace,
local_radius=4,
) # [B, (2R+1)^2, H, W]
elif task == 'depth':
if pred_bidir_depth and refine_iter_idx == 0:
intrinsics_curr = intrinsics_curr.repeat(2, 1, 1)
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori),
dim=0), torch.cat((feature1_ori,
feature0_ori), dim=0)
flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
intrinsics_curr,
extrinsics_rel=pose,
)
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=flow_from_depth,
local_radius=4,
) # [B, (2R+1)^2, H, W]
else:
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=flow,
local_radius=4,
) # [B, (2R+1)^2, H, W]
proj = self.refine_proj(feature0)
net, inp = torch.chunk(proj, chunks=2, dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(),
)
if task == 'depth':
flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth)
else:
flow = flow + residual_flow
if task == 'stereo':
flow = flow.clamp(min=0) # positive
if self.training or refine_iter_idx == num_reg_refine - 1:
if task == 'depth':
if refine_iter_idx < num_reg_refine - 1:
# bilinear upsampling
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=True)
else:
# last one convex upsampling
# NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling
# pad depth to 2 channels as flow
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
depth_up_pad = self.upsample_flow(depth_pad, feature0,
is_depth=True).clamp(min=min_depth,
max=max_depth)
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
else:
flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
if task == 'stereo':
for i in range(len(flow_preds)):
flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W]
# convert inverse depth to depth
if task == 'depth':
for i in range(len(flow_preds)):
flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W]
results_dict.update({'flow_preds': flow_preds})
return results_dict

View file

@ -0,0 +1,216 @@
import torch
import torch.nn.functional as F
from .position import PositionEmbeddingSine
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
torch.linspace(h_min, h_max, len_h, device=device)],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
return grid
def normalize_coords(coords, h, w):
# coords: [B, H, W, 2]
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def normalize_img(img0, img1):
# loaded images are in [0, 255]
# normalize by ImageNet mean and std
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
img0 = (img0 / 255. - mean) / std
img1 = (img1 / 255. - mean) / std
return img0, img1
def split_feature(feature,
num_splits=2,
channel_last=False,
):
if channel_last: # [B, H, W, C]
b, h, w, c = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
else: # [B, C, H, W]
b, c, h, w = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
return feature
def merge_splits(splits,
num_splits=2,
channel_last=False,
):
if channel_last: # [B*K*K, H/K, W/K, C]
b, h, w, c = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
else: # [B*K*K, C, H/K, W/K]
b, c, h, w = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
return merge
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
shift_size_h, shift_size_w, device=torch.device('cuda')):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# calculate attention mask for SW-MSA
h, w = input_resolution
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
h_slices = (slice(0, -window_size_h),
slice(-window_size_h, -shift_size_h),
slice(-shift_size_h, None))
w_slices = (slice(0, -window_size_w),
slice(-window_size_w, -shift_size_w),
slice(-shift_size_w, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
if attn_splits > 1: # add position in splited window
feature0_splits = split_feature(feature0, num_splits=attn_splits)
feature1_splits = split_feature(feature1, num_splits=attn_splits)
position = pos_enc(feature0_splits)
feature0_splits = feature0_splits + position
feature1_splits = feature1_splits + position
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
else:
position = pos_enc(feature0)
feature0 = feature0 + position
feature1 = feature1 + position
return feature0, feature1
def upsample_flow_with_mask(flow, up_mask, upsample_factor,
is_depth=False):
# convex upsampling following raft
mask = up_mask
b, flow_channel, h, w = flow.shape
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
mask = torch.softmax(mask, dim=2)
multiplier = 1 if is_depth else upsample_factor
up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
upsample_factor * w) # [B, 2, K*H, K*W]
return up_flow
def split_feature_1d(feature,
num_splits=2,
):
# feature: [B, W, C]
b, w, c = feature.size()
assert w % num_splits == 0
b_new = b * num_splits
w_new = w // num_splits
feature = feature.view(b, num_splits, w // num_splits, c
).view(b_new, w_new, c) # [B*K, W/K, C]
return feature
def merge_splits_1d(splits,
h,
num_splits=2,
):
b, w, c = splits.size()
new_b = b // num_splits // h
splits = splits.view(new_b, h, num_splits, w, c)
merge = splits.view(
new_b, h, num_splits * w, c) # [B, H, W, C]
return merge
def window_partition_1d(x, window_size_w):
"""
Args:
x: (B, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, C)
"""
B, W, C = x.shape
x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
return x
def generate_shift_window_attn_mask_1d(input_w, window_size_w,
shift_size_w, device=torch.device('cuda')):
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
w_slices = (slice(0, -window_size_w),
slice(-window_size_w, -shift_size_w),
slice(-shift_size_w, None))
cnt = 0
for w in w_slices:
img_mask[:, w, :] = cnt
cnt += 1
mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
mask_windows = mask_windows.view(-1, window_size_w)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask