diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h index 14c51cf23a5a..cd35aed144aa 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h @@ -10,6 +10,7 @@ #define CALL_EXPR_REWRITER_H #include "Diagnostics/Diagnostics.h" +#include "RuleInfra/ExprAnalysis.h" namespace clang { namespace dpct { @@ -1676,26 +1677,14 @@ class UserDefinedRewriter : public CallExprRewriter { case (OutputBuilder::Kind::MethodBase): { if (auto *MCE = llvm::dyn_cast(Call)) { if (auto *Callee = llvm::dyn_cast(MCE->getCallee())) { - auto &SM = DpctGlobalInfo::getSourceManager(); - auto &Context = DpctGlobalInfo::getContext(); - auto CallRange = - getDefinitionRange(Call->getBeginLoc(), Call->getEndLoc()); - auto LastTokenLength = Lexer::MeasureTokenLength( - CallRange.getEnd(), SM, Context.getLangOpts()); - auto Base = Callee->getBase(); - std::string BaseStr; - if (isa(Callee->getBase())) { - BaseStr = "this"; - } else { - BaseStr = getStringInRange( - Base->getSourceRange(), CallRange.getBegin(), - CallRange.getEnd().getLocWithOffset(LastTokenLength)); - } - OS << BaseStr; - if (Callee->isArrow()) { - OS << "->"; - } else { - OS << "."; + const Expr *Base = Callee->getBase(); + if (!isa(Base)) { + ExprAnalysis EA(Base); + OS << EA.getReplacedString(); + if (Callee->isArrow()) + OS << "->"; + else + OS << "."; } } } diff --git a/clang/test/dpct/user_defined_rule.cu b/clang/test/dpct/user_defined_rule.cu index 8879f8322803..1573e3277194 100644 --- a/clang/test/dpct/user_defined_rule.cu +++ b/clang/test/dpct/user_defined_rule.cu @@ -50,11 +50,13 @@ public: int fieldC; int methodA(int i, int j){return 0;}; }; +ClassA getClassA() { return ClassA(); }; class ClassB{ public: int fieldB; int methodB(int i){return 0;}; }; +ClassB getClassB() { return ClassB(); }; enum Fruit{ apple, @@ -80,12 +82,14 @@ void foo2(){ //CHECK: ClassB a; //CHECK-NEXT: a.fieldD = 3; //CHECK-NEXT: a.methodB(2); + //CHECK-NEXT: getClassB().methodB(2); //CHECK-NEXT: a.set_a(3); //CHECK-NEXT: int k = a.get_a(); //CHECK-NEXT: Fruit f = pineapple; ClassA a; a.fieldC = 3; a.methodA(1,2); + getClassA().methodA(1,2); a.fieldA = 3; int k = a.fieldA; Fruit f = Fruit::apple; diff --git a/clang/test/dpct/user_defined_rule.yaml b/clang/test/dpct/user_defined_rule.yaml index f54a7f69da2d..2e5f6c29659f 100644 --- a/clang/test/dpct/user_defined_rule.yaml +++ b/clang/test/dpct/user_defined_rule.yaml @@ -68,6 +68,11 @@ Out: $method_base methodB($2) - In: methodC Out: methodD +- Rule: rule_getClassA + Kind: API + Priority: Takeover + In: getClassA + Out: getClassB() - Rule: rule_Fruit Kind: Enum Priority: Takeover