def get_grid_quant_scale(grid_shape, view_shape):
    max_coord = max(*grid_shape, *view_shape)
    coord_bit_num = math.ceil(math.log(max_coord + 1, 2))
    coord_shift = 15 - coord_bit_num
    coord_shift = max(min(coord_shift, 8), 0)
    grid_quant_scale = 1.0 / (1 << coord_shift)
    return grid_quant_scale

def adjust_coords(coords: torch.Tensor, grid_size: Tuple[int]) -> torch.Tensor:
    """Adjust coords for hnn grid_sample.

    Args:
        coords: Coords for grid_sample.
        bev_size: Bev size.
    """

    W = grid_size[0]
    H = grid_size[1]

    bev_x = (torch.linspace(0, W - 1, W).reshape((1, W)).repeat(H, 1)).float()
    bev_y = (torch.linspace(0, H - 1, H).reshape((H, 1)).repeat(1, W)).float()

    bev_coords = torch.stack([bev_x, bev_y], axis=-1).to(device=coords.device)
    coords = coords - bev_coords #?
    return coords

def grid_sample(
    input: Tensor,
    grid: Tensor,
    mode: str,
    padding_mode: str,
    align_corners: bool,
    scale: Tensor,
    zero_point: Tensor,
    dtype: str,
    march: str,
):
    """Refine this docstring in the future.

    Given an input and a flow-field grid, computes the output using
    input values and pixel locations from grid.

    Note that the grid required by this function is DIFFERENT from
    torch.nn.functional.grid_sample !!!

    And the gradient of grid is always 0 now.

    Args:
        input (Tensor[N, C, H, W]): Input data.
        grid (Tensor[N, H_out, W_out, (dx, dy)]): Flow-field. This param
            is different with torch.nn.functional.grid_sample. In this
            function, the sample point of output point (x, y) is computed
            by (x + dx, y + dy).
        mode (str, optional): Interpolation mode to calculate output values.
            Only "bilinear" is supported now.
        padding_mode (str, optional): Padding mode for outside grid values.
            Only "zeros" is supported now.
        align_corners ([type], optional): Since the grid format is
            different with torch.nn.functional.grid_sample, this param
            does not have any effect now.
    """
    # Convert from xy to yx.
    grid_yx = torch.stack((grid[..., 1], grid[..., 0]), dim=-1)

    # Compute coord_shift.
    max_coord = max(
        max(input.size(2), input.size(3)), max(grid.size(1), grid.size(2))
    )
    # raise Exception(max_coord)
    coord_bit_num = math.ceil(math.log(max_coord + 1, 2))
    coord_shift = 15 - coord_bit_num
    coord_shift = max(min(coord_shift, 8), 0)

    # Coord int16 quantization.
    grid_scale = torch.tensor(
        1.0 / (1 << coord_shift), dtype=torch.float, device=grid.device
    ).reshape(1)
    quant_info = qinfo("qint16")
    grid_yx = torch.ops.horizon.scale_quanti(
        grid_yx.float(),
        grid_scale,
        torch.zeros_like(grid_scale).to(dtype=torch.long),
        -1,
        quant_info.min,
        quant_info.max,
        True,
        False,
        "bpu_round",
        march,
    )

    # Convert to absolute grid.
    n, h, w, _ = grid.shape
    base_coord = torch.stack(
        [
            torch.arange(h, device=grid.device)
            .reshape(1, h, 1)
            .expand(n, h, w),
            torch.arange(w, device=grid.device)
            .reshape(1, 1, w)
            .expand(n, h, w),
        ],
        dim=-1,
    )
    absolute_grid = grid_yx + base_coord
    absolute_grid = absolute_grid.permute(0, 3, 1, 2).unsqueeze(1)

    return torch.ops.horizon.quanti_grid_sample(
        input.float(),
        absolute_grid,
        scale,
        mode,
        padding_mode,
        align_corners,
        coord_shift,
        march,
    )

