-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_conv2d_v2.py
89 lines (73 loc) · 3.75 KB
/
my_conv2d_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Online lecture: Basics of PyTorch autograd
Demonstrate custom implementation #2 of forward and backward propagation of Conv2d
"""
import torch
from torch.autograd import Function
class MyConv2d_v2(Function):
"""
Version #2 of our own custom autograd Functions of MyConv2d by subclassing
torch.autograd.Function and overrdie the forward and backward passes
"""
@staticmethod
def forward(ctx, inX, in_weight, in_bias=None, convparam=None):
""" override the forward function """
# note: for demo purpose, assume dilation=1 and padding_mode='zeros',
# also assume the padding and stride is the same for ROWS and COLS, respectively
if convparam is not None:
padding, stride = convparam
else:
padding, stride = 0, 1
nOutCh, nInCh, nKnRows, nKnCols = in_weight.shape
nImgSamples, nInCh, nInImgRows, nInImgCols = inX.shape
paddedX = torch.zeros((nImgSamples, nInCh, nInImgRows+2*padding, nInImgCols+2*padding), dtype=inX.dtype)
paddedX[:,:,padding:nInImgRows+padding,padding:nInImgCols+padding] = inX
# determine the output shape
nOutRows = (nInImgRows + 2 * padding - nKnRows) // stride + 1
nOutCols = (nInImgCols + 2 * padding - nKnCols) // stride + 1
out = torch.zeros(nImgSamples, nOutCh, nOutRows, nOutCols, dtype=inX.dtype)
for outCh in range(nOutCh):
for iRow in range(nOutRows):
startRow = iRow * stride
for iCol in range(nOutCols):
startCol = iCol * stride
out[:, outCh, iRow, iCol] = \
(paddedX[:,:,startRow:startRow+nKnRows,startCol:startCol+nKnCols] \
* in_weight[outCh,:,0:nKnRows,0:nKnCols]).sum(axis=(1,2,3))
if in_bias is not None:
out += in_bias.view(1, -1, 1, 1)
ctx.parameters = (padding, stride)
ctx.save_for_backward(paddedX, in_weight, in_bias)
return out
@staticmethod
def backward(ctx, grad_from_upstream):
"""
override the backward function. It receives a Tensor containing the gradient of the loss
with respect to the output of the custom forward pass, and calculates the gradients of the loss
with respect to each of the inputs of the custom forward pass.
"""
print('Performing custom backward of MyConv2d_v2')
padding, stride = ctx.parameters
paddedX, in_weight, in_bias = ctx.saved_tensors
nImgSamples, nInCh, nPadImgRows, nPadImgCols = paddedX.shape
nOutCh, nInCh, nKnRows, nKnCols = in_weight.shape
nImgSamples, nOutCh, nOutRows, nOutCols = grad_from_upstream.shape
grad_padX = torch.zeros_like(paddedX)
grad_weight = torch.zeros_like(in_weight)
for outCh in range(nOutCh):
for iRow in range(nOutRows):
startRow = iRow * stride
for iCol in range(nOutCols):
startCol = iCol * stride
grad_padX[:,:,startRow:startRow+nKnRows,startCol:startCol+nKnCols] += \
grad_from_upstream[:, outCh, iRow, iCol].reshape(-1, 1, 1, 1) * \
in_weight[outCh, :, 0:nKnRows, 0:nKnCols]
grad_weight[outCh, :, 0:nKnRows, 0:nKnCols] += \
(paddedX[:,:,startRow:startRow+nKnRows,startCol:startCol+nKnCols] * \
grad_from_upstream[:, outCh, iRow, iCol].reshape(-1, 1, 1, 1)).sum(axis=0)
grad_inputX = grad_padX[:,:,padding:nPadImgRows-padding,padding:nPadImgCols-padding]
if in_bias is not None:
grad_bias = grad_from_upstream.sum(axis=(0, 2, 3))
else:
grad_bias = None
return grad_inputX, grad_weight, grad_bias, None