Skip to content

Commit

Permalink
Don't render all package and exports in multi-file mode (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
keynmol authored Jan 10, 2024
1 parent e56d391 commit 9554921
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 59 deletions.
19 changes: 9 additions & 10 deletions modules/bindgen/src/main/scala/render/TypeImports.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@ case class TypeImports(
structs: Boolean,
unions: Boolean
):
def render(out: LineBuilder, multiFile: Boolean)(using Config, Context) =
var any = false
val imp = (s: String) =>
any = true
to(out)(s"import _root_.$packageName.$s.*")
def render(out: LineBuilder)(using Config, Context) =

val addImport = (s: String) => to(out)(s"import _root_.$packageName.$s.*")

if enums then
imp("enumerations")
if !multiFile then imp("predef")
if aliases then imp("aliases")
if structs then imp("structs")
if unions then imp("unions")
addImport("enumerations")
addImport("predef")
if aliases then addImport("aliases")
if structs then addImport("structs")
if unions then addImport("unions")

end render
end TypeImports
74 changes: 34 additions & 40 deletions modules/bindgen/src/main/scala/render/binding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def renderBinding(
def create(name: String)(subPackage: String = name) =
val lb = LineBuilder()
lb.appendLine(s"package $packageName")
if multiFileMode then lb.appendLine(s"package $subPackage")
lb.emptyLine
lb.append("""
|import _root_.scala.scalanative.unsafe.*
Expand Down Expand Up @@ -233,14 +232,8 @@ def renderBinding(
}
end if

if multiFileMode then
val byType: Map[String, List[(String, String)]] =
exports.result().groupBy(_._1)

byType.toList.sortBy(_._1).foreach { (exportType, results) =>
renderExports(stream(s"all.$exportType", "all"), results, renderMode)
}
else renderExports(simpleStream(s"all"), exports.result(), renderMode)
if !multiFileMode then
renderExports(simpleStream(s"all"), exports.result(), renderMode)

if multiFileMode then RenderedOutput.Multi(multi.toMap)
else if summon[Context].lang == Lang.C then RenderedOutput.Single(cOutput)
Expand All @@ -261,12 +254,10 @@ private def renderAliases(
typeImports: TypeImports
)(using Config, AliasResolver, Context) =
val exported = List.newBuilder[Exported]
if mode == RenderMode.Files then typeImports.render(out, multiFile = true)
if mode == RenderMode.Objects then out.appendLine("object aliases:")

nestIf(mode == RenderMode.Objects) {
if mode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
if mode == RenderMode.Objects then typeImports.render(out)
exported ++= renderAll(aliases, out, alias)
}
exported.result()
Expand Down Expand Up @@ -301,11 +292,11 @@ private def renderUnions(
typeImports: TypeImports
)(using Config, AliasResolver, Context) =
val exported = List.newBuilder[Exported]
if mode == RenderMode.Files then typeImports.render(out, multiFile = true)

if mode == RenderMode.Objects then out.appendLine("object unions:")

nestIf(mode == RenderMode.Objects) {
if mode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
if mode == RenderMode.Objects then typeImports.render(out)
exported ++= renderAll(unions, out, union)
}
exported.result()
Expand All @@ -318,12 +309,10 @@ private def renderStructs(
typeImports: TypeImports
)(using Config, AliasResolver, Context) =
val exported = List.newBuilder[Exported]
if mode == RenderMode.Files then typeImports.render(out, multiFile = true)
if mode == RenderMode.Objects then out.appendLine("object structs:")

nestIf(mode == RenderMode.Objects) {
if mode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
if mode == RenderMode.Objects then typeImports.render(out)
exported ++= renderAll(structs, out, struct)
}
exported.result()
Expand Down Expand Up @@ -452,29 +441,33 @@ private def renderScalaFunctions(
f
}

val safePackageName = packageName.split('.').last

val hasExternFunctions = scalaExternFunctions.nonEmpty
val hasRegularFunctions = scalaRegularFunctions.nonEmpty

if functions.nonEmpty then
if exportMode == ExportMode.No then
if renderMode == RenderMode.Files then
typeImports.render(out, multiFile = true)

if hasExternFunctions then
summon[Config].linkName.foreach { l =>
out.append(s"""@link("$l")""")
}
val safePackageName = packageName.split('.').last
out.appendLine(
s"\n@extern\nprivate[$safePackageName] object extern_functions:"
)
nest {
if renderMode == RenderMode.Objects then
summon[Config].linkName.foreach { l =>
out.append(s"""@link("$l")""")
}
end if

nestIf(renderMode == RenderMode.Objects) {
if renderMode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
out.appendLine(
s"\n@extern\nprivate[$safePackageName] object extern_functions:"
)
typeImports.render(out)
else out.appendLine("\n")

exported ++= renderAll(
scalaExternFunctions.toList.sortBy(_.name),
out,
renderFunction
renderFunction(_, _, renderMode)
)
}
end if
Expand All @@ -483,18 +476,17 @@ private def renderScalaFunctions(
if renderMode == RenderMode.Objects then
out.appendLine(s"\nobject functions:")
nestIf(renderMode == RenderMode.Objects) {
if renderMode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
if renderMode == RenderMode.Objects then typeImports.render(out)

if hasExternFunctions then
if hasExternFunctions && renderMode == RenderMode.Objects then
to(out)("import extern_functions.*")
to(out)("export extern_functions.*")
out.emptyLine

exported ++= renderAll(
scalaRegularFunctions.toList.sortBy(_.name),
out,
renderFunction
renderFunction(_, _, renderMode)
)
}
end if
Expand All @@ -519,22 +511,24 @@ private def renderScalaFunctions(

out.appendLine("trait ExportedFunctions:")
nest {
if renderMode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
renderAll(modified(ExportLocation.Trait), out, renderFunction)
if renderMode == RenderMode.Objects then typeImports.render(out)
renderAll(
modified(ExportLocation.Trait),
out,
renderFunction(_, _, renderMode)
)
}

if renderMode == RenderMode.Objects then
out.appendLine(s"\nobject functions extends ExportedFunctions:")
nestIf(renderMode == RenderMode.Objects) {
if renderMode == RenderMode.Objects then
typeImports.render(out, multiFile = false)
if renderMode == RenderMode.Objects then typeImports.render(out)
renderAll(
modified(
ExportLocation.Body(summon[Context].packageName.map(_ + ".impl"))
),
out,
renderFunction
renderFunction(_, _, renderMode)
)
}

Expand Down
22 changes: 20 additions & 2 deletions modules/bindgen/src/main/scala/render/function.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ package rendering
import bindgen.*
import scala.collection.mutable.ListBuffer

def renderFunction(f: GeneratedFunction.ScalaFunction, line: Appender)(using
def renderFunction(
f: GeneratedFunction.ScalaFunction,
line: Appender,
mode: RenderMode
)(using
Config,
AliasResolver,
Context
Expand Down Expand Up @@ -45,10 +49,24 @@ def renderFunction(f: GeneratedFunction.ScalaFunction, line: Appender)(using
else Exported.No

case ScalaFunctionBody.Extern =>
val linkAnnotation = Option
.when(f.public && mode == RenderMode.Files)(
summon[Config].linkName
.map { l =>
s"""@link("$l") """
}
)
.flatten
.getOrElse("")

val externAnnotation =
Option.when(mode == RenderMode.Files)("@extern ").getOrElse("")

line(
s"${access}def ${f.name}$arglist: ${scalaType(f.returnType)} = extern"
s"$externAnnotation$linkAnnotation${access}def ${f.name}$arglist: ${scalaType(f.returnType)} = extern"
)
if f.public then Exported.Yes(f.name.value) else Exported.No

case ScalaFunctionBody.Delegate(
to,
a @ Allocations(indices, returnAsWell)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import bindings.functions.*
import bindings.*

@main def test =
assert(hello(25, 0.5f) == true)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import org.junit.Assert.*
import org.junit.Test

import testbindings.functions.*
import testbindings.*

class MyTest {
@Test def superComplicatedTest(): Unit = {
Expand Down
4 changes: 4 additions & 0 deletions modules/tests/src/test/resources/scala-native/multi_file.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
unsigned run(int i, float h, HelloAlias test, union Test verify) { return 42; };
void naughty(Hello st){};
void nice(char st){};

int test_varargs(int i, ...) {
return i;
}
2 changes: 2 additions & 0 deletions modules/tests/src/test/resources/scala-native/multi_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ unsigned run(int i, float h, HelloAlias test, union Test verify);
void naughty(Hello st);
void nice(char st);
enum { Constant1, Constant2 };

int test_varargs(int i, ...);
5 changes: 0 additions & 5 deletions modules/tests/src/test/scalajvm/TestInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,6 @@ class TestInterface {
probe.scalaFiles / "lib_check" / "aliases.scala",
probe.scalaFiles / "lib_check" / "structs.scala",
probe.scalaFiles / "lib_check" / "functions.scala",
probe.scalaFiles / "lib_check" / "all.unions.scala",
probe.scalaFiles / "lib_check" / "all.structs.scala",
probe.scalaFiles / "lib_check" / "all.functions.scala",
probe.scalaFiles / "lib_check" / "all.aliases.scala",
probe.scalaFiles / "lib_check" / "all.enumerations.scala",
probe.scalaFiles / "lib_check" / "unions.scala"
),
allFilesMultiScala.toSet
Expand Down

0 comments on commit 9554921

Please sign in to comment.