Skip to content

Commit

Permalink
Type every node field and mark on-error-only types explicitly
Browse files Browse the repository at this point in the history
* For Loader.java, do not deserialize the AST if there are errors, so then Java nodes only have non-error types for fields.
  • Loading branch information
eregon committed Sep 3, 2024
1 parent 163a265 commit 27ca419
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 57 deletions.
216 changes: 204 additions & 12 deletions config.yml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions rakelib/typecheck.rake
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace :typecheck do
--ignore=test/
--ignore=rakelib/
--ignore=Rakefile
--ignore=top-100-gems/
# Treat all files as "typed: true" by default
--typed=true
# Use the typed-override file to revert some files to "typed: false"
Expand Down
61 changes: 45 additions & 16 deletions rust/ruby-prism/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,27 @@ enum NodeFieldType {
Double,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct OnErrorType {
#[serde(rename = "on error")]
kind: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum UnionKind {
OnSuccess(String),
OnError(OnErrorType),
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum NodeFieldKind {
Concrete(String),
Union(Vec<String>),
Union(Vec<UnionKind>),
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -126,6 +142,13 @@ fn struct_name(name: &str) -> String {
result
}

fn kind_to_type(kind: &String) -> String {
match kind.as_str() {
"non-void expression" | "pattern expression" | "Node" => String::new(),
_ => kind.to_string(),
}
}

/// Returns the name of the C type from the given node name.
fn type_name(name: &str) -> String {
let mut result = String::with_capacity(8 + name.len());
Expand Down Expand Up @@ -263,30 +286,34 @@ fn write_node(file: &mut File, flags: &[Flags], node: &Node) -> Result<(), Box<d
writeln!(file, " #[must_use]")?;

match field.field_type {
NodeFieldType::Node => {
if let Some(NodeFieldKind::Concrete(kind)) = &field.kind {
NodeFieldType::Node => match &field.kind {
Some(NodeFieldKind::Concrete(raw_kind)) if !kind_to_type(raw_kind).is_empty() => {
let kind = kind_to_type(raw_kind);
writeln!(file, " pub fn {}(&self) -> {}<'pr> {{", field.name, kind)?;
writeln!(file, " let node: *mut pm{}_t = unsafe {{ (*self.pointer).{} }};", struct_name(kind), field.name)?;
writeln!(file, " let node: *mut pm{}_t = unsafe {{ (*self.pointer).{} }};", struct_name(&kind), field.name)?;
writeln!(file, " {} {{ parser: self.parser, pointer: node, marker: PhantomData }}", kind)?;
writeln!(file, " }}")?;
} else {
},
_ => {
writeln!(file, " pub fn {}(&self) -> Node<'pr> {{", field.name)?;
writeln!(file, " let node: *mut pm_node_t = unsafe {{ (*self.pointer).{} }};", field.name)?;
writeln!(file, " Node::new(self.parser, node)")?;
writeln!(file, " }}")?;
}
},
},
NodeFieldType::OptionalNode => {
if let Some(NodeFieldKind::Concrete(kind)) = &field.kind {
NodeFieldType::OptionalNode => match &field.kind {
Some(NodeFieldKind::Concrete(raw_kind)) if !kind_to_type(raw_kind).is_empty() => {
let kind = kind_to_type(raw_kind);
writeln!(file, " pub fn {}(&self) -> Option<{}<'pr>> {{", field.name, kind)?;
writeln!(file, " let node: *mut pm{}_t = unsafe {{ (*self.pointer).{} }};", struct_name(kind), field.name)?;
writeln!(file, " let node: *mut pm{}_t = unsafe {{ (*self.pointer).{} }};", struct_name(&kind), field.name)?;
writeln!(file, " if node.is_null() {{")?;
writeln!(file, " None")?;
writeln!(file, " }} else {{")?;
writeln!(file, " Some({} {{ parser: self.parser, pointer: node, marker: PhantomData }})", kind)?;
writeln!(file, " }}")?;
writeln!(file, " }}")?;
} else {
},
_ => {
writeln!(file, " pub fn {}(&self) -> Option<Node<'pr>> {{", field.name)?;
writeln!(file, " let node: *mut pm_node_t = unsafe {{ (*self.pointer).{} }};", field.name)?;
writeln!(file, " if node.is_null() {{")?;
Expand All @@ -295,7 +322,7 @@ fn write_node(file: &mut File, flags: &[Flags], node: &Node) -> Result<(), Box<d
writeln!(file, " Some(Node::new(self.parser, node))")?;
writeln!(file, " }}")?;
writeln!(file, " }}")?;
}
},
},
NodeFieldType::NodeList => {
writeln!(file, " pub fn {}(&self) -> NodeList<'pr> {{", field.name)?;
Expand Down Expand Up @@ -473,16 +500,18 @@ fn write_visit(file: &mut File, config: &Config) -> Result<(), Box<dyn std::erro
for field in &node.fields {
match field.field_type {
NodeFieldType::Node => {
if let Some(NodeFieldKind::Concrete(kind)) = &field.kind {
writeln!(file, " visitor.visit{}(&node.{}());", struct_name(kind), field.name)?;
if let Some(NodeFieldKind::Concrete(raw_kind)) = &field.kind {
let kind = kind_to_type(raw_kind);
writeln!(file, " visitor.visit{}(&node.{}());", struct_name(&kind), field.name)?;
} else {
writeln!(file, " visitor.visit(&node.{}());", field.name)?;
}
},
NodeFieldType::OptionalNode => {
if let Some(NodeFieldKind::Concrete(kind)) = &field.kind {
if let Some(NodeFieldKind::Concrete(raw_kind)) = &field.kind {
let kind = kind_to_type(raw_kind);
writeln!(file, " if let Some(node) = node.{}() {{", field.name)?;
writeln!(file, " visitor.visit{}(&node);", struct_name(kind))?;
writeln!(file, " visitor.visit{}(&node);", struct_name(&kind))?;
writeln!(file, " }}")?;
} else {
writeln!(file, " if let Some(node) = node.{}() {{", field.name)?;
Expand Down Expand Up @@ -762,7 +791,7 @@ impl TryInto<i32> for Integer<'_> {{
let length = unsafe {{ (*self.pointer).length }};
if length == 0 {{
i32::try_from(unsafe {{ (*self.pointer).value }}).map_or(Err(()), |value|
i32::try_from(unsafe {{ (*self.pointer).value }}).map_or(Err(()), |value|
if negative {{
Ok(-value)
}} else {{
Expand Down
33 changes: 24 additions & 9 deletions src/prism.c
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,7 @@ pm_index_and_write_node_create(pm_parser_t *parser, pm_call_node_t *target, cons

pm_index_arguments_check(parser, target->arguments, target->block);

assert(!target->block || PM_NODE_TYPE_P(target->block, PM_BLOCK_ARGUMENT_NODE));
*node = (pm_index_and_write_node_t) {
{
.type = PM_INDEX_AND_WRITE_NODE,
Expand All @@ -3000,7 +3001,7 @@ pm_index_and_write_node_create(pm_parser_t *parser, pm_call_node_t *target, cons
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.block = target->block,
.block = (pm_block_argument_node_t *) target->block,
.operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
};
Expand Down Expand Up @@ -3060,6 +3061,7 @@ pm_index_operator_write_node_create(pm_parser_t *parser, pm_call_node_t *target,

pm_index_arguments_check(parser, target->arguments, target->block);

assert(!target->block || PM_NODE_TYPE_P(target->block, PM_BLOCK_ARGUMENT_NODE));
*node = (pm_index_operator_write_node_t) {
{
.type = PM_INDEX_OPERATOR_WRITE_NODE,
Expand All @@ -3075,7 +3077,7 @@ pm_index_operator_write_node_create(pm_parser_t *parser, pm_call_node_t *target,
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.block = target->block,
.block = (pm_block_argument_node_t *) target->block,
.binary_operator = pm_parser_constant_id_location(parser, operator->start, operator->end - 1),
.binary_operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
Expand Down Expand Up @@ -3137,6 +3139,7 @@ pm_index_or_write_node_create(pm_parser_t *parser, pm_call_node_t *target, const

pm_index_arguments_check(parser, target->arguments, target->block);

assert(!target->block || PM_NODE_TYPE_P(target->block, PM_BLOCK_ARGUMENT_NODE));
*node = (pm_index_or_write_node_t) {
{
.type = PM_INDEX_OR_WRITE_NODE,
Expand All @@ -3152,7 +3155,7 @@ pm_index_or_write_node_create(pm_parser_t *parser, pm_call_node_t *target, const
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.block = target->block,
.block = (pm_block_argument_node_t *) target->block,
.operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
};
Expand Down Expand Up @@ -3205,6 +3208,7 @@ pm_index_target_node_create(pm_parser_t *parser, pm_call_node_t *target) {

pm_index_arguments_check(parser, target->arguments, target->block);

assert(!target->block || PM_NODE_TYPE_P(target->block, PM_BLOCK_ARGUMENT_NODE));
*node = (pm_index_target_node_t) {
{
.type = PM_INDEX_TARGET_NODE,
Expand All @@ -3216,7 +3220,7 @@ pm_index_target_node_create(pm_parser_t *parser, pm_call_node_t *target) {
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.block = target->block
.block = (pm_block_argument_node_t *) target->block,
};

// Here we're going to free the target, since it is no longer necessary.
Expand All @@ -3231,7 +3235,7 @@ pm_index_target_node_create(pm_parser_t *parser, pm_call_node_t *target) {
* Allocate and initialize a new CapturePatternNode node.
*/
static pm_capture_pattern_node_t *
pm_capture_pattern_node_create(pm_parser_t *parser, pm_node_t *value, pm_node_t *target, const pm_token_t *operator) {
pm_capture_pattern_node_create(pm_parser_t *parser, pm_node_t *value, pm_local_variable_target_node_t *target, const pm_token_t *operator) {
pm_capture_pattern_node_t *node = PM_NODE_ALLOC(parser, pm_capture_pattern_node_t);

*node = (pm_capture_pattern_node_t) {
Expand All @@ -3240,7 +3244,7 @@ pm_capture_pattern_node_create(pm_parser_t *parser, pm_node_t *value, pm_node_t
.node_id = PM_NODE_IDENTIFY(parser),
.location = {
.start = value->location.start,
.end = target->location.end
.end = target->base.location.end
},
},
.value = value,
Expand Down Expand Up @@ -4032,14 +4036,25 @@ pm_find_pattern_node_create(pm_parser_t *parser, pm_node_list_t *nodes) {
pm_find_pattern_node_t *node = PM_NODE_ALLOC(parser, pm_find_pattern_node_t);

pm_node_t *left = nodes->nodes[0];
assert(PM_NODE_TYPE_P(left, PM_SPLAT_NODE));
pm_splat_node_t *left_splat_node = (pm_splat_node_t *) left;

pm_node_t *right;

if (nodes->size == 1) {
right = (pm_node_t *) pm_missing_node_create(parser, left->location.end, left->location.end);
} else {
right = nodes->nodes[nodes->size - 1];
assert(PM_NODE_TYPE_P(right, PM_SPLAT_NODE));
}

#if PRISM_SERIALIZE_ONLY_SEMANTICS_FIELDS
// FindPatternNode#right is typed as SplatNode in this case, so replace the potential MissingNode with a SplatNode.
// The resulting AST will anyway be ignored, but this file still needs to compile.
pm_splat_node_t *right_splat_node = PM_NODE_TYPE_P(right, PM_SPLAT_NODE) ? (pm_splat_node_t *) right : left_splat_node;
#else
pm_node_t *right_splat_node = right;
#endif
*node = (pm_find_pattern_node_t) {
{
.type = PM_FIND_PATTERN_NODE,
Expand All @@ -4050,8 +4065,8 @@ pm_find_pattern_node_create(pm_parser_t *parser, pm_node_list_t *nodes) {
},
},
.constant = NULL,
.left = left,
.right = right,
.left = left_splat_node,
.right = right_splat_node,
.requireds = { 0 },
.opening_loc = PM_OPTIONAL_LOCATION_NOT_PROVIDED_VALUE,
.closing_loc = PM_OPTIONAL_LOCATION_NOT_PROVIDED_VALUE
Expand Down Expand Up @@ -17401,7 +17416,7 @@ parse_pattern_primitives(pm_parser_t *parser, pm_constant_id_list_t *captures, p
}

parse_pattern_capture(parser, captures, constant_id, &PM_LOCATION_TOKEN_VALUE(&parser->previous));
pm_node_t *target = (pm_node_t *) pm_local_variable_target_node_create(
pm_local_variable_target_node_t *target = pm_local_variable_target_node_create(
parser,
&PM_LOCATION_TOKEN_VALUE(&parser->previous),
constant_id,
Expand Down
2 changes: 1 addition & 1 deletion templates/include/prism/ast.h.erb
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,6 @@ typedef enum pm_<%= flag.human %> {
* to specify that through the environment. It will never be true except for in
* those build systems.
*/
#define PRISM_SERIALIZE_ONLY_SEMANTICS_FIELDS <%= Prism::Template::SERIALIZE_ONLY_SEMANTICS_FIELDS %>
#define PRISM_SERIALIZE_ONLY_SEMANTICS_FIELDS <%= Prism::Template::SERIALIZE_ONLY_SEMANTICS_FIELDS ? 1 : 0 %>

#endif
21 changes: 13 additions & 8 deletions templates/java/org/prism/Loader.java.erb
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,21 @@ public class Loader {
int constantPoolLength = loadVarUInt();
this.constantPool = new ConstantPool(this, source.bytes, constantPoolBufferOffset, constantPoolLength);

Nodes.Node node = loadNode();
Nodes.Node node;
if (errors.length == 0) {
node = loadNode();

int left = constantPoolBufferOffset - buffer.position();
if (left != 0) {
throw new Error("Expected to consume all bytes while deserializing but there were " + left + " bytes left");
}
int left = constantPoolBufferOffset - buffer.position();
if (left != 0) {
throw new Error("Expected to consume all bytes while deserializing but there were " + left + " bytes left");
}

boolean[] newlineMarked = new boolean[1 + source.getLineCount()];
MarkNewlinesVisitor visitor = new MarkNewlinesVisitor(source, newlineMarked);
node.accept(visitor);
boolean[] newlineMarked = new boolean[1 + source.getLineCount()];
MarkNewlinesVisitor visitor = new MarkNewlinesVisitor(source, newlineMarked);
node.accept(visitor);
} else {
node = null;
}

return new ParseResult(node, magicComments, dataLocation, errors, warnings, source);
}
Expand Down
4 changes: 2 additions & 2 deletions templates/lib/prism/dsl.rb.erb
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ module Prism
def <%= node.human %>(<%= ["source: default_source", "node_id: 0", "location: default_location", "flags: 0", *node.fields.map { |field|
case field
when Prism::Template::NodeField
if !field.kind?
kind = field.specific_kind || field.union_kind&.first
if kind.nil?
"#{field.name}: default_node(source, location)"
else
kind = field.specific_kind || field.union_kind.first
"#{field.name}: #{kind.gsub(/(?<=.)[A-Z]/, "_\\0").downcase}(source: source)"
end
when Prism::Template::ConstantField
Expand Down
2 changes: 1 addition & 1 deletion templates/lib/prism/node.rb.erb
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ module Prism
@flags = flags
<%- node.fields.each do |field| -%>
<%- if Prism::Template::CHECK_FIELD_KIND && field.respond_to?(:check_field_kind) -%>
raise <%= field.name %>.inspect unless <%= field.check_field_kind %>
raise "<%= node.name %>#<%= field.name %> was of unexpected type:\n#{<%= field.name %>.inspect}" unless <%= field.check_field_kind %>
<%- end -%>
@<%= field.name %> = <%= field.name %>
<%- end -%>
Expand Down
4 changes: 2 additions & 2 deletions templates/rbi/prism/dsl.rbi.erb
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ module Prism::DSL
].concat(node.fields.map { |field|
case field
when Prism::Template::NodeField
if !field.kind?
kind = field.specific_kind || field.union_kind&.first
if kind.nil?
[field.name, "default_node(source, location)", field.rbi_class]
else
kind = field.specific_kind || field.union_kind.first
[field.name, %Q{#{kind.gsub(/(?<=.)[A-Z]/, "_\\0").downcase}(source: source)}, field.rbi_class]
end
when Prism::Template::OptionalNodeField
Expand Down
Loading

0 comments on commit 27ca419

Please sign in to comment.