Skip to content

Commit

Permalink
Adding plugin to migrate scalatest
Browse files Browse the repository at this point in the history
  • Loading branch information
ketkarameya committed Aug 14, 2023
1 parent b16fa19 commit 5b18385
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 0 deletions.
23 changes: 23 additions & 0 deletions plugins/scala_test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import argparse
from update_imports import update_imports




def _parse_args():
parser = argparse.ArgumentParser(description="Migrates scala tests!!!")
parser.add_argument(
"--path_to_codebase",
required=True,
help="Path to the codebase directory.",
)

args = parser.parse_args()
return args

def main():
args = _parse_args()
update_imports(args.path_to_codebase, dry_run=True)

if __name__ == "__main__":
main()
83 changes: 83 additions & 0 deletions plugins/scala_test/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from polyglot_piranha import Rule, OutgoingEdges, RuleGraph, PiranhaArguments, execute_piranha

def replace_imports(
target_new_types: dict[str, str], search_heuristic: str, path_to_codebase: str,
dry_run = False
):
find_relevant_files = Rule(
name="find_relevant_files",
query="((identifier) @x (#eq? @x \"@search_heuristic\"))",
holes={"search_heuristic"},
)
e1 = OutgoingEdges("find_relevant_files", to=[f"update_import"], scope="File")

rules = [find_relevant_files]
edges = [e1]

for target_type, new_type in target_new_types.items():
rs, es = replace_import_rules_edges(target_type, new_type)
rules.extend(rs)
edges.extend(es)

rule_graph = RuleGraph(rules=rules, edges=edges)

args= PiranhaArguments(
language="scala",
path_to_codebase=path_to_codebase,
rule_graph=rule_graph,
substitutions={"search_heuristic": f"{search_heuristic}"},
dry_run=dry_run
)

return execute_piranha(args)



def replace_import_rules_edges(
target_qualified_type_name: str, new_qualified_type_name: str
) -> (list[Rule], list[OutgoingEdges]):

name_components = target_qualified_type_name.split(".")
type_name = name_components[-1]

qualifier_predicate = "\n".join(
[f'(#match? @import_decl "{n}")' for n in name_components[:-1]]
)

delete_nested_import = Rule(
name=f"delete_nested_import_{type_name}",
query=f"""(
(import_declaration (namespace_selectors (_) @tn )) @import_decl
(#eq? @tn "{type_name}")
{qualifier_predicate}
)""",
replace_node="tn",
replace="",
is_seed_rule=False,
groups={"update_import"},
)

update_simple_import = Rule(
name=f"update_simple_import_{type_name}",
query=f"cs import {target_qualified_type_name}",
replace_node="*",
replace=f"import {new_qualified_type_name}",
is_seed_rule=False,
groups={"update_import"},
)

insert_import = Rule(
name=f"insert_import_{type_name}",
query="(import_declaration) @import_decl",
replace_node="import_decl",
replace=f"@import_decl\nimport {new_qualified_type_name}\n",
is_seed_rule=False,
)

e2 = OutgoingEdges(
f"delete_nested_import_{type_name}",
to=[f"insert_import_{type_name}"],
scope="Parent",
)

return [delete_nested_import, update_simple_import, insert_import], [e2]
Empty file.
8 changes: 8 additions & 0 deletions plugins/scala_test/tests/resources/expected/sample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.scala.piranha

import com.uber.michelangelo.AbstractSparkSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
7 changes: 7 additions & 0 deletions plugins/scala_test/tests/resources/input/sample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.scala.piranha

import com.uber.michelangelo.AbstractSparkSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.mock.MockitoSugar
40 changes: 40 additions & 0 deletions plugins/scala_test/tests/test_update_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from logging import debug, error
from pathlib import Path

from os.path import join, basename
from os import listdir

from update_imports import update_imports
# from update_imports import update_imports

def test_update_imports():
summary = update_imports("plugins/scala_test/tests/resources/input/", dry_run=True)
assert is_as_expected("plugins/scala_test/tests/resources/", summary)

def is_as_expected(path_to_scenario, output_summary):
expected_output = join(path_to_scenario, "expected")
print("Summary", output_summary)
input_dir = join(path_to_scenario, "input")
for file_name in listdir(expected_output):
with open(join(expected_output, file_name), "r") as f:
file_content = f.read()
expected_content = "".join(file_content.split())

# Search for the file in the output summary
updated_content = [
"".join(o.content.split())
for o in output_summary
if basename(o.path) == file_name
]
print(file_name)
# Check if the file was rewritten
if updated_content:
if expected_content != updated_content[0]:
error("----update" + updated_content[0] )
return False
else:
# The scenario where the file is not expected to be rewritten
original_content= Path(join(input_dir, file_name)).read_text()
if expected_content != "".join(original_content.split()):
return False
return True
15 changes: 15 additions & 0 deletions plugins/scala_test/update_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from recipes import replace_imports


IMPORT_MAPPING = {
"org.scalatest.Matchers": "org.scalatest.matchers.should.Matchers",
"org.scalatest.mock.MockitoSugar": "org.scalatestplus.mockito.MockitoSugar",
# Todo write test scenarios for these
"org.scalatest.FunSuite":"org.scalatest.funsuite.AnyFunSuite",
"org.scalatest.junit.JUnitRunner":"org.scalatestplus.junit.JUnitRunner",
"org.scalatest.FlatSpec": "org.scalatest.flatspec.AnyFlatSpec",
"org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit",
}

def update_imports(path_to_codebase: str, dry_run = False):
return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run)
39 changes: 39 additions & 0 deletions plugins/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from setuptools import find_packages, setup

setup(
name="scala_test",
version="0.0.1",
description="Rules to migrate `scaletest`",
# long_description=open("README.md").read(),
# long_description_content_type="text/markdown",
# url="https://github.com/uber/piranha",
packages=find_packages(),
include_package_data=True,
install_requires=[
# "polyglot-piranha",
"pytest",
],
entry_points={
"console_scripts": ["scala_test = scala_test.main:main"]
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
python_requires=">=3.9",
tests_require=["pytest"],
# Define the test suite
test_suite="tests",
)

0 comments on commit 5b18385

Please sign in to comment.