Transformer P6 位置编码层优化版代码实现
上节课当中,带大家用最直观的方法实现了 Transformer 中的位置编码,在实现过程中,用到两层 for 循环,去逐个修改矩阵中各个元素的值,效率很低,所以这节课给大家补充一种更高效的实现方法。
代码示例
1、张量乘法自动广播
a = torch.tensor([ [1], [2], [3], ]) b = torch.tensor([4, 5]) print(a*b)
PyTorch 会自动广播张量 b,扩展为(3, 2),然后每一行和 a 的对应行相乘。
2、直接套公式计算角度值
d_model = 8 # 位置 position = torch.arange(0, 10).unsqueeze(1) # 除法,看成乘以除数分之一 i_2 = torch.arange(0, d_model, 2) div_term = 1 / 10000 ** (i_2 / d_model) # 角度值 angle = position * div_term print(angle)
3、公式变形
div_term = 1 / 10000 ** (i_2 / d_model) div_term = torch.exp(math.log(1 / 10000 ** (i_2 / d_model))) div_term = torch.exp(-math.log(10000 ** (i_2 / d_model))) div_term = torch.exp(-math.log(10000) * i_2 / d_model) div_term = torch.exp(i_2 * -math.log(10000) / d_model) div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000) / d_model)
4、封装位置编码层
class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) # 位置和除数 position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000) / d_model) # 修改pe矩阵的值 pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) # 扩展 batch 维度 pe = pe.unsqueeze(0) # 存储为不需要计算梯度的参数 self.register_buffer('pe', pe) def forward(self, x): x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) return self.dropout(x)
以上这个类的封装代码,来源于 annotated-transformer 开源项目,利用张量乘法,避免了两层 for 循环操作,提高了运算效率。大家在自己项目中,涉及多层嵌套的结构,也可以参考这个方法。
本文链接:http://ichenhua.cn/edu/note/653
版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!