-
Notifications
You must be signed in to change notification settings - Fork 137
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CK_TILE] Image to Column kernel (#1532)
* [CK_TILE] Image to Column kernel * Fixes * Vector loads and stores * Fixes * Fixes * change test dir name
- Loading branch information
Showing
19 changed files
with
1,423 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# not using add_example_executable() to add this target, since we don't want this to have | ||
# to be included in "make all/install/check" | ||
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Image to Column | ||
|
||
This folder contains example for Image to Column using ck_tile tile-programming implementation. | ||
|
||
## build | ||
``` | ||
# in the root of ck_tile | ||
mkdir build && cd build | ||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... | ||
make tile_example_img2col -j | ||
``` | ||
This will result in an executable `build/bin/tile_example_img2col` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#include <algorithm> | ||
#include <cstring> | ||
|
||
#include "ck_tile/host.hpp" | ||
#include "image_to_column.hpp" | ||
|
||
// Host API implementation | ||
template <> | ||
float image_to_column(const image_to_column_traits& traits, | ||
const image_to_column_args<2>& args, | ||
const ck_tile::stream_config& stream_conf) | ||
{ | ||
if(traits.data_type.compare("fp16") == 0) | ||
{ | ||
constexpr ck_tile::index_t NDimSpatial = 2; | ||
constexpr ck_tile::index_t VectorSize = 8; | ||
|
||
using thread_tile = ck_tile::sequence<8, 8>; | ||
using warp_tile = ck_tile::sequence<64, 64>; | ||
using block_tile = ck_tile::sequence<128, 128>; | ||
|
||
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>; | ||
|
||
using InDataType = ck_tile::half_t; | ||
using OutDataType = ck_tile::half_t; | ||
|
||
using PipelineProblem = ck_tile::BlockImageToColumnProblem<InDataType, | ||
OutDataType, | ||
Shape, | ||
NDimSpatial, | ||
VectorSize, | ||
VectorSize>; | ||
|
||
using Kernel = ck_tile::ImageToColumn<PipelineProblem>; | ||
|
||
auto kargs = Kernel::MakeKargs(args.p_in, | ||
args.p_out, | ||
args.G, | ||
args.N, | ||
args.C, | ||
args.input_spatial_lengths, | ||
args.filter_spatial_lengths, | ||
args.output_spatial_lengths, | ||
args.image_g_n_c_wis_strides, | ||
args.gemm_g_m_k_strides, | ||
args.conv_filter_strides, | ||
args.conv_filter_dilations, | ||
args.input_left_pads, | ||
args.input_right_pads); | ||
|
||
const dim3 grids = Kernel::GridSize( | ||
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1], | ||
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C, | ||
args.G); | ||
constexpr dim3 blocks = Kernel::BlockSize(); | ||
|
||
constexpr ck_tile::index_t kBlockPerCu = 2; | ||
|
||
float ave_time = ck_tile::launch_kernel( | ||
stream_conf, | ||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); | ||
|
||
return ave_time; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
int main(int argc, char* argv[]) | ||
{ | ||
constexpr ck_tile::index_t NDimSpatial = 2; | ||
|
||
ExecutionConfig config; | ||
ck_tile::conv::ConvParam conv_params = DefaultConvParams; | ||
|
||
if(!parse_cmd_args(argc, argv, config, conv_params)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
|
||
if(conv_params.num_dim_spatial_ != NDimSpatial) | ||
{ | ||
std::cerr << "unsupported # of spatial dimensions" << std::endl; | ||
return EXIT_FAILURE; | ||
} | ||
|
||
using InDataType = ck_tile::half_t; | ||
using OutDataType = ck_tile::half_t; | ||
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC; | ||
|
||
const auto G = conv_params.G_; | ||
const auto N = conv_params.N_; | ||
const auto C = conv_params.C_; | ||
|
||
const ck_tile::long_index_t NHoWo = | ||
N * std::accumulate(conv_params.output_spatial_lengths_.begin(), | ||
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial), | ||
1, | ||
std::multiplies<>()); | ||
|
||
const ck_tile::long_index_t CYX = | ||
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(), | ||
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial), | ||
1, | ||
std::multiplies<>()); | ||
|
||
const auto in_desc = | ||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params); | ||
const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX}); | ||
|
||
// host verify | ||
ck_tile::HostTensor<InDataType> in(in_desc); | ||
ck_tile::HostTensor<OutDataType> out_device(out_desc); | ||
ck_tile::HostTensor<OutDataType> out_host(out_desc); | ||
|
||
switch(config.init_method) | ||
{ | ||
case 0: break; | ||
case 1: ck_tile::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in); break; | ||
default: ck_tile::FillUniformDistribution<InDataType>{-0.5, 0.5}(in); break; | ||
} | ||
|
||
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes()); | ||
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes()); | ||
|
||
in_device_buf.ToDevice(in.data()); | ||
|
||
image_to_column_traits traits{"fp16"}; | ||
|
||
image_to_column_args<NDimSpatial> args{ | ||
in_device_buf.GetDeviceBuffer(), | ||
out_device_buf.GetDeviceBuffer(), | ||
G, | ||
N, | ||
C, | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_spatial_lengths_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.filter_spatial_lengths_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.output_spatial_lengths_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()), | ||
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_dilations_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_), | ||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_)}; | ||
|
||
float ave_time = | ||
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel}); | ||
|
||
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); | ||
float gb_per_sec = num_btype / 1.E6 / ave_time; | ||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; | ||
|
||
bool pass = true; | ||
|
||
if(config.do_verification) | ||
{ | ||
// reference | ||
ck_tile::reference_im2col<InDataType, OutDataType, NDimSpatial>(in, out_host, conv_params); | ||
|
||
out_device_buf.FromDevice(out_device.data()); | ||
pass = ck_tile::check_err(out_device, out_host); | ||
|
||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl; | ||
} | ||
|
||
return !pass; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
#include "ck_tile/core.hpp" | ||
#include "ck_tile/host/kernel_launch.hpp" | ||
#include "ck_tile/ops/image_to_column.hpp" | ||
#include <string> | ||
|
||
#define DefaultConvParams \ | ||
ck_tile::conv::ConvParam \ | ||
{ \ | ||
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ | ||
} | ||
|
||
struct ExecutionConfig final | ||
{ | ||
bool do_verification = true; | ||
int init_method = 1; | ||
bool time_kernel = false; | ||
}; | ||
|
||
inline void print_help_msg() | ||
{ | ||
std::cerr << "arg1: verification (0=no, 1=yes)\n" | ||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" | ||
<< "arg3: time kernel (0=no, 1=yes)\n" | ||
<< ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl; | ||
} | ||
|
||
inline bool parse_cmd_args(int argc, | ||
char* argv[], | ||
ExecutionConfig& config, | ||
ck_tile::conv::ConvParam& conv_params) | ||
{ | ||
constexpr int num_execution_config_args = | ||
3; // arguments for do_verification, init_method, time_kernel | ||
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ | ||
|
||
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; | ||
constexpr int threshold_to_catch_all_args = | ||
threshold_to_catch_partial_args + num_conv_param_leading_args; | ||
|
||
if(argc == 1) | ||
{ | ||
// use default | ||
config = ExecutionConfig{}; | ||
} | ||
// catch only ExecutionConfig arguments | ||
else if(argc == threshold_to_catch_partial_args) | ||
{ | ||
config.do_verification = std::stoi(argv[1]); | ||
config.init_method = std::stoi(argv[2]); | ||
config.time_kernel = std::stoi(argv[3]); | ||
} | ||
// catch both ExecutionConfig & ConvParam arguments | ||
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) | ||
{ | ||
config.do_verification = std::stoi(argv[1]); | ||
config.init_method = std::stoi(argv[2]); | ||
config.time_kernel = std::stoi(argv[3]); | ||
|
||
const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]); | ||
conv_params = | ||
ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv); | ||
} | ||
else | ||
{ | ||
print_help_msg(); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
struct image_to_column_traits | ||
{ | ||
std::string data_type; | ||
}; | ||
|
||
template <ck_tile::index_t NDimSpatial> | ||
struct image_to_column_args | ||
{ | ||
const void* p_in; | ||
void* p_out; | ||
const ck_tile::long_index_t G; | ||
const ck_tile::long_index_t N; | ||
const ck_tile::long_index_t C; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_spatial_lengths; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> filter_spatial_lengths; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> output_spatial_lengths; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides; | ||
const ck_tile::array<ck_tile::long_index_t, 3> gemm_g_m_k_strides; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_strides; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_dilations; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_left_pads; | ||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_right_pads; | ||
}; | ||
|
||
// host API | ||
template <ck_tile::index_t NDimSpatial> | ||
float image_to_column(const image_to_column_traits&, | ||
const image_to_column_args<NDimSpatial>&, | ||
const ck_tile::stream_config&); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.