From 6b2495cc243f2d8e829523b700f32db1f5d50f78 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 15 Feb 2022 21:40:35 +0800 Subject: [PATCH] add reshape before and after pooling 123d with no batch dimension (#3566) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/ir.cpp | 9 ++ tools/pnnx/src/ir.h | 2 + tools/pnnx/src/pass_ncnn.cpp | 3 + .../src/pass_ncnn/insert_reshape_pooling.cpp | 106 ++++++++++++++++++ .../src/pass_ncnn/insert_reshape_pooling.h | 25 +++++ tools/pnnx/tests/ncnn/test_F_avg_pool2d.py | 17 ++- tools/pnnx/tests/ncnn/test_F_avg_pool3d.py | 17 ++- tools/pnnx/tests/ncnn/test_F_max_pool1d.py | 16 ++- tools/pnnx/tests/ncnn/test_F_max_pool2d.py | 16 ++- tools/pnnx/tests/ncnn/test_F_max_pool3d.py | 16 ++- tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py | 16 ++- tools/pnnx/tests/ncnn/test_nn_AvgPool3d.py | 16 ++- tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py | 17 ++- tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py | 17 ++- tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py | 17 ++- tools/pnnx/tests/test_F_avg_pool2d.py | 24 +++- tools/pnnx/tests/test_F_avg_pool3d.py | 24 +++- tools/pnnx/tests/test_F_max_pool1d.py | 26 +++-- tools/pnnx/tests/test_F_max_pool2d.py | 25 +++-- tools/pnnx/tests/test_F_max_pool3d.py | 25 +++-- tools/pnnx/tests/test_nn_AvgPool2d.py | 24 +++- tools/pnnx/tests/test_nn_AvgPool3d.py | 24 +++- tools/pnnx/tests/test_nn_MaxPool1d.py | 26 +++-- tools/pnnx/tests/test_nn_MaxPool2d.py | 25 +++-- tools/pnnx/tests/test_nn_MaxPool3d.py | 25 +++-- 26 files changed, 473 insertions(+), 86 deletions(-) create mode 100644 tools/pnnx/src/pass_ncnn/insert_reshape_pooling.cpp create mode 100644 tools/pnnx/src/pass_ncnn/insert_reshape_pooling.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index f1acc01d845f..a85057a8f690 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -279,6 +279,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp pass_ncnn/fuse_innerproduct_activation.cpp pass_ncnn/fuse_transpose_matmul.cpp + pass_ncnn/insert_reshape_pooling.cpp pass_ncnn/F_adaptive_avg_pool1d.cpp pass_ncnn/F_adaptive_avg_pool2d.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 880f59032f24..8b47b6932f7b 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2419,6 +2419,15 @@ Operator* Graph::new_operator_before(const std::string& type, const std::string& return op; } +Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); + return op; +} + Operand* Graph::new_operand(const torch::jit::Value* v) { Operand* r = new Operand; diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 07e5259dc4c3..9a86176c05cf 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -218,6 +218,8 @@ class Graph Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); + Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur); + Operand* new_operand(const torch::jit::Value* v); Operand* new_operand(const std::string& name); diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index 49aabc536524..924d9b0cb992 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -37,6 +37,7 @@ #include "pass_ncnn/fuse_deconvolutiondepthwise_activation.h" #include "pass_ncnn/fuse_innerproduct_activation.h" #include "pass_ncnn/fuse_transpose_matmul.h" +#include "pass_ncnn/insert_reshape_pooling.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -73,6 +74,8 @@ void pass_ncnn(Graph& g) ncnn::chain_multi_output(g); + ncnn::insert_reshape_pooling(g); + ncnn::solve_batch_index(g); ncnn::convert_half_to_float(g); diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.cpp b/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.cpp new file mode 100644 index 000000000000..3fb25a67cc48 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.cpp @@ -0,0 +1,106 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "insert_reshape_pooling.h" +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_reshape_pooling(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "nn.MaxPool1d" && op->type != "nn.MaxPool2d" && op->type != "nn.MaxPool3d") + continue; + + int input_rank = op->inputs[0]->shape.size(); + if (input_rank == 0) + continue; + + fprintf(stderr, "insert_reshape_pooling %d\n", input_rank); + + // nn.MaxPool1d 2d-3d-2d + // nn.MaxPool2d 3d-4d-3d + // nn.MaxPool3d 4d-5d-4d + bool insert_reshape = false; + if ((op->type == "nn.MaxPool1d" && input_rank == 2) + || (op->type == "nn.MaxPool2d" && input_rank == 3) + || (op->type == "nn.MaxPool3d" && input_rank == 4)) + { + insert_reshape = true; + } + + if (!insert_reshape) + continue; + + matched = true; + + Operand* pooling_in = op->inputs[0]; + Operand* pooling_out = op->outputs[0]; + + Operator* reshape0 = graph.new_operator_before("Tensor.reshape", op->name + "_ncnnreshape0", op); + Operator* reshape1 = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape1", op); + + Operand* reshape0_out = graph.new_operand(op->name + "_ncnnreshape0_out"); + Operand* reshape1_in = graph.new_operand(op->name + "_ncnnreshape1_in"); + + reshape0->inputs.push_back(pooling_in); + reshape0->outputs.push_back(reshape0_out); + reshape1->inputs.push_back(reshape1_in); + reshape1->outputs.push_back(pooling_out); + + for (size_t j = 0; j < pooling_in->consumers.size(); j++) + { + if (pooling_in->consumers[j] == op) + { + pooling_in->consumers[j] = reshape0; + break; + } + } + pooling_out->producer = reshape1; + + op->inputs[0] = reshape0_out; + op->outputs[0] = reshape1_in; + + reshape0_out->producer = reshape0; + reshape0_out->consumers.push_back(op); + reshape1_in->producer = op; + reshape1_in->consumers.push_back(reshape1); + + std::vector reshape0_shape = pooling_in->shape; + reshape0_shape.insert(reshape0_shape.begin(), 1); + std::vector reshape1_shape = pooling_out->shape; + + reshape0->params["shape"] = reshape0_shape; + reshape1->params["shape"] = reshape1_shape; + + break; + } + + if (!matched) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.h b/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.h new file mode 100644 index 000000000000..45f42da7bf28 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_pooling.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_reshape_pooling(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/test_F_avg_pool2d.py b/tools/pnnx/tests/ncnn/test_F_avg_pool2d.py index 83777732362e..22085f146114 100644 --- a/tools/pnnx/tests/ncnn/test_F_avg_pool2d.py +++ b/tools/pnnx/tests/ncnn/test_F_avg_pool2d.py @@ -21,6 +21,8 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): + y = x.reshape(12, 128, 127) + x = F.avg_pool2d(x, kernel_size=3) x = F.avg_pool2d(x, kernel_size=4, stride=2, padding=2) x = F.avg_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) @@ -28,7 +30,15 @@ def forward(self, x): x = F.avg_pool2d(x, kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) x = F.avg_pool2d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) #x = F.avg_pool2d(x, kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) - return x + + y = F.avg_pool2d(y, kernel_size=3) + y = F.avg_pool2d(y, kernel_size=4, stride=2, padding=2) + y = F.avg_pool2d(y, kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) + y = F.avg_pool2d(y, kernel_size=(4,5), stride=(1,2), padding=(1,2), ceil_mode=True, count_include_pad=False) + y = F.avg_pool2d(y, kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) + y = F.avg_pool2d(y, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + #y = F.avg_pool2d(y, kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) + return x, y def test(): net = Model() @@ -51,7 +61,10 @@ def test(): import test_F_avg_pool2d_ncnn b = test_F_avg_pool2d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_F_avg_pool3d.py b/tools/pnnx/tests/ncnn/test_F_avg_pool3d.py index 50e775d87d5d..31f370882a8e 100644 --- a/tools/pnnx/tests/ncnn/test_F_avg_pool3d.py +++ b/tools/pnnx/tests/ncnn/test_F_avg_pool3d.py @@ -21,6 +21,8 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): + y = x.reshape(12, 96, 128, 128) + x = F.avg_pool3d(x, kernel_size=3) x = F.avg_pool3d(x, kernel_size=4, stride=2, padding=2) x = F.avg_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) @@ -28,7 +30,15 @@ def forward(self, x): x = F.avg_pool3d(x, kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) x = F.avg_pool3d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) #x = F.avg_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) - return x + + y = F.avg_pool3d(y, kernel_size=3) + y = F.avg_pool3d(y, kernel_size=4, stride=2, padding=2) + y = F.avg_pool3d(y, kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) + y = F.avg_pool3d(y, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,1,2), ceil_mode=True, count_include_pad=False) + y = F.avg_pool3d(y, kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) + y = F.avg_pool3d(y, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + #y = F.avg_pool3d(y, kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) + return x, y def test(): net = Model() @@ -51,7 +61,10 @@ def test(): import test_F_avg_pool3d_ncnn b = test_F_avg_pool3d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_F_max_pool1d.py b/tools/pnnx/tests/ncnn/test_F_max_pool1d.py index ee0948d487c9..50eca9c1ce49 100644 --- a/tools/pnnx/tests/ncnn/test_F_max_pool1d.py +++ b/tools/pnnx/tests/ncnn/test_F_max_pool1d.py @@ -21,13 +21,22 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): + y = x.reshape(12, 128) + x = F.max_pool1d(x, kernel_size=3) x = F.max_pool1d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool1d(x, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) x = F.max_pool1d(x, kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) x = F.max_pool1d(x, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) x = F.max_pool1d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) - return x + + y = F.max_pool1d(y, kernel_size=3) + y = F.max_pool1d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool1d(y, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool1d(y, kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool1d(y, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool1d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + return x, y #x, indices1 = F.max_pool1d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) #x, indices2 = F.max_pool1d(x, kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=True) #return x, indices1, indices2 @@ -53,7 +62,10 @@ def test(): import test_F_max_pool1d_ncnn b = test_F_max_pool1d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_F_max_pool2d.py b/tools/pnnx/tests/ncnn/test_F_max_pool2d.py index ff4413ef24c4..b3ad4c8d17aa 100644 --- a/tools/pnnx/tests/ncnn/test_F_max_pool2d.py +++ b/tools/pnnx/tests/ncnn/test_F_max_pool2d.py @@ -21,13 +21,22 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): + y = x.reshape(12, 128, 127) + x = F.max_pool2d(x, kernel_size=3) x = F.max_pool2d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) x = F.max_pool2d(x, kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) x = F.max_pool2d(x, kernel_size=(2,3), stride=1, padding=1, dilation=(1,1), return_indices=False, ceil_mode=False) x = F.max_pool2d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) - return x + + y = F.max_pool2d(y, kernel_size=3) + y = F.max_pool2d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool2d(y, kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool2d(y, kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool2d(y, kernel_size=(2,3), stride=1, padding=1, dilation=(1,1), return_indices=False, ceil_mode=False) + y = F.max_pool2d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + return x, y #x, indices1 = F.max_pool2d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) #x, indices2 = F.max_pool2d(x, kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) #return x, indices1, indices2 @@ -53,7 +62,10 @@ def test(): import test_F_max_pool2d_ncnn b = test_F_max_pool2d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_F_max_pool3d.py b/tools/pnnx/tests/ncnn/test_F_max_pool3d.py index ab0c35eb22a0..7708ffe9a865 100644 --- a/tools/pnnx/tests/ncnn/test_F_max_pool3d.py +++ b/tools/pnnx/tests/ncnn/test_F_max_pool3d.py @@ -21,13 +21,22 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): + y = x.reshape(12, 96, 128, 128) + x = F.max_pool3d(x, kernel_size=3) x = F.max_pool3d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) x = F.max_pool3d(x, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) x = F.max_pool3d(x, kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,1,1), return_indices=False, ceil_mode=False) x = F.max_pool3d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) - return x + + y = F.max_pool3d(y, kernel_size=3) + y = F.max_pool3d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool3d(y, kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool3d(y, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool3d(y, kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,1,1), return_indices=False, ceil_mode=False) + y = F.max_pool3d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + return x, y #x, indices = F.max_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) #return x, indices @@ -52,7 +61,10 @@ def test(): import test_F_max_pool3d_ncnn b = test_F_max_pool3d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py b/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py index 609afea93629..0040d3e76075 100644 --- a/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py @@ -28,13 +28,22 @@ def __init__(self): self.pool_5 = nn.AvgPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) def forward(self, x): + y = x.reshape(12, 128, 128) + x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) x = self.pool_5(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + return x, y def test(): net = Model() @@ -57,7 +66,10 @@ def test(): import test_nn_AvgPool2d_ncnn b = test_nn_AvgPool2d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_AvgPool3d.py b/tools/pnnx/tests/ncnn/test_nn_AvgPool3d.py index 207101215875..3cf42ba6ac36 100644 --- a/tools/pnnx/tests/ncnn/test_nn_AvgPool3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_AvgPool3d.py @@ -28,13 +28,22 @@ def __init__(self): self.pool_5 = nn.AvgPool3d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) def forward(self, x): + y = x.reshape(12, 96, 128, 128) + x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) x = self.pool_5(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + return x, y def test(): net = Model() @@ -57,7 +66,10 @@ def test(): import test_nn_AvgPool3d_ncnn b = test_nn_AvgPool3d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py index 36a141fb5865..f41552d97bba 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py @@ -29,6 +29,8 @@ def __init__(self): self.pool_6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, return_indices=False, ceil_mode=False) def forward(self, x): + y = x.reshape(12, 64) + x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +38,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x = self.pool_6(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + y = self.pool_6(y) + return x, y def test(): net = Model() @@ -59,7 +69,10 @@ def test(): import test_nn_MaxPool1d_ncnn b = test_nn_MaxPool1d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py index d424a939ee7a..b1c5df0595ae 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py @@ -29,6 +29,8 @@ def __init__(self): self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, ceil_mode=False) def forward(self, x): + y = x.reshape(12, 64, 64) + x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +38,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x = self.pool_6(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + y = self.pool_6(y) + return x, y def test(): net = Model() @@ -59,7 +69,10 @@ def test(): import test_nn_MaxPool2d_ncnn b = test_nn_MaxPool2d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py index 271604eb607f..deb9cdb64ec6 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py @@ -29,6 +29,8 @@ def __init__(self): self.pool_6 = nn.MaxPool3d(kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=False, ceil_mode=False) def forward(self, x): + y = x.reshape(12, 64, 64, 64) + x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +38,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x = self.pool_6(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + y = self.pool_6(y) + return x, y def test(): net = Model() @@ -59,7 +69,10 @@ def test(): import test_nn_MaxPool3d_ncnn b = test_nn_MaxPool3d_ncnn.test_inference() - return torch.allclose(a, b, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_avg_pool2d.py b/tools/pnnx/tests/test_F_avg_pool2d.py index d00b2ecad79e..95c6ed24baa5 100644 --- a/tools/pnnx/tests/test_F_avg_pool2d.py +++ b/tools/pnnx/tests/test_F_avg_pool2d.py @@ -20,7 +20,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x): + def forward(self, x, y): x = F.avg_pool2d(x, kernel_size=3) x = F.avg_pool2d(x, kernel_size=4, stride=2, padding=2) x = F.avg_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) @@ -28,7 +28,15 @@ def forward(self, x): x = F.avg_pool2d(x, kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) x = F.avg_pool2d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) x = F.avg_pool2d(x, kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) - return x + + y = F.avg_pool2d(y, kernel_size=3) + y = F.avg_pool2d(y, kernel_size=4, stride=2, padding=2) + y = F.avg_pool2d(y, kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) + y = F.avg_pool2d(y, kernel_size=(4,5), stride=(1,2), padding=(1,2), ceil_mode=True, count_include_pad=False) + y = F.avg_pool2d(y, kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) + y = F.avg_pool2d(y, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + y = F.avg_pool2d(y, kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) + return x, y def test(): net = Model() @@ -36,22 +44,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 128, 127) + y = torch.rand(12, 128, 127) - a = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_F_avg_pool2d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_F_avg_pool2d.pt inputshape=[1,12,128,127]") + os.system("../src/pnnx test_F_avg_pool2d.pt inputshape=[1,12,128,127],[12,128,127]") # pnnx inference import test_F_avg_pool2d_pnnx b = test_F_avg_pool2d_pnnx.test_inference() - return torch.equal(a, b) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_avg_pool3d.py b/tools/pnnx/tests/test_F_avg_pool3d.py index 5867a502dfa0..7d29cc4c1a9b 100644 --- a/tools/pnnx/tests/test_F_avg_pool3d.py +++ b/tools/pnnx/tests/test_F_avg_pool3d.py @@ -20,7 +20,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x): + def forward(self, x, y): x = F.avg_pool3d(x, kernel_size=3) x = F.avg_pool3d(x, kernel_size=4, stride=2, padding=2) x = F.avg_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) @@ -28,7 +28,15 @@ def forward(self, x): x = F.avg_pool3d(x, kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) x = F.avg_pool3d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) x = F.avg_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) - return x + + y = F.avg_pool3d(y, kernel_size=3) + y = F.avg_pool3d(y, kernel_size=4, stride=2, padding=2) + y = F.avg_pool3d(y, kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) + y = F.avg_pool3d(y, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,1,2), ceil_mode=True, count_include_pad=False) + y = F.avg_pool3d(y, kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) + y = F.avg_pool3d(y, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + y = F.avg_pool3d(y, kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) + return x, y def test(): net = Model() @@ -36,22 +44,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 96, 128, 128) + y = torch.rand(12, 96, 128, 128) - a = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_F_avg_pool3d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_F_avg_pool3d.pt inputshape=[1,12,96,128,128]") + os.system("../src/pnnx test_F_avg_pool3d.pt inputshape=[1,12,96,128,128],[12,96,128,128]") # pnnx inference import test_F_avg_pool3d_pnnx b = test_F_avg_pool3d_pnnx.test_inference() - return torch.equal(a, b) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_max_pool1d.py b/tools/pnnx/tests/test_F_max_pool1d.py index c3e3b02527ea..92e604f39714 100644 --- a/tools/pnnx/tests/test_F_max_pool1d.py +++ b/tools/pnnx/tests/test_F_max_pool1d.py @@ -20,7 +20,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x): + def forward(self, x, y): x = F.max_pool1d(x, kernel_size=3) x = F.max_pool1d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool1d(x, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) @@ -29,7 +29,15 @@ def forward(self, x): x = F.max_pool1d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) x, indices1 = F.max_pool1d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) x, indices2 = F.max_pool1d(x, kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=True) - return x, indices1, indices2 + + y = F.max_pool1d(y, kernel_size=3) + y = F.max_pool1d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool1d(y, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool1d(y, kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool1d(y, kernel_size=3, stride=1, padding=1, dilation=2, return_indices=False, ceil_mode=False) + y = F.max_pool1d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + + return x, indices1, indices2, y def test(): net = Model() @@ -37,22 +45,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 128) + y = torch.rand(12, 128) - a0, a1, a2 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_F_max_pool1d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_F_max_pool1d.pt inputshape=[1,12,128]") + os.system("../src/pnnx test_F_max_pool1d.pt inputshape=[1,12,128],[12,128]") # pnnx inference import test_F_max_pool1d_pnnx - b0, b1, b2 = test_F_max_pool1d_pnnx.test_inference() + b = test_F_max_pool1d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_max_pool2d.py b/tools/pnnx/tests/test_F_max_pool2d.py index 5b7f9722d88b..a6a6e162a8a6 100644 --- a/tools/pnnx/tests/test_F_max_pool2d.py +++ b/tools/pnnx/tests/test_F_max_pool2d.py @@ -20,7 +20,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x): + def forward(self, x, y): x = F.max_pool2d(x, kernel_size=3) x = F.max_pool2d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) @@ -29,7 +29,14 @@ def forward(self, x): x = F.max_pool2d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) x, indices1 = F.max_pool2d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) x, indices2 = F.max_pool2d(x, kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) - return x, indices1, indices2 + + y = F.max_pool2d(y, kernel_size=3) + y = F.max_pool2d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool2d(y, kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool2d(y, kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool2d(y, kernel_size=(2,3), stride=1, padding=1, dilation=(1,2), return_indices=False, ceil_mode=False) + y = F.max_pool2d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + return x, indices1, indices2, y def test(): net = Model() @@ -37,22 +44,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 128, 127) + y = torch.rand(12, 128, 127) - a0, a1, a2 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_F_max_pool2d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_F_max_pool2d.pt inputshape=[1,12,128,127]") + os.system("../src/pnnx test_F_max_pool2d.pt inputshape=[1,12,128,127],[12,128,127]") # pnnx inference import test_F_max_pool2d_pnnx - b0, b1, b2 = test_F_max_pool2d_pnnx.test_inference() + b = test_F_max_pool2d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_max_pool3d.py b/tools/pnnx/tests/test_F_max_pool3d.py index d82087f00c64..42e52b014a25 100644 --- a/tools/pnnx/tests/test_F_max_pool3d.py +++ b/tools/pnnx/tests/test_F_max_pool3d.py @@ -20,7 +20,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x): + def forward(self, x, y): x = F.max_pool3d(x, kernel_size=3) x = F.max_pool3d(x, kernel_size=4, stride=2, padding=2, dilation=1) x = F.max_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) @@ -28,7 +28,14 @@ def forward(self, x): x = F.max_pool3d(x, kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,2,2), return_indices=False, ceil_mode=False) x = F.max_pool3d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) x, indices = F.max_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) - return x, indices + + y = F.max_pool3d(y, kernel_size=3) + y = F.max_pool3d(y, kernel_size=4, stride=2, padding=2, dilation=1) + y = F.max_pool3d(y, kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) + y = F.max_pool3d(y, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) + y = F.max_pool3d(y, kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,2,2), return_indices=False, ceil_mode=False) + y = F.max_pool3d(y, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + return x, indices, y def test(): net = Model() @@ -36,22 +43,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 96, 128, 128) + y = torch.rand(12, 96, 128, 128) - a0, a1 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_F_max_pool3d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_F_max_pool3d.pt inputshape=[1,12,96,128,128]") + os.system("../src/pnnx test_F_max_pool3d.pt inputshape=[1,12,96,128,128],[12,96,128,128]") # pnnx inference import test_F_max_pool3d_pnnx - b0, b1 = test_F_max_pool3d_pnnx.test_inference() + b = test_F_max_pool3d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_AvgPool2d.py b/tools/pnnx/tests/test_nn_AvgPool2d.py index 458463a3d438..4f642560bec3 100644 --- a/tools/pnnx/tests/test_nn_AvgPool2d.py +++ b/tools/pnnx/tests/test_nn_AvgPool2d.py @@ -28,7 +28,7 @@ def __init__(self): self.pool_5 = nn.AvgPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) self.pool_6 = nn.AvgPool2d(kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) - def forward(self, x): + def forward(self, x, y): x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +36,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x = self.pool_6(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + y = self.pool_6(y) + return x, y def test(): net = Model() @@ -44,22 +52,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 128, 128) + y = torch.rand(12, 128, 128) - a = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_nn_AvgPool2d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_nn_AvgPool2d.pt inputshape=[1,12,128,128]") + os.system("../src/pnnx test_nn_AvgPool2d.pt inputshape=[1,12,128,128],[12,128,128]") # pnnx inference import test_nn_AvgPool2d_pnnx b = test_nn_AvgPool2d_pnnx.test_inference() - return torch.equal(a, b) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_AvgPool3d.py b/tools/pnnx/tests/test_nn_AvgPool3d.py index 486f0e603742..639f222561e1 100644 --- a/tools/pnnx/tests/test_nn_AvgPool3d.py +++ b/tools/pnnx/tests/test_nn_AvgPool3d.py @@ -28,7 +28,7 @@ def __init__(self): self.pool_5 = nn.AvgPool3d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) self.pool_6 = nn.AvgPool3d(kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) - def forward(self, x): + def forward(self, x, y): x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +36,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x = self.pool_6(x) - return x + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + y = self.pool_6(y) + return x, y def test(): net = Model() @@ -44,22 +52,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 96, 128, 128) + y = torch.rand(12, 96, 128, 128) - a = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_nn_AvgPool3d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_nn_AvgPool3d.pt inputshape=[1,12,96,128,128]") + os.system("../src/pnnx test_nn_AvgPool3d.pt inputshape=[1,12,96,128,128],[12,96,128,128]") # pnnx inference import test_nn_AvgPool3d_pnnx b = test_nn_AvgPool3d_pnnx.test_inference() - return torch.equal(a, b) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_MaxPool1d.py b/tools/pnnx/tests/test_nn_MaxPool1d.py index 8dbcd3d858c6..6e2c05974fca 100644 --- a/tools/pnnx/tests/test_nn_MaxPool1d.py +++ b/tools/pnnx/tests/test_nn_MaxPool1d.py @@ -28,7 +28,7 @@ def __init__(self): self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) self.pool_6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) - def forward(self, x): + def forward(self, x, y): x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +36,15 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x, indices = self.pool_6(x) - return x, indices + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + + return x, indices, y def test(): net = Model() @@ -44,22 +52,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 64) + y = torch.rand(12, 64) - a0, a1 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_nn_MaxPool1d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_nn_MaxPool1d.pt inputshape=[1,12,64]") + os.system("../src/pnnx test_nn_MaxPool1d.pt inputshape=[1,12,64],[12,64]") # pnnx inference import test_nn_MaxPool1d_pnnx - b0, b1 = test_nn_MaxPool1d_pnnx.test_inference() + b = test_nn_MaxPool1d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_MaxPool2d.py b/tools/pnnx/tests/test_nn_MaxPool2d.py index 497f05a6aeb4..f171d4967e56 100644 --- a/tools/pnnx/tests/test_nn_MaxPool2d.py +++ b/tools/pnnx/tests/test_nn_MaxPool2d.py @@ -28,7 +28,7 @@ def __init__(self): self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) - def forward(self, x): + def forward(self, x, y): x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +36,14 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x, indices = self.pool_6(x) - return x, indices + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + return x, indices, y def test(): net = Model() @@ -44,22 +51,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 64, 64) + y = torch.rand(12, 64, 64) - a0, a1 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_nn_MaxPool2d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_nn_MaxPool2d.pt inputshape=[1,12,64,64]") + os.system("../src/pnnx test_nn_MaxPool2d.pt inputshape=[1,12,64,64],[12,64,64]") # pnnx inference import test_nn_MaxPool2d_pnnx - b0, b1 = test_nn_MaxPool2d_pnnx.test_inference() + b = test_nn_MaxPool2d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_MaxPool3d.py b/tools/pnnx/tests/test_nn_MaxPool3d.py index 22918594e47a..a0b7063a8ddf 100644 --- a/tools/pnnx/tests/test_nn_MaxPool3d.py +++ b/tools/pnnx/tests/test_nn_MaxPool3d.py @@ -28,7 +28,7 @@ def __init__(self): self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) self.pool_6 = nn.MaxPool3d(kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) - def forward(self, x): + def forward(self, x, y): x = self.pool_0(x) x = self.pool_1(x) x = self.pool_2(x) @@ -36,7 +36,14 @@ def forward(self, x): x = self.pool_4(x) x = self.pool_5(x) x, indices = self.pool_6(x) - return x, indices + + y = self.pool_0(y) + y = self.pool_1(y) + y = self.pool_2(y) + y = self.pool_3(y) + y = self.pool_4(y) + y = self.pool_5(y) + return x, indices, y def test(): net = Model() @@ -44,22 +51,26 @@ def test(): torch.manual_seed(0) x = torch.rand(1, 12, 64, 64, 64) + y = torch.rand(12, 64, 64, 64) - a0, a1 = net(x) + a = net(x, y) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y)) mod.save("test_nn_MaxPool3d.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_nn_MaxPool3d.pt inputshape=[1,12,64,64,64]") + os.system("../src/pnnx test_nn_MaxPool3d.pt inputshape=[1,12,64,64,64],[12,64,64,64]") # pnnx inference import test_nn_MaxPool3d_pnnx - b0, b1 = test_nn_MaxPool3d_pnnx.test_inference() + b = test_nn_MaxPool3d_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test():