nn.Correlation
class horizon_plugin_pytorch.nn.Correlation(kernel_size: int = 1, max_displacement: int = 1, stride1: int = 1, stride2: int = 1, pad_size: int = 0, is_multiply: bool = True)
Perform multiplicative patch comparisons between two feature maps.
Correlation performs multiplicative patch comparisons
between two feature maps. Given two multi-channel feature maps
f1,f2, with w, h, and c
being their width, height, and number of channels, the correlation
layer lets the network compare each patch from f1
with each patch from f2.
For now we consider only a single comparison of two patches.
The ‘correlation’ of two patches centered at x1
in the first map and x2 in the second map is then defined as:
c(x1,x2)=o∈[−k,k]×[−k,k]∑<f1(x1+o),f2(x2+o)>
for a square patch of size K:=2k+1.
Note that the equation above is identical to one step of a convolution in
neural networks, but instead of convolving data with a filter, it convolves
data with other data. For this reason, it has no training weights.
Computing c(x1,x2) involves c∗K2
multiplications. Comparing all patch combinations involves
w2∗h2 such computations.
Given a maximum displacement d, for each location x1
it computes correlations c(x1,x2) only in a neighborhood
of size D:=2d+1, by limiting the range of x2.
We use strides s1,s2, to quantize x1 globally
and to quantize x2 within the neighborhood centered
around x1.
The final output is defined by the following expression:
out[n,q,i,j]=c(xi,j,xq)
where i and j enumerate spatial locations in f1,
and q denotes the qth neighborhood of xi,j.
- Parameters:
- kernel_size – kernel size for Correlation must be an odd number
- max_displacement – Max displacement of Correlation
- stride1 – stride1 quantize data1 globally
- stride2 – stride2 quantize data2 within neighborhood
centered around data1
- pad_size – pad for Correlation
- is_multiply – operation type is either multiplication
or subduction, only support True now
forward(data1: Tensor | QTensor, data2: Tensor | QTensor)
Forward for Horizon Correlation.
- Parameters:
- data1 – shape of [N,C,H,W]
- data2 – shape of [N,C,H,W]
- Returns:
output
- Return type:
Tensor