import torch
import torch.nn.functional as F
from torch import nn
class Mish(nn.Module):
'''
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
Examples:
>>> m = Mish()
>>> input = torch.randn(2)
>>> output = m(input)
'''
def __init__(self):
'''
Init method.
'''
super(Mish, self).__init__()
def forward(self, input):
'''
Forward pass of the function.
'''
return input * torch.tanh(F.softplus(input))
model = models.resnet50(pretrained=True, progress=True)
print("acitve", model.relu)
model.relu = Mish() ## 在pytorch的resnet50里面这样替换一下就可以了
print("acitve", model.relu)
深度学习--新的激活函数Mish
2019-11-11