The code for model weight initialization is common, record it, I am too lazy to type it myself every time in the future
def _init_weights(self):
init_set = {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}
for m in self.modules():
if type(m) in init_set:
nn.init.kaiming_normal_(
m.weight.data, mode='fan_out', nonlinearity='relu', a=0
)
if m.bias is not None:
fan_in, fan_out = \
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
# nn.init.constant_(self.unet.last.bias, -4)
# nn.init.constant_(self.unet.last.bias, 4)