nn.MultiScaleDeformableAttention
class horizon_plugin_pytorch.nn.MultiScaleDeformableAttention(embed_dims: int = 256, num_heads: int = 8, num_levels: int = 4, num_points: int = 4, im2col_step: int = 64, dropout: float = 0.1, batch_first: bool = False, value_proj_ratio: float = 1.0)
An attention module used in Deformable-Detr.
Deformable DETR: Deformable Transformers for End-to-End Object Detection..
- Parameters:
- embed_dims – The embedding dimension of Attention.
Default: 256.
- num_heads – Parallel attention heads. Default: 8.
- num_levels – The number of feature map used in
Attention. Default: 4.
- num_points – The number of sampling points for
each query in each head. Default: 4.
- im2col_step – The step used in image_to_column.
Default: 64.
- dropout – A Dropout layer on inp_identity.
Default: 0.1.
- batch_first – Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
- value_proj_ratio – The expansion ratio of value_proj.
Default: 1.0.
forward(query: Tensor | QTensor, key: Tensor | QTensor | None = None, value: Tensor | QTensor | None = None, identity: Tensor | QTensor | None = None, query_pos: Tensor | QTensor | None = None, key_padding_mask: Tensor | None = None, reference_points: Tensor | QTensor | None = None, spatial_shapes: Tensor | None = None)
Forward Function of MultiScaleDeformAttention.
- Parameters:
- query – Query of Transformer with shape
(num_query, bs, embed_dims).
- key – The key tensor with shape
(num_key, bs, embed_dims).
- value – The value tensor with shape
(num_key, bs, embed_dims).
- identity – The tensor used for addition, with the
same shape as query. Default None. If None,
query will be used.
- query_pos – The positional encoding for query.
Default: None.
- key_padding_mask – ByteTensor for query, with
shape [bs, num_key].
- reference_points – The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (bs, num_query, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
- spatial_shapes – Spatial shape of features in
different levels. int tensor with shape (num_levels, 2),
last dimension represents (h, w).
- Returns:
the same shape with query.
- Return type:
Tensor