Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: intwanghao <[email protected]>
  • Loading branch information
intwanghao committed Dec 16, 2024
1 parent a79a89b commit 3cf4e2c
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 102 deletions.
13 changes: 2 additions & 11 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,11 +1615,6 @@ void DpctGlobalInfo::buildReplacements() {
}
}

for (auto &Repls : WrapperRegisterMap) {
addReplacement(Repls.second.first);
addReplacement(Repls.second.second);
}

for (auto &File : FileMap)
File.second->buildReplacements();

Expand Down Expand Up @@ -2508,9 +2503,6 @@ std::unordered_map<std::string, std::unordered_map<clang::tooling::UnifiedPath,
std::unordered_map<clang::tooling::UnifiedPath,
std::vector<clang::tooling::UnifiedPath>>
DpctGlobalInfo::MainSourceFileMap;
std::unordered_map<std::string, std::pair<std::shared_ptr<ExtReplacement>,
std::shared_ptr<ExtReplacement>>>
DpctGlobalInfo::WrapperRegisterMap;
std::unordered_map<std::string, bool> DpctGlobalInfo::MallocHostInfoMap;
std::map<std::shared_ptr<TextModification>, bool>
DpctGlobalInfo::ConstantReplProcessedFlagMap;
Expand Down Expand Up @@ -5384,8 +5376,7 @@ DeviceFunctionInfo::DeviceFunctionInfo(size_t ParamsNum,
}

