Skip to content

Commit

Permalink
Ck tile gemm padding dim (#1516)
Browse files Browse the repository at this point in the history
* Support the N dimension padding

* Finished the padding feature for different dimension of K
  • Loading branch information
ThomasNing authored Sep 18, 2024
1 parent e84adec commit 694c300
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
10 changes: 5 additions & 5 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,

std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
<< "is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< " is: \n";
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush;

return ave_time;
Expand Down Expand Up @@ -235,7 +235,7 @@ int main(int argc, char* argv[])
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = false;
constexpr bool kPadC = true;

// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
Expand Down Expand Up @@ -348,7 +348,7 @@ int main(int argc, char* argv[])

pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);

std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail")
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush;
}

Expand Down
26 changes: 21 additions & 5 deletions include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,26 @@ struct GemmKernel
}
}();

auto ABlockWindow = make_tile_window(
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});

auto ABlockWindow = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});

auto BBlockWindow = make_tile_window(
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});

auto BBlockWindow = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});

// allocate LDS
Expand Down Expand Up @@ -163,12 +175,16 @@ struct GemmKernel
}
}();

auto CBlockWindow = make_tile_window(
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
// epilogue.
EpiloguePipeline{}(CBlockWindow, acc);
EpiloguePipeline{}(CBlockWindow_pad, acc);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;

static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;

CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;

static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
};

} // namespace ck_tile

0 comments on commit 694c300

Please sign in to comment.