Skip to content

Commit

Permalink
Merge pull request magicleap#85 from johnwlambert/patch-1
Browse files Browse the repository at this point in the history
Add type hints to superglue.py
  • Loading branch information
romachalm committed Feb 15, 2022
2 parents c0626d5 + 710352a commit 563109d
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions models/superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@

from copy import deepcopy
from pathlib import Path
from typing import List, Tuple

import torch
from torch import nn


def MLP(channels: list, do_bn=True):
def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
""" Multi-layer perceptron """
n = len(channels)
layers = []
Expand All @@ -72,7 +74,7 @@ def normalize_keypoints(kpts, image_shape):

class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
def __init__(self, feature_dim: int, layers: List[int]) -> None:
super().__init__()
self.encoder = MLP([3] + layers + [feature_dim])
nn.init.constant_(self.encoder[-1].bias, 0.0)
Expand All @@ -82,7 +84,7 @@ def forward(self, kpts, scores):
return self.encoder(torch.cat(inputs, dim=1))


def attention(query, key, value):
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
dim = query.shape[1]
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
prob = torch.nn.functional.softmax(scores, dim=-1)
Expand All @@ -99,7 +101,7 @@ def __init__(self, num_heads: int, d_model: int):
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

def forward(self, query, key, value):
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
for l, x in zip(self.proj, (query, key, value))]
Expand All @@ -114,20 +116,20 @@ def __init__(self, feature_dim: int, num_heads: int):
self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)

def forward(self, x, source):
def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
message = self.attn(x, source, source)
return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list):
def __init__(self, feature_dim: int, layer_names: List[str]) -> None:
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4)
for _ in range(len(layer_names))])
self.names = layer_names

def forward(self, desc0, desc1):
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
for layer, name in zip(self.layers, self.names):
if name == 'cross':
src0, src1 = desc1, desc0
Expand All @@ -138,7 +140,7 @@ def forward(self, desc0, desc1):
return desc0, desc1


def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
""" Perform Sinkhorn Normalization in Log-space for stability"""
u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
for _ in range(iters):
Expand All @@ -147,7 +149,7 @@ def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int):
def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
one = scores.new_tensor(1)
Expand Down Expand Up @@ -209,7 +211,7 @@ def __init__(self, config):
self.config['descriptor_dim'], self.config['keypoint_encoder'])

self.gnn = AttentionalGNN(
self.config['descriptor_dim'], self.config['GNN_layers'])
feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers'])

self.final_proj = nn.Conv1d(
self.config['descriptor_dim'], self.config['descriptor_dim'],
Expand Down

0 comments on commit 563109d

Please sign in to comment.