Pytorch 中的 ModuleList & Sequential
ModuleList & Sequential 都是 torch.nn中重要的容器类,是为了方便定义结构化的可复用的网络结构而产生,但是两者的功能又有略微不同。
不同点
场景不同,ModuleList 可拓展性更强,sequential更方便。
ModuleList需要可以自定义运算顺序,Sequential必须按照定义的顺序依次计算
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| class SeqNet(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Linear(10,128), nn.ReLU(), nn.Linear(128,50) ) def forward(self,x): return self.encoder(x) class ModNet(nn.Module): def __init__(self): super().__init__() self.encoder = nn.ModuleList([ nn.Linear(10,128), nn.ReLU(), nn.Linear(128,50) ]) def forward(self,x): for block in self.encoder: x = block(x) return x
|
- 可以看到ModuleList拓展性更强,如何运算完全取决于我们如何定义。
- 在一些结构化很强的模型中使用Sequential更方便,但是对于一些灵活的模型ModuleList更好
- 两者初始化的参数也有略微不同,Sequential是多个module对象,ModList是一个数组