From 9554921b42212680981a9bf7ed5f74bd94cd082a Mon Sep 17 00:00:00 2001 From: Anton Sviridov Date: Wed, 10 Jan 2024 21:27:11 +0000 Subject: [PATCH] Don't render `all` package and exports in multi-file mode (#258) --- .../src/main/scala/render/TypeImports.scala | 19 +++-- .../src/main/scala/render/binding.scala | 74 +++++++++---------- .../src/main/scala/render/function.scala | 22 +++++- .../basic/multi-file/src/main/scala/run.scala | 2 +- .../multi-file/src/test/scala/MyTest.scala | 2 +- .../test/resources/scala-native/multi_file.c | 4 + .../test/resources/scala-native/multi_file.h | 2 + .../src/test/scalajvm/TestInterface.scala | 5 -- 8 files changed, 71 insertions(+), 59 deletions(-) diff --git a/modules/bindgen/src/main/scala/render/TypeImports.scala b/modules/bindgen/src/main/scala/render/TypeImports.scala index f92830a0..5370caf9 100644 --- a/modules/bindgen/src/main/scala/render/TypeImports.scala +++ b/modules/bindgen/src/main/scala/render/TypeImports.scala @@ -11,17 +11,16 @@ case class TypeImports( structs: Boolean, unions: Boolean ): - def render(out: LineBuilder, multiFile: Boolean)(using Config, Context) = - var any = false - val imp = (s: String) => - any = true - to(out)(s"import _root_.$packageName.$s.*") + def render(out: LineBuilder)(using Config, Context) = + + val addImport = (s: String) => to(out)(s"import _root_.$packageName.$s.*") + if enums then - imp("enumerations") - if !multiFile then imp("predef") - if aliases then imp("aliases") - if structs then imp("structs") - if unions then imp("unions") + addImport("enumerations") + addImport("predef") + if aliases then addImport("aliases") + if structs then addImport("structs") + if unions then addImport("unions") end render end TypeImports diff --git a/modules/bindgen/src/main/scala/render/binding.scala b/modules/bindgen/src/main/scala/render/binding.scala index 80ffa44f..50f53741 100644 --- a/modules/bindgen/src/main/scala/render/binding.scala +++ b/modules/bindgen/src/main/scala/render/binding.scala @@ -97,7 +97,6 @@ def renderBinding( def create(name: String)(subPackage: String = name) = val lb = LineBuilder() lb.appendLine(s"package $packageName") - if multiFileMode then lb.appendLine(s"package $subPackage") lb.emptyLine lb.append(""" |import _root_.scala.scalanative.unsafe.* @@ -233,14 +232,8 @@ def renderBinding( } end if - if multiFileMode then - val byType: Map[String, List[(String, String)]] = - exports.result().groupBy(_._1) - - byType.toList.sortBy(_._1).foreach { (exportType, results) => - renderExports(stream(s"all.$exportType", "all"), results, renderMode) - } - else renderExports(simpleStream(s"all"), exports.result(), renderMode) + if !multiFileMode then + renderExports(simpleStream(s"all"), exports.result(), renderMode) if multiFileMode then RenderedOutput.Multi(multi.toMap) else if summon[Context].lang == Lang.C then RenderedOutput.Single(cOutput) @@ -261,12 +254,10 @@ private def renderAliases( typeImports: TypeImports )(using Config, AliasResolver, Context) = val exported = List.newBuilder[Exported] - if mode == RenderMode.Files then typeImports.render(out, multiFile = true) if mode == RenderMode.Objects then out.appendLine("object aliases:") nestIf(mode == RenderMode.Objects) { - if mode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + if mode == RenderMode.Objects then typeImports.render(out) exported ++= renderAll(aliases, out, alias) } exported.result() @@ -301,11 +292,11 @@ private def renderUnions( typeImports: TypeImports )(using Config, AliasResolver, Context) = val exported = List.newBuilder[Exported] - if mode == RenderMode.Files then typeImports.render(out, multiFile = true) + if mode == RenderMode.Objects then out.appendLine("object unions:") + nestIf(mode == RenderMode.Objects) { - if mode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + if mode == RenderMode.Objects then typeImports.render(out) exported ++= renderAll(unions, out, union) } exported.result() @@ -318,12 +309,10 @@ private def renderStructs( typeImports: TypeImports )(using Config, AliasResolver, Context) = val exported = List.newBuilder[Exported] - if mode == RenderMode.Files then typeImports.render(out, multiFile = true) if mode == RenderMode.Objects then out.appendLine("object structs:") nestIf(mode == RenderMode.Objects) { - if mode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + if mode == RenderMode.Objects then typeImports.render(out) exported ++= renderAll(structs, out, struct) } exported.result() @@ -452,29 +441,33 @@ private def renderScalaFunctions( f } + val safePackageName = packageName.split('.').last + val hasExternFunctions = scalaExternFunctions.nonEmpty val hasRegularFunctions = scalaRegularFunctions.nonEmpty if functions.nonEmpty then if exportMode == ExportMode.No then - if renderMode == RenderMode.Files then - typeImports.render(out, multiFile = true) if hasExternFunctions then - summon[Config].linkName.foreach { l => - out.append(s"""@link("$l")""") - } - val safePackageName = packageName.split('.').last - out.appendLine( - s"\n@extern\nprivate[$safePackageName] object extern_functions:" - ) - nest { + if renderMode == RenderMode.Objects then + summon[Config].linkName.foreach { l => + out.append(s"""@link("$l")""") + } + end if + + nestIf(renderMode == RenderMode.Objects) { if renderMode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + out.appendLine( + s"\n@extern\nprivate[$safePackageName] object extern_functions:" + ) + typeImports.render(out) + else out.appendLine("\n") + exported ++= renderAll( scalaExternFunctions.toList.sortBy(_.name), out, - renderFunction + renderFunction(_, _, renderMode) ) } end if @@ -483,10 +476,9 @@ private def renderScalaFunctions( if renderMode == RenderMode.Objects then out.appendLine(s"\nobject functions:") nestIf(renderMode == RenderMode.Objects) { - if renderMode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + if renderMode == RenderMode.Objects then typeImports.render(out) - if hasExternFunctions then + if hasExternFunctions && renderMode == RenderMode.Objects then to(out)("import extern_functions.*") to(out)("export extern_functions.*") out.emptyLine @@ -494,7 +486,7 @@ private def renderScalaFunctions( exported ++= renderAll( scalaRegularFunctions.toList.sortBy(_.name), out, - renderFunction + renderFunction(_, _, renderMode) ) } end if @@ -519,22 +511,24 @@ private def renderScalaFunctions( out.appendLine("trait ExportedFunctions:") nest { - if renderMode == RenderMode.Objects then - typeImports.render(out, multiFile = false) - renderAll(modified(ExportLocation.Trait), out, renderFunction) + if renderMode == RenderMode.Objects then typeImports.render(out) + renderAll( + modified(ExportLocation.Trait), + out, + renderFunction(_, _, renderMode) + ) } if renderMode == RenderMode.Objects then out.appendLine(s"\nobject functions extends ExportedFunctions:") nestIf(renderMode == RenderMode.Objects) { - if renderMode == RenderMode.Objects then - typeImports.render(out, multiFile = false) + if renderMode == RenderMode.Objects then typeImports.render(out) renderAll( modified( ExportLocation.Body(summon[Context].packageName.map(_ + ".impl")) ), out, - renderFunction + renderFunction(_, _, renderMode) ) } diff --git a/modules/bindgen/src/main/scala/render/function.scala b/modules/bindgen/src/main/scala/render/function.scala index 16adf99f..6c0e34e0 100644 --- a/modules/bindgen/src/main/scala/render/function.scala +++ b/modules/bindgen/src/main/scala/render/function.scala @@ -4,7 +4,11 @@ package rendering import bindgen.* import scala.collection.mutable.ListBuffer -def renderFunction(f: GeneratedFunction.ScalaFunction, line: Appender)(using +def renderFunction( + f: GeneratedFunction.ScalaFunction, + line: Appender, + mode: RenderMode +)(using Config, AliasResolver, Context @@ -45,10 +49,24 @@ def renderFunction(f: GeneratedFunction.ScalaFunction, line: Appender)(using else Exported.No case ScalaFunctionBody.Extern => + val linkAnnotation = Option + .when(f.public && mode == RenderMode.Files)( + summon[Config].linkName + .map { l => + s"""@link("$l") """ + } + ) + .flatten + .getOrElse("") + + val externAnnotation = + Option.when(mode == RenderMode.Files)("@extern ").getOrElse("") + line( - s"${access}def ${f.name}$arglist: ${scalaType(f.returnType)} = extern" + s"$externAnnotation$linkAnnotation${access}def ${f.name}$arglist: ${scalaType(f.returnType)} = extern" ) if f.public then Exported.Yes(f.name.value) else Exported.No + case ScalaFunctionBody.Delegate( to, a @ Allocations(indices, returnAsWell) diff --git a/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/main/scala/run.scala b/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/main/scala/run.scala index d782e516..41b63e65 100644 --- a/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/main/scala/run.scala +++ b/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/main/scala/run.scala @@ -1,4 +1,4 @@ -import bindings.functions.* +import bindings.* @main def test = assert(hello(25, 0.5f) == true) diff --git a/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/test/scala/MyTest.scala b/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/test/scala/MyTest.scala index 51319ff9..34a6f824 100644 --- a/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/test/scala/MyTest.scala +++ b/modules/sbt-plugin/src/sbt-test/basic/multi-file/src/test/scala/MyTest.scala @@ -1,7 +1,7 @@ import org.junit.Assert.* import org.junit.Test -import testbindings.functions.* +import testbindings.* class MyTest { @Test def superComplicatedTest(): Unit = { diff --git a/modules/tests/src/test/resources/scala-native/multi_file.c b/modules/tests/src/test/resources/scala-native/multi_file.c index 66c38bd0..0b50b8c3 100644 --- a/modules/tests/src/test/resources/scala-native/multi_file.c +++ b/modules/tests/src/test/resources/scala-native/multi_file.c @@ -3,3 +3,7 @@ unsigned run(int i, float h, HelloAlias test, union Test verify) { return 42; }; void naughty(Hello st){}; void nice(char st){}; + +int test_varargs(int i, ...) { + return i; +} diff --git a/modules/tests/src/test/resources/scala-native/multi_file.h b/modules/tests/src/test/resources/scala-native/multi_file.h index 29885a77..96a26072 100644 --- a/modules/tests/src/test/resources/scala-native/multi_file.h +++ b/modules/tests/src/test/resources/scala-native/multi_file.h @@ -17,3 +17,5 @@ unsigned run(int i, float h, HelloAlias test, union Test verify); void naughty(Hello st); void nice(char st); enum { Constant1, Constant2 }; + +int test_varargs(int i, ...); diff --git a/modules/tests/src/test/scalajvm/TestInterface.scala b/modules/tests/src/test/scalajvm/TestInterface.scala index 916a8b84..dc53438f 100644 --- a/modules/tests/src/test/scalajvm/TestInterface.scala +++ b/modules/tests/src/test/scalajvm/TestInterface.scala @@ -164,11 +164,6 @@ class TestInterface { probe.scalaFiles / "lib_check" / "aliases.scala", probe.scalaFiles / "lib_check" / "structs.scala", probe.scalaFiles / "lib_check" / "functions.scala", - probe.scalaFiles / "lib_check" / "all.unions.scala", - probe.scalaFiles / "lib_check" / "all.structs.scala", - probe.scalaFiles / "lib_check" / "all.functions.scala", - probe.scalaFiles / "lib_check" / "all.aliases.scala", - probe.scalaFiles / "lib_check" / "all.enumerations.scala", probe.scalaFiles / "lib_check" / "unions.scala" ), allFilesMultiScala.toSet