Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fromPtr/toPtr helper methods for C function pointer aliases #240

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions modules/bindgen/src/main/scala/render/alias.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def alias(model: Def.Alias, line: Appender)(using
case _: Reference | _: Function | Void => false
case _ => true

val isFunctionPointer = underlyingType match
case Pointer(f: Function) => true
case _ => false

val modifier = if isOpaque then "opaque " else ""
renderComment(line, model.meta)
line(s"${modifier}type ${model.name} = ${scalaType(underlyingType)}")
Expand All @@ -31,14 +35,25 @@ def alias(model: Def.Alias, line: Appender)(using
line(s"val _tag: Tag[${model.name}] = summon[Tag[${name.full}]]")
case _ =>
line(s"given _tag: Tag[${model.name}] = ${scalaTag(underlyingType)}")

if isFunctionPointer then
line(
s"inline def fromPtr(ptr: Ptr[Byte]): ${model.name} = CFuncPtr.fromPtr(ptr)"
)
end if

if enableConstructor then
line(
s"inline def apply(inline o: ${scalaType(underlyingType)}): ${model.name} = o"
)
line(s"extension (v: ${model.name})")
nest {
line(s"inline def value: ${scalaType(underlyingType)} = v")
if isFunctionPointer then
line(s"inline def toPtr: Ptr[Byte] = CFuncPtr.toPtr(v)")
end if
}
end if
}

Exported.Yes(model.name)
Expand Down
12 changes: 12 additions & 0 deletions modules/tests/src/test/resources/scala-native/aliases.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "aliases.h"

int higher_order_function(int n, TestFunctionPointer handler,
struct TestStruct *container) {
int sum = 0;

for (int i = 0; i < n; i++) {
sum += handler(container);
}

return sum;
}
10 changes: 9 additions & 1 deletion modules/tests/src/test/resources/scala-native/aliases.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ typedef struct {
AliasesRef *field2;
} TestAliases;


typedef int hello_alias;
typedef hello_alias alias_of_an_alias;

struct TestStruct {
int ne;
};

typedef int (*TestFunctionPointer)(struct TestStruct *container);

int higher_order_function(int n, TestFunctionPointer handler,
struct TestStruct *container);
29 changes: 27 additions & 2 deletions modules/tests/src/test/scalanative/TestAliases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import scala.scalanative.unsigned.*

class TestAliases:
import lib_test_aliases.types.*
import lib_test_aliases.functions.*

@Test def test_basics() =
zone {
zone:
assertEquals(42069, hello_alias(42069).value)

val x: Int = hello_alias(42069).value
Expand All @@ -25,5 +26,29 @@ class TestAliases:

assertEquals(t.value, hello_alias(25).value)
assertEquals(t.value, 25)
}

@Test def test_function_pointers() =
zone:
val test = TestStruct(10)

val square = TestFunctionPointer: (struct: Ptr[TestStruct]) =>
(!struct)._ne * (!struct)._ne

// sanity check first
assertEquals(500, higher_order_function(5, square, test))

val ptr: Ptr[Byte] = CFuncPtr.toPtr(square.value)
val squareReconstructed =
TestFunctionPointer.fromPtr(ptr)

// now check reconstructed function is still valid
assertEquals(500, higher_order_function(5, squareReconstructed, test))

// test toPtr/fromPtr pair
val roundTrip =
TestFunctionPointer.fromPtr(square.toPtr)

// now check reconstructed function is still valid
assertEquals(500, higher_order_function(5, roundTrip, test))

end TestAliases
Loading