文章目录
6-1P:推导RNN反向传播算法BPTT.
总结 心得体会
- 6-1P:推导RNN反向传播算法BPTT.
- 6-2P:设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.
- 总结
- 心得体会
- 参考链接
6-1P:推导RNN反向传播算法BPTT.
代码如下:
import os os.environ['KMP_DUPLICATE_LIB_OK']='True' import torch import numpy as np class RNNCell: def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): self.weight_ih = weight_ih self.weight_hh = weight_hh self.bias_ih = bias_ih self.bias_hh = bias_hh self.x_stack = [] self.dx_list = [] self.dw_ih_stack = [] self.dw_hh_stack = [] self.db_ih_stack = [] self.db_hh_stack = [] self.prev_hidden_stack = [] self.next_hidden_stack = [] # temporary cache self.prev_dh = None def __call__(self, x, prev_hidden): self.x_stack.append(x) next_h = np.tanh( np.dot(x, self.weight_ih.T) + np.dot(prev_hidden, self.weight_hh.T) + self.bias_ih + self.bias_hh) self.prev_hidden_stack.append(prev_hidden) self.next_hidden_stack.append(next_h) # clean cache self.prev_dh = np.zeros(next_h.shape) return next_h def backward(self, dh): x = self.x_stack.pop() prev_hidden = self.prev_hidden_stack.pop() next_hidden = self.next_hidden_stack.pop() d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2) self.prev_dh = np.dot(d_tanh, self.weight_hh) dx = np.dot(d_tanh, self.weight_ih) self.dx_list.insert(0, dx) dw_ih = np.dot(d_tanh.T, x) self.dw_ih_stack.append(dw_ih) dw_hh = np.dot(d_tanh.T, prev_hidden) self.dw_hh_stack.append(dw_hh) self.db_ih_stack.append(d_tanh) self.db_hh_stack.append(d_tanh) return self.dx_list if __name__ == '__main__': np.random.seed(123) torch.random.manual_seed(123) np.set_printoptions(precision=6, suppress=True) rnn_PyTorch = torch.nn.RNN(4, 5).double() rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(), rnn_PyTorch.all_weights[0][1].data.numpy(), rnn_PyTorch.all_weights[0][2].data.numpy(), rnn_PyTorch.all_weights[0][3].data.numpy()) nums = 3 x3_numpy = np.random.random((nums, 3, 4)) x3_tensor = torch.tensor(x3_numpy, requires_grad=True) h3_numpy = np.random.random((1, 3, 5)) h3_tensor = torch.tensor(h3_numpy, requires_grad=True) dh_numpy = np.random.random((nums, 3, 5)) dh_tensor = torch.tensor(dh_numpy, requires_grad=True) h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor) h_numpy_list = [] h_numpy = h3_numpy[0] for i in range(nums): h_numpy = rnn_numpy(x3_numpy[i], h_numpy) h_numpy_list.append(h_numpy) h3_tensor[0].backward(dh_tensor) for i in reversed(range(nums)): rnn_numpy.backward(dh_numpy[i]) print("numpy_hidden :n", np.array(h_numpy_list)) print("torch_hidden :n", h3_tensor[0].data.numpy()) print("-----------------------------------------------") print("dx_numpy :n", np.array(rnn_numpy.dx_list)) print("dx_torch :n", x3_tensor.grad.data.numpy()) print("------------------------------------------------") print("dw_ih_numpy :n", np.sum(rnn_numpy.dw_ih_stack, axis=0)) print("dw_ih_torch :n", rnn_PyTorch.all_weights[0][0].grad.data.numpy()) print("------------------------------------------------") print("dw_hh_numpy :n", np.sum(rnn_numpy.dw_hh_stack, axis=0)) print("dw_hh_torch :n", rnn_PyTorch.all_weights[0][1].grad.data.numpy()) print("------------------------------------------------") print("db_ih_numpy :n", np.sum(rnn_numpy.db_ih_stack, axis=(0, 1))) print("db_ih_torch :n", rnn_PyTorch.all_weights[0][2].grad.data.numpy()) print("-----------------------------------------------") print("db_hh_numpy :n", np.sum(rnn_numpy.db_hh_stack, axis=(0, 1))) print("db_hh_torch :n", rnn_PyTorch.all_weights[0][3].grad.data.numpy())
运行结果:
numpy_hidden : [[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391] [ 0.365172 -0.361254 0.426838 -0.448951 0.331553] [ 0.589187 -0.188248 0.684941 -0.45859 0.190099]] [[ 0.146213 -0.306517 0.297109 0.370957 -0.040084] [-0.009201 -0.365735 0.333659 0.486789 0.061897] [ 0.030064 -0.282985 0.42643 0.025871 0.026388]] [[ 0.225432 -0.015057 0.116555 0.080901 0.260097] [ 0.368327 0.258664 0.357446 0.177961 0.55928 ] [ 0.103317 -0.029123 0.182535 0.216085 0.264766]]] torch_hidden : [[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391] [ 0.365172 -0.361254 0.426838 -0.448951 0.331553] [ 0.589187 -0.188248 0.684941 -0.45859 0.190099]] [[ 0.146213 -0.306517 0.297109 0.370957 -0.040084] [-0.009201 -0.365735 0.333659 0.486789 0.061897] [ 0.030064 -0.282985 0.42643 0.025871 0.026388]] [[ 0.225432 -0.015057 0.116555 0.080901 0.260097] [ 0.368327 0.258664 0.357446 0.177961 0.55928 ] [ 0.103317 -0.029123 0.182535 0.216085 0.264766]]] ----------------------------------------------- dx_numpy : [[[-0.643965 0.215931 -0.476378 0.072387] [-1.221727 0.221325 -0.757251 0.092991] [-0.59872 -0.065826 -0.390795 0.037424]] [[-0.537631 -0.303022 -0.364839 0.214627] [-0.815198 0.392338 -0.564135 0.217464] [-0.931365 -0.254144 -0.561227 0.164795]] [[-1.055966 0.249554 -0.623127 0.009784] [-0.45858 0.108994 -0.240168 0.117779] [-0.957469 0.315386 -0.616814 0.205634]]] dx_torch : [[[-0.643965 0.215931 -0.476378 0.072387] [-1.221727 0.221325 -0.757251 0.092991] [-0.59872 -0.065826 -0.390795 0.037424]] [[-0.537631 -0.303022 -0.364839 0.214627] [-0.815198 0.392338 -0.564135 0.217464] [-0.931365 -0.254144 -0.561227 0.164795]] [[-1.055966 0.249554 -0.623127 0.009784] [-0.45858 0.108994 -0.240168 0.117779] [-0.957469 0.315386 -0.616814 0.205634]]] ------------------------------------------------ dw_ih_numpy : [[3.918335 2.958509 3.725173 4.157478] [1.261197 0.812825 1.10621 0.97753 ] [2.216469 1.718251 2.366936 2.324907] [3.85458 3.052212 3.643157 3.845696] [1.806807 1.50062 1.615917 1.521762]] dw_ih_torch : [[3.918335 2.958509 3.725173 4.157478] [1.261197 0.812825 1.10621 0.97753 ] [2.216469 1.718251 2.366936 2.324907] [3.85458 3.052212 3.643157 3.845696] [1.806807 1.50062 1.615917 1.521762]] ------------------------------------------------ dw_hh_numpy : [[ 2.450078 0.243735 4.269672 0.577224 1.46911 ] [ 0.421015 0.372353 0.994656 0.962406 0.518992] [ 1.079054 0.042843 2.12169 0.863083 0.757618] [ 2.225794 0.188735 3.682347 0.934932 0.955984] [ 0.660546 -0.321076 1.554888 0.833449 0.605201]] dw_hh_torch : [[ 2.450078 0.243735 4.269672 0.577224 1.46911 ] [ 0.421015 0.372353 0.994656 0.962406 0.518992] [ 1.079054 0.042843 2.12169 0.863083 0.757618] [ 2.225794 0.188735 3.682347 0.934932 0.955984] [ 0.660546 -0.321076 1.554888 0.833449 0.605201]] ------------------------------------------------ db_ih_numpy : [7.568411 2.175445 4.335336 6.820628 3.51003 ] db_ih_torch : [7.568411 2.175445 4.335336 6.820628 3.51003 ] ----------------------------------------------- db_hh_numpy : [7.568411 2.175445 4.335336 6.820628 3.51003 ] db_hh_torch : [7.568411 2.175445 4.335336 6.820628 3.51003 ] Process finished with exit code 0
总结 心得体会
本次作业手推了一遍BPTT也用numpy和torch写了一遍RNN,参考着老师的PPT以及老师给的代码,很快就做完了作业。
参考链接L5W1作业1 手把手实现循环神经网络
NNDL 作业9:分别使用numpy和pytorch实现BPTT