From c61de7956b2e334be6449bfdc1299b1169316e46 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Dec 2024 18:05:59 +0100 Subject: [PATCH] fix Signed-off-by: Ivan Butygin --- convbench/conv_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index 3dd4a65..0cd8be4 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -59,14 +59,16 @@ def get_kernel_shape(self) -> str: def get_out_shape(self) -> str: padding = 0 - h_out = (self.H + 2 * padding - self.P) // self.S + 1 - w_out = (self.W + 2 * padding - self.Q) // self.S + 1 + in_h = self.H * self.S + self.P - 1 + in_w = self.W * self.S + self.Q - 1 + h_out = (in_h + 2 * padding - self.P) // self.S + 1 + w_out = (in_w + 2 * padding - self.Q) // self.S + 1 n = self.N nf = self.F if "nhwc" in self.OP: - return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.input_dtype + return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.output_dtype if "nchw" in self.OP: - return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.input_dtype + return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.output_dtype def get_byte_count(self) -> int: dtype_bits_map = {