forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Pool.cpp
80 lines (71 loc) · 2.3 KB
/
Pool.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//
// Pool.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "OpConverter.hpp"
#include "logkit.h"
class Pool : public OpConverter {
public:
virtual void run(MNN::OpT* dstOp, const caffe::LayerParameter& parameters, const caffe::LayerParameter& weight);
Pool() {
}
virtual ~Pool() {
}
virtual MNN::OpType opType() {
return MNN::OpType_Pooling;
}
virtual MNN::OpParameter type() {
return MNN::OpParameter_Pool;
}
};
void Pool::run(MNN::OpT* dstOp, const caffe::LayerParameter& parameters, const caffe::LayerParameter& weight) {
const caffe::PoolingParameter& p = parameters.pooling_param();
auto pool = new MNN::PoolT;
dstOp->main.value = pool;
auto poolingType = p.pool();
if (poolingType == caffe::PoolingParameter::MAX) {
pool->type = MNN::PoolType_MAXPOOL;
} else if (poolingType == caffe::PoolingParameter::AVE) {
pool->type = MNN::PoolType_AVEPOOL;
} else {
DLOG(FATAL) << "Pool type not support! ==> " << parameters.name();
}
// orinal NCHW, our whc
int kernelSize[3];
kernelSize[2] = kernelSize[1] = kernelSize[0] = 1;
if (p.has_kernel_size())
kernelSize[2] = kernelSize[1] = kernelSize[0] = p.kernel_size();
if (p.has_kernel_w())
kernelSize[0] = p.kernel_w();
if (p.has_kernel_h())
kernelSize[1] = p.kernel_h();
pool->kernelY = (kernelSize[1]);
pool->kernelX = (kernelSize[0]);
int stride[3];
int pad[3];
int isGlobal = 0;
stride[2] = stride[1] = stride[0] = 1;
if (p.has_stride())
stride[2] = stride[1] = stride[0] = p.stride();
if (p.has_stride_w())
stride[0] = p.stride_w();
if (p.has_stride_h())
stride[1] = p.stride_h();
pool->strideY = (stride[1]);
pool->strideX = (stride[0]);
pad[2] = pad[1] = pad[0] = 0;
if (p.has_pad())
pad[2] = pad[1] = pad[0] = p.pad();
if (p.has_pad_w())
pad[0] = p.pad_w();
if (p.has_pad_h())
pad[1] = p.pad_h();
pool->padY = pad[1];
pool->padX = pad[0];
isGlobal = p.has_global_pooling() ? p.global_pooling() : 0;
pool->isGlobal = isGlobal;
}
static OpConverterRegister<Pool> a("Pooling");