注册机制

注册机制是辅助构建 config 的重要模块,也是HAT的重要组成部分。

本小节通过自定义模块的例子,为您说明如何在注册机制下在增加新的模块并在 config 中正常使用。

自定义模块

backbone 为例,这里展示一下如何开发以 mobilenet 为例的新模块。

定义一个新的backbone(如MobileNet):

新建一个新文件:hat/models/backbones/mobilenet.py

import torch.nn as nn from hat.registry import OBJECT_REGISTRY __all__ = ["MobileNet"] @OBJECT_REGISTRY.register class MobileNet(nn.Module): def __init__(self, args1, args2): pass def forward(self, x): pass

导入新定义的模块

可以在 hat/models/backbones/__init__.py 中增加导入模块的行。

from .mobilenet import MobileNet

config中使用新的backbone

model = dict( ... backbone=dict( type="MobileNet", args1=xxx, args2=xxx, ) ... )

以此类推,其他任何可注册的模块,都可以使用这种方法来完成开发和使用。