From 8328c541dfeffc9585657d86bde02eafdbfb4901 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 2 Jun 2024 20:53:59 -0400 Subject: [PATCH] compiler: Support record spread Thsi feels even messier than in the interpreter but I am not really sure how else to do this. Maybe with some better static analysis that computes record layouts at compile-time. --- compiler.py | 12 +++++++++++- compiler_tests.py | 9 +++++++-- runtime.c | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/compiler.py b/compiler.py index b2116e5d..c1eec992 100755 --- a/compiler.py +++ b/compiler.py @@ -220,9 +220,19 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En if isinstance(pattern, Record): self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") updates = {} + seen_key_indices: list[int] = [] for key, pattern_value in pattern.data.items(): - assert not isinstance(pattern_value, Spread), "record spread not yet supported" + if isinstance(pattern_value, Spread): + use_spread = True + if pattern_value.name: + num_seen_keys = len(seen_key_indices) + self._emit( + f"size_t seen_keys[{num_seen_keys}] = {{ {', '.join(map(str, seen_key_indices))} }};" + ) + updates[pattern_value.name] = self._mktemp(f"record_rest({arg}, seen_keys, {num_seen_keys})") + break key_idx = self.record_key(key) + seen_key_indices.append(key_idx) record_value = self._mktemp(f"record_get({arg}, {key_idx})") self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") updates.update(self.try_match(env, record_value, pattern_value, fallthrough)) diff --git a/compiler_tests.py b/compiler_tests.py index d4f7823f..7900e9fc 100644 --- a/compiler_tests.py +++ b/compiler_tests.py @@ -52,12 +52,17 @@ def test_match_list(self) -> None: def test_match_list_spread(self) -> None: self.assertEqual(self._run("f [4, 5] . f = | [_, ...xs] -> xs"), "[5]\n") + def test_match_list_spread_empty(self) -> None: + self.assertEqual(self._run("f [4] . f = | [_, ...xs] -> xs"), "[]\n") + def test_match_record(self) -> None: self.assertEqual(self._run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n") - @unittest.skip("TODO") def test_match_record_spread(self) -> None: - self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "[5]\n") + self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "{b = 2, c = 3}\n") + + def test_match_record_spread_empty(self) -> None: + self.assertEqual(self._run("f {a=1} . f = | {a=1, ...rest} -> rest"), "{}\n") def test_match_hole(self) -> None: self.assertEqual(self._run("f () . f = | 1 -> 3 | () -> 4"), "4\n") diff --git a/runtime.c b/runtime.c index d09a25fe..b0748bad 100644 --- a/runtime.c +++ b/runtime.c @@ -471,6 +471,10 @@ struct object* record_get(struct object* record, size_t key) { return NULL; } +size_t record_num_fields(struct object* record) { + return as_record(record)->size; +} + bool is_string(struct object* obj) { if (is_small_string(obj)) { return true; @@ -603,6 +607,34 @@ struct object* list_cons(struct object* item, struct object* list) { return result; } +bool array_contains(size_t* haystack, size_t size, size_t needle) { + for (size_t i = 0; i < size; i++) { + if (haystack[i] == needle) { + return true; + } + } + return false; +} + +struct object* record_rest(struct object* record, size_t* exclude, + size_t num_excluded) { + // NB: This is used in a match expression so it is assumed that all of the + // key indices in the exclude array are present in the record and that there + // are no duplicates in either the record or the exclude array. + HANDLES(); + GC_PROTECT(record); + size_t num_keys = record_num_fields(record); + size_t num_result_keys = num_keys - num_excluded; + struct object* result = mkrecord(heap, num_result_keys); + for (size_t src = 0, dst = 0; dst < num_result_keys; src++) { + struct record_field field = as_record(record)->fields[src]; + if (!array_contains(exclude, num_excluded, field.key)) { + record_set(result, dst++, field); + } + } + return result; +} + struct object* heap_string_concat(struct object* a, struct object* b) { uword a_size = string_length(a); uword b_size = string_length(b);