void DeviceFunctionInfo::collectInfoForWrapper(const FunctionDecl *FD) {
if (!WrapperInfoCollected) {
WrapperInfoCollected = true;
if (!DFInfoForWrapper) {
DFInfoForWrapper = std::make_shared<DeviceFunctionInfoForWrapper>();
auto LocInfo = DpctGlobalInfo::getLocInfo(FD->getBeginLoc());
auto &TemplateParametersInfo = DFInfoForWrapper->TemplateParametersInfo;
Expand Down Expand Up @@ -6175,7 +6166,7 @@ std::shared_ptr<KernelCallExpr> KernelCallExpr::buildFromCudaLaunchKernel(
CE);
Kernel->buildNeedBracesInfo(CE);
const FunctionDecl *FD = nullptr;
if (auto Callee = getAddressedRef(CE->getArg(0), &FD)) {
if (auto Callee = getAddressedRef(CE->getArg(0), true, &FD)) {
Kernel->buildCalleeInfo(Callee, std::nullopt);
auto FuncInfo = Kernel->getFuncInfo();
if (FD && FuncInfo) {
Expand Down
18 changes: 8 additions & 10 deletions clang/lib/DPCT/AnalysisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1005,10 +1005,14 @@ class DpctGlobalInfo {
return Cur.get<TargetTy>();
});
}
template <class TargetTy, class NodeTy>
template <class TargetTy, class NodeTy, class... SkipNodeTy>
static auto findParent(const NodeTy *Node) {
return findAncestor<TargetTy>(
Node, [](const DynTypedNode &Cur) -> bool { return true; });
return findAncestor<TargetTy>(Node, [](const DynTypedNode &Cur) -> bool {
if ((... || Cur.get<SkipNodeTy>())) {
return false;
}
return true;
});
}

template <typename TargetTy, typename NodeTy>
Expand Down Expand Up @@ -1466,7 +1470,6 @@ class DpctGlobalInfo {
return ConstantReplProcessedFlagMap;
}
static IncludeMapSetTy &getIncludeMapSet() { return IncludeMapSet; }
static auto &getWrapperRegisterMap() { return WrapperRegisterMap; }
static auto &getCodePinTypeInfoVec() { return CodePinTypeInfoMap; }
static auto &getCodePinTemplateTypeInfoVec() {
return CodePinTemplateTypeInfoMap;
Expand Down Expand Up @@ -1691,10 +1694,6 @@ class DpctGlobalInfo {
static std::map<std::shared_ptr<TextModification>, bool>
ConstantReplProcessedFlagMap;
static IncludeMapSetTy IncludeMapSet;
static std::unordered_map<std::string,
std::pair<std::shared_ptr<ExtReplacement>,
std::shared_ptr<ExtReplacement>>>
WrapperRegisterMap;
static std::vector<std::pair<std::string, VarInfoForCodePin>>
CodePinTypeInfoMap;
static std::vector<std::pair<std::string, VarInfoForCodePin>>
Expand Down Expand Up @@ -2772,9 +2771,8 @@ class DeviceFunctionInfo {
bool CallGroupFunctionInControlFlow = false;
bool HasCheckedCallGroupFunctionInControlFlow = false;
OverloadedOperatorKind OO_Kind = OverloadedOperatorKind::OO_None;
bool WrapperInfoCollected = false;
bool ModuleUsed = false;
std::shared_ptr<DeviceFunctionInfoForWrapper> DFInfoForWrapper;
std::shared_ptr<DeviceFunctionInfoForWrapper> DFInfoForWrapper = nullptr;
};

class KernelCallExpr : public CallFunctionExpr {
Expand Down
166 changes: 105 additions & 61 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4442,38 +4442,112 @@ void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
this);
}

std::string KernelCallRefRule::getTypeRepl(const Expr *E) {
std::string TypeRef;
if (auto BO = DpctGlobalInfo::findParent<BinaryOperator, Expr,
ImplicitCastExpr, ParenExpr>(E)) {
TypeRef = "decltype(" + ExprAnalysis::ref(BO->getLHS()) + ")";
} else if (auto VD =
DpctGlobalInfo::findParent<VarDecl, Expr, ImplicitCastExpr,
ParenExpr>(E)) {
TypeRef = "decltype(" + VD->getNameAsString() + ")";
} else if (auto RS =
DpctGlobalInfo::findParent<ReturnStmt, Expr, ImplicitCastExpr,
ParenExpr>(E)) {
auto FD = DpctGlobalInfo::findAncestor<FunctionDecl>(RS);
TypeRef =
getStmtSpelling(FD->getReturnTypeSourceRange(), FD->getSourceRange());
} else if (auto CE =
DpctGlobalInfo::findParent<CallExpr, Expr, ImplicitCastExpr,
ParenExpr>(E)) {
size_t N = 0;
for (auto Arg : CE->arguments()) {
if (Arg == E) {
break;
}
N++;
}
TypeRef = "typename " + MapNames::getDpctNamespace() +
"nth_argument_type<decltype(" +
ExprAnalysis::ref(CE->getCallee()) + "), " + std::to_string(N) +
">::type";
}
if (!TypeRef.empty()) {
return "<" + TypeRef + ">";
}
return TypeRef;
}

template <typename T>
void KernelCallRefRule::insertWrapperPostfix(const T *Node,
std::string &&TypeRepl,
bool isInsertWrapperRegister) {
auto NLoc = DpctGlobalInfo::getSourceManager().getSpellingLoc(
Node->getNameInfo().getBeginLoc());
emplaceTransformation(new InsertText(
NLoc.getLocWithOffset(Node->getNameInfo().getAsString().length()),
"_wrapper"));

if (!isInsertWrapperRegister) {
return;
}
const Expr *E = Node;
if (auto UO = DpctGlobalInfo::findParent<UnaryOperator, T, ImplicitCastExpr,
ParenExpr>(Node)) {
if (UO->getOpcode() == UO_AddrOf) {
E = UO;
}
} else if (auto COC = DpctGlobalInfo::findParent<CXXOperatorCallExpr, T,
ImplicitCastExpr, ParenExpr>(
Node)) {
if (COC->getOperator() == clang::OO_Amp) {
E = COC;
}
}
emplaceTransformation(new InsertBeforeStmt(
E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "("));
emplaceTransformation(new InsertAfterStmt(E, ")"));
}

void KernelCallRefRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, "kernelRef")) {
if (auto ParentCE = dpct::DpctGlobalInfo::findAncestor<CallExpr>(DRE)) {
if (auto ParentCE = DpctGlobalInfo::findAncestor<CallExpr>(DRE)) {
if (auto Callee = ParentCE->getDirectCallee()) {
if (dpct::DpctGlobalInfo::isInCudaPath(Callee->getBeginLoc())) {
return;
}
}
}
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
const FunctionDecl *FD = dyn_cast<FunctionDecl>(DRE->getDecl());
bool IsTemplateRelated = false;
int TemplateParamNum = 0;
if (FD) {
if (FD->getTemplatedKind() !=
FunctionDecl::TemplatedKind::TK_NonTemplate) {
IsTemplateRelated = true;
}

if (auto DFI = DeviceFunctionDecl::LinkRedecls(FD)) {
DFI->collectInfoForWrapper(FD);
}
}

auto NLoc = DpctGlobalInfo::getSourceManager().getSpellingLoc(
DRE->getNameInfo().getBeginLoc());
emplaceTransformation(new InsertText(
NLoc.getLocWithOffset(DRE->getNameInfo().getAsString().length()),
"_wrapper"));
if (DpctGlobalInfo::isCVersionCUDALaunchUsed()) {
auto &SM = DpctGlobalInfo::getSourceManager();
auto &Map = DpctGlobalInfo::getWrapperRegisterMap();
auto Key = getStrFromLoc(SM.getSpellingLoc(DRE->getBeginLoc()));
if (!Map.count(Key)) {
Map.insert({getStrFromLoc(SM.getSpellingLoc(DRE->getBeginLoc())),
{InsertBeforeStmt(DRE, MapNames::getDpctNamespace() +
"wrapper_register(")
.getReplacement(DpctGlobalInfo::getContext()),
InsertAfterStmt(DRE, ").get()")
.getReplacement(DpctGlobalInfo::getContext())}});
std::cout << IsTemplateRelated << std::endl;
std::cout <<DRE->hasExplicitTemplateArgs() << std::endl;
if (auto *OuterFD = DpctGlobalInfo::findAncestor<FunctionDecl>(DRE)) {
if ((OuterFD->getTemplatedKind() ==
FunctionDecl::TemplatedKind::TK_NonTemplate) ||
(OuterFD->getTemplatedKind() ==
FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
std::string TypeRepl;
if (DpctGlobalInfo::isCVersionCUDALaunchUsed() && IsTemplateRelated &&
!DRE->hasExplicitTemplateArgs()) {
TypeRepl = getTypeRepl(DRE);
}
std::cout << TypeRepl << std::endl;
insertWrapperPostfix<DeclRefExpr>(
DRE, std::move(TypeRepl),
DpctGlobalInfo::isCVersionCUDALaunchUsed());
}
}
}
Expand Down Expand Up @@ -4513,39 +4587,7 @@ void KernelCallRefRule::runRule(
}
}
}
std::string TypeRef;
if (auto BO = DpctGlobalInfo::findParent<BinaryOperator>(ULE)) {
TypeRef = "decltype(" + ExprAnalysis::ref(BO->getLHS()) + ")";
} else if (auto VD = DpctGlobalInfo::findParent<VarDecl>(ULE)) {
TypeRef = "decltype(" + VD->getNameAsString() + ")";
} else if (auto RS = DpctGlobalInfo::findParent<ReturnStmt>(ULE)) {
auto FD = DpctGlobalInfo::findAncestor<FunctionDecl>(RS);
TypeRef =
getStmtSpelling(FD->getReturnTypeSourceRange(), FD->getSourceRange());
} else if (auto CE = DpctGlobalInfo::findParent<CallExpr>(ULE)) {
size_t N = 0;
for (auto Arg : CE->arguments()) {
if (Arg == ULE) {
break;
}
N++;
}
TypeRef = "typename " + MapNames::getDpctNamespace() +
"nth_argument_type<decltype(" +
ExprAnalysis::ref(CE->getCallee()) + "), " + std::to_string(N) +
">::type";
}
if (!TypeRef.empty()) {
auto &SM = DpctGlobalInfo::getSourceManager();
auto &Map = DpctGlobalInfo::getWrapperRegisterMap();
Map.insert(
{getStrFromLoc(SM.getSpellingLoc(ULE->getBeginLoc())),
{InsertBeforeStmt(ULE, MapNames::getDpctNamespace() +
"wrapper_register<" + TypeRef + ">(")
.getReplacement(DpctGlobalInfo::getContext()),
InsertAfterStmt(ULE, ").get()")
.getReplacement(DpctGlobalInfo::getContext())}});
}
insertWrapperPostfix<UnresolvedLookupExpr>(ULE, getTypeRepl(ULE), true);
}
}

Expand Down Expand Up @@ -4679,6 +4721,8 @@ void KernelCallRule::runRule(
if (auto Arg = KCall->getArg(i)) {
if (!Arg->isDefaultArgument()) {
OS << ", " << ExprAnalysis::ref(Arg);
} else {
break;
}
}
}
Expand Down Expand Up @@ -4775,18 +4819,18 @@ void KernelCallRule::runRule(
const Expr *CalleeDRE = LaunchKernelCall->getArg(0);
bool IsFuncTypeErased = true;
auto QT = CalleeDRE->getType();

if (QT->isPointerType()) {
QT = QT->getPointeeType();
}
if (QT->isFunctionType()) {
IsFuncTypeErased = false;
} else if (QT->isPointerType()) {
const Type *PointeeType = QT->getPointeeType().getTypePtr();
if (PointeeType->isFunctionType()) {
IsFuncTypeErased = false;
}
}
if (IsFuncTypeErased) {
DpctGlobalInfo::setCVersionCUDALaunchUsed();
}

if (!getAddressedRef(CalleeDRE)) {
if (IsFuncTypeErased) {
DpctGlobalInfo::setCVersionCUDALaunchUsed();
}
std::string ReplStr;
llvm::raw_string_ostream OS(ReplStr);
if (IsAssigned) {
Expand All @@ -4797,7 +4841,7 @@ void KernelCallRule::runRule(
for (size_t i = 0; i < ArgsNum; i++) {
if (auto Arg = LaunchKernelCall->getArg(i)) {
if (i == 0) {
if (auto E = getAddressedRef(CalleeDRE, nullptr, false)) {
if (auto E = getAddressedRef(CalleeDRE, false, nullptr)) {
OS << ExprAnalysis::ref(E);
} else {
OS << ExprAnalysis::ref(Arg);
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ class KernelCallRule : public NamedMigrationRule<KernelCallRule> {
};

class KernelCallRefRule : public NamedMigrationRule<KernelCallRefRule> {
std::string getTypeRepl(const Expr *E);
template <typename T>
void insertWrapperPostfix(const T *Node, std::string &&TypeRepl,
bool isInsertWrapperRegister);

public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
Expand Down
14 changes: 7 additions & 7 deletions clang/lib/DPCT/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4998,8 +4998,8 @@ int isArgumentInitialized(

return DeclsRequireInit.empty();
}
const Expr *getAddressedRef(const Expr *E, const FunctionDecl **FuncDecl,
bool IsCheckFunctionDecl) {
const Expr *getAddressedRef(const Expr *E, bool IsCheckFunctionDecl,
const FunctionDecl **FuncDecl) {
E = E->IgnoreImplicitAsWritten();
if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
if (IsCheckFunctionDecl) {
Expand Down Expand Up @@ -5033,17 +5033,17 @@ const Expr *getAddressedRef(const Expr *E, const FunctionDecl **FuncDecl,
return ULE;
}
} else if (auto Paren = dyn_cast<ParenExpr>(E)) {
return getAddressedRef(Paren->getSubExpr(), FuncDecl, IsCheckFunctionDecl);
return getAddressedRef(Paren->getSubExpr(), IsCheckFunctionDecl, FuncDecl);
} else if (auto Cast = dyn_cast<CastExpr>(E)) {
return getAddressedRef(Cast->getSubExprAsWritten(), FuncDecl,
IsCheckFunctionDecl);
return getAddressedRef(Cast->getSubExprAsWritten(), IsCheckFunctionDecl,
FuncDecl);
} else if (auto UO = dyn_cast<UnaryOperator>(E)) {
if (UO->getOpcode() == UO_AddrOf) {
return getAddressedRef(UO->getSubExpr(), FuncDecl, IsCheckFunctionDecl);
return getAddressedRef(UO->getSubExpr(), IsCheckFunctionDecl, FuncDecl);
}
} else if (auto COC = dyn_cast<CXXOperatorCallExpr>(E)) {
if (COC->getOperator() == clang::OO_Amp) {
return getAddressedRef(COC->getArg(0), FuncDecl, IsCheckFunctionDecl);
return getAddressedRef(COC->getArg(0), IsCheckFunctionDecl, FuncDecl);
}
}
if (FuncDecl) {
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/DPCT/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,8 @@ findTheOuterMostCompoundStmtUntilMeetControlFlowNodes(
const clang::NamedDecl *getNamedDecl(const clang::Type *TypePtr);
const clang::LambdaExpr *
getImmediateOuterLambdaExpr(const clang::FunctionDecl *FuncDecl);
const Expr *getAddressedRef(const Expr *E,
const FunctionDecl **FuncDecl = nullptr,
bool IsCheckFunctionDecl = true);
const Expr *getAddressedRef(const Expr *E, bool IsCheckFunctionDecl = true,
const FunctionDecl **FuncDecl = nullptr);
const clang::FunctionDecl *findTheOuterMostFunctionDecl(const clang::Decl *D);

// Source Range & location, offset.
Expand Down
Loading

0 comments on commit 3cf4e2c

Please sign in to comment.