10.3.5. 注册机制¶
注册机制是辅助构建 config 的重要模块,也是HAT的重要组成部分。
本小节通过自定义模块的例子,为您说明如何在注册机制下在增加新的模块并在 config 中正常使用。
10.3.5.1. 自定义模块¶
以 backbone 为例,这里展示一下如何开发以 mobilenet 为例的新模块。
10.3.5.1.1. 定义一个新的backbone(如MobileNet):¶
新建一个新文件:hat/models/backbones/mobilenet.py。
import torch.nn as nn
from hat.registry import OBJECT_REGISTRY
__all__ = ["MobileNet"]
@OBJECT_REGISTRY.register
class MobileNet(nn.Module):
def __init__(self, args1, args2):
pass
def forward(self, x):
pass
10.3.5.1.2. 导入新定义的模块¶
可以在 hat/models/backbones/__init__.py 中增加导入模块的行。
from .mobilenet import MobileNet
10.3.5.1.3. 在config中使用新的backbone¶
model = dict(
...
backbone=dict(
type="MobileNet",
args1=xxx,
args2=xxx,
)
...
)
以此类推,其他任何可注册的模块,都可以使用这种方法来完成开发和使用。