Skip to content

Commit

Permalink
Add support for 3D and 2D grouped conolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Nov 18, 2024
1 parent 32fb07d commit a31f546
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 38 deletions.
28 changes: 28 additions & 0 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CONV_Q = r"""%c0_i32 = arith.constant 0 : i32
%11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS}
%cst = arith.constant {ZERO} : {OUT_ELEM_TYPE}
%9 = tensor.empty() : tensor<{OUT_TYPE}>
Expand All @@ -33,6 +35,12 @@ class ConvConfig:
Q: int
F: int
S: int
G: int # group count
D: int # input depth
R: int # filter depth
P_D: int # padding along depth
S_D: int # stride along depth
D_D: int # dilation along depth
OP: str
input_dtype: str
output_dtype: str
Expand Down Expand Up @@ -109,11 +117,18 @@ def generate_mlir(config: ConvConfig):
q = config.Q
f = config.F
stride = config.S
g = config.G
d = config.D
r = config.R
p_d = config.P_D
s_d = config.S_D
d_d = config.D_D
operation = config.OP
dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}"
elem_types = dtypes.split("x")
in_h = str(int(h) * int(stride) + int(p) - 1)
in_w = str(int(w) * int(stride) + int(q) - 1)
in_d = str(int(d) * int(s_d) + int(r) - 1)
if "nhwc" in operation:
conv_type = "nhwc_hwcf"
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
Expand All @@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig):
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ngchw" in operation:
conv_type = "ngchw_fgchw"
lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ncdhw" in operation:
conv_type = "ncdhw_fcdhw"
lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])

one = "1"
zero = "0"
if (elem_types[0][0] == "f"):
Expand All @@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig):
conv_template = CONV
if "q" in operation:
conv_template = CONV_Q
if "ncdhw" in operation:
conv_template = CONV_3D
operation = conv_template.format(
INPUT_TYPE=lhs,
FILTER_TYPE=rhs,
Expand Down
76 changes: 38 additions & 38 deletions convbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,49 @@
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def get_conv_configs() -> list[tuple[str, ConvConfig]]:
Expand Down

0 comments on commit a31f546

Please sign in to comment.