diff --git a/include/ast.h b/include/ast.h index 0930d9c..47c4e45 100644 --- a/include/ast.h +++ b/include/ast.h @@ -24,6 +24,7 @@ struct AbstractSyntaxTree { ForLoop, List, ArrayType, + ArrayAccess, } nodeType{Type::Invalid}; virtual ~AbstractSyntaxTree() = default; virtual bool operator==(const AbstractSyntaxTree &other) const = 0; @@ -150,6 +151,13 @@ struct ArrayType : public AbstractSyntaxTree { bool operator==(const AbstractSyntaxTree &other) const override; }; +struct ArrayAccess : public AbstractSyntaxTree { + Node identifier; + AbstractSyntaxTree *index; + ArrayAccess(Node identifier, AbstractSyntaxTree *index); + bool operator==(const AbstractSyntaxTree &other) const override; +}; + std::vector parse(const char *input); void compile(Program &program, const char *input); Program compile(const char *input); diff --git a/src/ast.cpp b/src/ast.cpp index 55634bd..c131a0b 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -505,6 +505,18 @@ bool ArrayType::operator==(const AbstractSyntaxTree &other) const { return *type == *otherArrayType.type; } +ArrayAccess::ArrayAccess(Node identifier, AbstractSyntaxTree *index) + : identifier(identifier), index(index) { + nodeType = AbstractSyntaxTree::Type::ArrayAccess; + typeStr = "ArrayAccess"; +} +bool ArrayAccess::operator==(const AbstractSyntaxTree &other) const { + if (other.nodeType != nodeType) return false; + auto &otherArrayAccess = dynamic_cast(other); + return identifier == otherArrayAccess.identifier && + *index == *otherArrayAccess.index; +} + void compile(Program &program, const char *input) { auto ast = parse(input); if (!program.segments.empty() && diff --git a/src/parser.y b/src/parser.y index f6f9493..924edde 100644 --- a/src/parser.y +++ b/src/parser.y @@ -32,7 +32,7 @@ %token Number String Identifier %type Expression Expressions VarType ScopedBody TypeCast FunctionCall IfStatement WhileStatement ForLoop %type ArgumentDeclaration ArgumentDeclarationsList Arguments FunctionDeclaration UnaryExpression -%type List Elements ArrayType +%type List Elements ArrayType ArrayAccess %left Plus Minus %left Multiply Divide @@ -47,6 +47,11 @@ Statement: } ; +ArrayAccess: + Identifier LBracket Expression RBracket { + $$ = new ArrayAccess(Node({Identifier, $1}), static_cast($3)); + } + List: LBracket Elements RBracket { $$ = new List(*static_cast*>($2)); @@ -201,6 +206,7 @@ Expression: | ForLoop { $$ = $1; } | UnaryExpression { $$ = $1; } | List { $$ = $1; } + | ArrayAccess { $$ = $1; } | Return Expression { $$ = new ReturnStatement(static_cast($2)); } diff --git a/tests/parser_tests.cpp b/tests/parser_tests.cpp index 3ab2318..402a011 100644 --- a/tests/parser_tests.cpp +++ b/tests/parser_tests.cpp @@ -467,3 +467,17 @@ TEST(ParserTests, ArrayDeclaration) { for (int i = 0; i < expectedResult.size(); i++) ASSERT_EQ(*expectedResult[i], *actualResult[i]); } + +TEST(ParserTests, ArrayAccess) { + const char *input = "x[0];"; + auto expectedResult = std::vector({ + new ArrayAccess( + Node({Identifier, "x"}), + new Node({Number, "0"})), + }); + + auto actualResult = parse(input); + ASSERT_EQ(expectedResult.size(), actualResult.size()); + for (int i = 0; i < expectedResult.size(); i++) + ASSERT_EQ(*expectedResult[i], *actualResult[i]); +}