import torch import torch.nn as nn import torch.nn.functional as F #class CustomModel(nn.Module): # def __init__(self, BN, num_query, DC): # super(CustomModel, self).__init__() # self.BN = BN # self.num_query = num_query # self.DC = DC # # def forward(self, x): # # Step 1: Clamp and floor the input # x = torch.clip(torch.floor(x), 0, self.DC - 1) # # # Step 2: One-hot encoding with scatter # out = torch.zeros_like(x).repeat(1, 1, 1, 1, self.DC) # out.scatter_(x.ndim-1, x.to(torch.long), 1) # # # Step 3: Reshape the output # out = out.view(self.BN, self.num_query, self.DC) # return out class CustomModel(nn.Module): def __init__(self): super(CustomModel, self).__init__() def bilinear_grid_sample(self, input, grid, align_corners=False): n, c, h, w = input.shape gn, gh, gw, _ = grid.shape assert n == gn x = grid[:, :, :, 0] y = grid[:, :, :, 1] if align_corners: x = ((x + 1) / 2) * (w - 1) y = ((y + 1) / 2) * (h - 1) else: x = ((x + 1) * w - 1) / 2 y = ((y + 1) * h - 1) / 2 x = x.contiguous().view(n, -1) y = y.contiguous().view(n, -1) x0 = torch.floor(x).long() y0 = torch.floor(y).long() x1 = x0 + 1 y1 = y0 + 1 wa = ((x1 - x) * (y1 - y)).unsqueeze(1) wb = ((x1 - x) * (y - y0)).unsqueeze(1) wc = ((x - x0) * (y1 - y)).unsqueeze(1) wd = ((x - x0) * (y - y0)).unsqueeze(1) im_padded = F.pad(input, pad=[1, 1, 1, 1], mode='constant', value=0) padded_h = h + 2 padded_w = w + 2 x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 device = torch.device("cpu") x0 = torch.where(x0 < 0, torch.tensor(0).to(device), x0) x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x0) x1 = torch.where(x1 < 0, torch.tensor(0).to(device), x1) x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x1) y0 = torch.where(y0 < 0, torch.tensor(0).to(device), y0) y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y0) y1 = torch.where(y1 < 0, torch.tensor(0).to(device), y1) y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y1) im_padded = im_padded.view(n, c, -1) x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) Ia = torch.gather(im_padded, 2, x0_y0) Ib = torch.gather(im_padded, 2, x0_y1) Ic = torch.gather(im_padded, 2, x1_y0) Id = torch.gather(im_padded, 2, x1_y1) return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) def forward(self, x): grid = torch.randn(2, 15360, 1, 2) out = self.bilinear_grid_sample(x, grid, True) return out # Initialize the model model = CustomModel() # Example input tensor input_tensor = torch.randn(2, 64, 16, 30) # Export the model to ONNX onnx_file = "custom_model.onnx" torch.onnx.export( model, input_tensor, onnx_file, export_params=True, opset_version=11, input_names=['input'], output_names=['output'] ) print(f"Model exported to {onnx_file}")