forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fused_layers.py
77 lines (66 loc) Β· 3.08 KB
/
fused_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle.framework import core
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
else:
return False
if is_fused_matmul_bias_supported():
origin_linear = paddle.incubate.nn.functional.fused_linear
else:
origin_linear = paddle.nn.functional.linear
class FusedLinearWithGradAdd(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias=None, name=None):
y = origin_linear(x, weight, bias)
ctx.save_for_backward(x, weight, bias)
return y
@staticmethod
def backward(ctx, y_grad):
x, weight, bias = ctx.saved_tensor()
x_grad = paddle.matmul(y_grad, weight, transpose_y=True)
# _C_ops.fused_linear_param_grad_add(x, y_grad, dw, db, multi precision, has bias)
if bias is None:
if hasattr(weight, "main_grad"):
weight.main_grad, _ = _C_ops.fused_linear_param_grad_add(
x, y_grad, weight.main_grad, None, True, False
)
return x_grad, None
else:
if weight.grad is not None:
weight.grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.grad, None, False, False)
return x_grad, None
else:
weight_grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, False)
return x_grad, weight_grad
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
x, y_grad, weight.main_grad, bias.main_grad, True
)
return x_grad, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.grad, bias.grad, False)
return x_grad, None, None
else:
weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False)
return x_grad, weight_grad, bias_grad
def mock_layers():
paddle.nn.functional.linear = FusedLinearWithGradAdd.apply
if is_fused_matmul_bias_supported():
paddle.incubate.nn.functional.fused_linear = FusedLinearWithGradAdd.apply