import torch import torch.nn as nn class CustomModel(nn.Module): def __init__(self, BN, num_query, DC): super(CustomModel, self).__init__() self.BN = BN self.num_query = num_query self.DC = DC def forward(self, x): x = torch.clip(torch.floor(x), 0, self.DC - 1) out = torch.zeros_like(x).repeat(1, 1, 1, 1, self.DC) out.scatter_(-1, x.to(torch.long), 1) out = out.view(self.BN, self.num_query, self.DC) return out # Model parameters BN = 3 DC = 16 num_query = 15360 # Derived from input shape # Initialize the model model = CustomModel(BN=BN, num_query=num_query, DC=DC) # Example input tensor input_tensor = torch.randn(1, 1, 3, 15360, 1) # Export the model to ONNX onnx_file = "custom_model.onnx" torch.onnx.export( model, input_tensor, onnx_file, export_params=True, opset_version=11, input_names=['input'], output_names=['output'] ) print(f"Model exported to {onnx_file}")