diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index 3dd4a65..0cd8be4 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -59,14 +59,16 @@ def get_kernel_shape(self) -> str: def get_out_shape(self) -> str: padding = 0 - h_out = (self.H + 2 * padding - self.P) // self.S + 1 - w_out = (self.W + 2 * padding - self.Q) // self.S + 1 + in_h = self.H * self.S + self.P - 1 + in_w = self.W * self.S + self.Q - 1 + h_out = (in_h + 2 * padding - self.P) // self.S + 1 + w_out = (in_w + 2 * padding - self.Q) // self.S + 1 n = self.N nf = self.F if "nhwc" in self.OP: - return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.input_dtype + return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.output_dtype if "nchw" in self.OP: - return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.input_dtype + return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.output_dtype def get_byte_count(self) -> int: dtype_bits_map = {