From cd1a3f1344020aad0a5f5e1d9be7a747b964b5dc Mon Sep 17 00:00:00 2001 From: Vasilii Filippov Date: Fri, 15 Nov 2024 21:06:38 +0100 Subject: [PATCH] Revert "Fixed incorrect transpose in find 2.0 (#3285)" This reverts commit 2d69aebdcf1b22df1357a38b68e1b8ddefd8b424. --- src/problem.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/problem.cpp b/src/problem.cpp index ba84856850..fed48dfe88 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -179,11 +179,7 @@ Problem::FindSolutions(Handle& handle, const FindOptions& options, std::size_t m auto ret = std::visit( boost::hof::match( [&](const ConvolutionDescriptor& op_desc) { - if(op_desc.mode == miopenTranspose) - return MakeTransposed().FindSolutionsImpl( - handle, options, max_solutions, buffers, op_desc); - else - return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); + return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); }, [&](const SoftmaxDescriptor& op_desc) { return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); @@ -481,17 +477,21 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, const auto& w = buffers.at(miopenTensorConvolutionW); auto y = buffers.at(miopenTensorConvolutionY); - if(conv_desc.mode == miopenTranspose) - std::swap(x, y); - - const auto conv_problem = AsConvolution(); - - ValidateGroupCount(x_desc, w_desc, conv_desc); + const auto conv_problem = + conv_desc.mode == miopenTranspose ? MakeTransposed().AsConvolution() : AsConvolution(); std::size_t workspace_size; Allocator::ManageDataPtr owned_workspace; Data_t workspace; + if(conv_desc.mode == miopenTranspose) + { + std::swap(x, y); + std::swap(x_desc, y_desc); + } + + ValidateGroupCount(x_desc, w_desc, conv_desc); + if(options.preallocated_workspace) { workspace = options.preallocated_workspace->buffer;