import torch import torch.nn as nn import torch.onnx # 定义一个简单的模型 class SimpleModel(nn.Module): def __init__(self, input_size, hidden_size): super(SimpleModel, self).__init__() self.weight = nn.Parameter(torch.randn(hidden_size, input_size)) self.bias = nn.Parameter(torch.randn(hidden_size)) def forward(self, x): _, indices0 = x.max(dim=-1) x = torch.matmul(x, self.weight.t()) + self.bias # 取最大值的索引 _, indices1 = x.max(dim=-1) return torch.cat((indices0, indices1), dim=0) # 模型参数 input_size = 10 # 输入特征维度 hidden_size = 5 # 隐藏层维度 batch_size = 1 # 批大小 # 创建模型实例 model = SimpleModel(input_size, hidden_size) # 创建一个随机输入张量 dummy_input = torch.randn(batch_size, input_size) # 计算模型输出 output = model(dummy_input) print("PyTorch模型输出:", output) # 导出ONNX模型 onnx_file_path = "/data/simple_model.onnx" torch.onnx.export( model, # 要导出的模型 dummy_input, # 模型的输入 onnx_file_path, # 导出的ONNX模型文件路径 input_names=["input"],# 输入名称 output_names=["output"],# 输出名称 dynamic_axes={ "input": {0: "batch_size"}, # 批大小是动态的 "output": {0: "batch_size"} }, opset_version=16 # ONNX opset版本 ) print("ONNX模型已导出到:", onnx_file_path)