diff --git a/gutils/BUILD.bazel b/gutils/BUILD.bazel index abf01c7..04532bf 100644 --- a/gutils/BUILD.bazel +++ b/gutils/BUILD.bazel @@ -36,7 +36,12 @@ cc_library( "proto.h", ], visibility = ["//visibility:public"], - deps = ["@com_google_protobuf//:protobuf"], + deps = [ + ":status", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//src/google/protobuf/io", + ], ) cc_library( diff --git a/gutils/proto.cc b/gutils/proto.cc index 9b3a114..8e3644d 100644 --- a/gutils/proto.cc +++ b/gutils/proto.cc @@ -1,7 +1,16 @@ #include "gutils/proto.h" +#include + +#include +#include + +#include "absl/status/status.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" +#include "gutils/status_builder.h" namespace gutils { @@ -22,4 +31,28 @@ bool ProtoEqual(const google::protobuf::Message &message1, return ProtoEqual(message1, message2, differ); } +absl::Status ReadProtoFromFile(std::string_view filename, + google::protobuf::Message *message) { + // Verifies that the version of the library that we linked against is + // compatible with the version of the headers we compiled against. + /* copybara:insert(not needed nor possible in google3, as it is a mono repo) + GOOGLE_PROTOBUF_VERIFY_VERSION; + */ + + int fd = open(std::string(filename).c_str(), O_RDONLY); + if (fd < 0) { + return InvalidArgumentErrorBuilder() + << "Error opening the file " << filename; + } + + google::protobuf::io::FileInputStream file_stream(fd); + file_stream.SetCloseOnDelete(true); + + if (!google::protobuf::TextFormat::Parse(&file_stream, message)) { + return InvalidArgumentErrorBuilder() << "Failed to parse file " << filename; + } + + return absl::OkStatus(); +} + } // namespace gutils diff --git a/gutils/proto.h b/gutils/proto.h index 9d58418..bbda3eb 100644 --- a/gutils/proto.h +++ b/gutils/proto.h @@ -1,6 +1,9 @@ #ifndef THIRD_PARTY_P4LANG_P4_CONSTRAINTS_GUTILS_PROTO_H_ #define THIRD_PARTY_P4LANG_P4_CONSTRAINTS_GUTILS_PROTO_H_ +#include + +#include "absl/status/status.h" #include "google/protobuf/message.h" #include "google/protobuf/util/message_differencer.h" @@ -14,6 +17,10 @@ bool ProtoEqual(const google::protobuf::Message &message1, bool ProtoEqual(const google::protobuf::Message &message1, const google::protobuf::Message &message2); +// Read the contents of the file into a protobuf. +absl::Status ReadProtoFromFile(std::string_view filename, + google::protobuf::Message *message); + } // namespace gutils #endif // THIRD_PARTY_P4LANG_P4_CONSTRAINTS_GUTILS_PROTO_H_ diff --git a/p4_constraints/ast.proto b/p4_constraints/ast.proto index 7d034ab..3020fca 100644 --- a/p4_constraints/ast.proto +++ b/p4_constraints/ast.proto @@ -184,5 +184,10 @@ message SourceLocation { // If present, `line` and `column` are relative to an @entry_restriction // annotation attached to a table of the given name. string table_name = 4; + + // P4 action name. Prefer `file_path` whenever possible. + // If present, `line` and `column` are relative to an @action_restriction + // annotation attached to an action of the given name. + string action_name = 5; } } diff --git a/p4_constraints/backend/BUILD.bazel b/p4_constraints/backend/BUILD.bazel index 302cac7..43491da 100644 --- a/p4_constraints/backend/BUILD.bazel +++ b/p4_constraints/backend/BUILD.bazel @@ -1,4 +1,5 @@ load("//e2e_tests:p4check.bzl", "cmd_diff_test") +load("@com_github_p4lang_p4c//:bazel/p4_library.bzl", "p4_library") package( default_visibility = ["//visibility:public"], @@ -83,6 +84,7 @@ cc_library( "//p4_constraints/frontend:constraint_kind", "//p4_constraints/frontend:parser", "@com_github_p4lang_p4runtime//:p4info_cc_proto", + "@com_github_p4lang_p4runtime//:p4types_cc_proto", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -183,6 +185,20 @@ cc_library( ], ) +p4_library( + name = "p4_programs/action_restrictions_valid", + src = "p4_programs/action_restrictions_valid.p4", + p4info_out = "p4_programs/action_restrictions_valid.p4info.pb.txt", + visibility = ["//visibility:private"], +) + +p4_library( + name = "p4_programs/action_restrictions_invalid", + src = "p4_programs/action_restrictions_invalid.p4", + p4info_out = "p4_programs/action_restrictions_invalid.p4info.pb.txt", + visibility = ["//visibility:private"], +) + cc_test( name = "symbolic_interpreter_test", srcs = ["symbolic_interpreter_test.cc"], @@ -208,3 +224,22 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "constraint_info_test", + srcs = ["constraint_info_test.cc"], + data = [ + "p4_programs/action_restrictions_invalid.p4info.pb.txt", + "p4_programs/action_restrictions_valid.p4info.pb.txt", + ], + deps = [ + ":constraint_info", + "//gutils:proto", + "//gutils:status_matchers", + "@com_github_p4lang_p4runtime//:p4info_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/p4_constraints/backend/constraint_info.cc b/p4_constraints/backend/constraint_info.cc index 20bbed2..63b9e95 100644 --- a/p4_constraints/backend/constraint_info.cc +++ b/p4_constraints/backend/constraint_info.cc @@ -21,6 +21,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -30,6 +31,7 @@ #include "absl/types/optional.h" #include "gutils/status_macros.h" #include "p4/config/v1/p4info.pb.h" +#include "p4/config/v1/p4types.pb.h" #include "p4_constraints/ast.pb.h" #include "p4_constraints/backend/type_checker.h" #include "p4_constraints/constraint_source.h" @@ -41,12 +43,43 @@ namespace p4_constraints { namespace { +using p4::config::v1::Action; +using p4::config::v1::Action_Param; using p4::config::v1::MatchField; +using p4::config::v1::Preamble; using p4::config::v1::Table; +using p4_constraints::ConstraintKind; + +RE2 GetConstraintAnnotation(ConstraintKind constraint_kind) { + switch (constraint_kind) { + case ConstraintKind::kTableConstraint: + return {R"RE(@entry_restriction)RE"}; + case ConstraintKind::kActionConstraint: + return {R"RE(@action_restriction)RE"}; + } + LOG(DFATAL) + << "ConstraintKind is neither TableConstraint nor ActionConstraint"; + return RE2(""); +} + +void SetConstraintLocationName(ConstraintKind constraint_kind, + absl::string_view name, + ast::SourceLocation& source_location) { + switch (constraint_kind) { + case ConstraintKind::kTableConstraint: + source_location.set_table_name(name); + return; + case ConstraintKind::kActionConstraint: + source_location.set_action_name(name); + return; + } + LOG(DFATAL) + << "ConstraintKind is neither TableConstraint nor ActionConstraint"; +} absl::StatusOr> ExtractConstraint( - const Table& table) { - // We expect .p4 files to have the following format: + ConstraintKind constraint_kind, const Preamble& preamble) { + // We expect .p4 files to have the following format for tables: // ```p4 // @file(__FILE__) // optional // @line(__LINE__) // optional @@ -55,16 +88,26 @@ absl::StatusOr> ExtractConstraint( // ") // table foo { ... } // ``` - // The @file/@line annotations are optional and intended for debugging/testing - // only; they allows us to give error messages that quote the source code. + // We expect .p4 files to have the following format for actions: + // ```p4 + // @file(__FILE__) // optional + // @line(__LINE__) // optional + // @action_restriction(" + // + // ") + // action bar { ... } + // ``` + // The @file/@line annotations are optional and intended for + // debugging/testing only; they allows us to give error messages that quote + // the source code. const RE2 file_annotation = {R"RE(@file[(]"([^"]*)"[)])RE"}; const RE2 line_annotation = {R"RE(@line[(](\d+)[)])RE"}; - const RE2 constraint_annotation = {R"RE(@entry_restriction)RE"}; + const RE2 constraint_annotation = GetConstraintAnnotation(constraint_kind); absl::string_view constraint_string = ""; ast::SourceLocation constraint_location; int line = 0; - for (absl::string_view annotation : table.preamble().annotations()) { + for (absl::string_view annotation : preamble.annotations()) { if (RE2::Consume(&annotation, file_annotation, constraint_location.mutable_file_path())) continue; @@ -79,15 +122,18 @@ absl::StatusOr> ExtractConstraint( constraint_location.set_line(line); if (constraint_location.file_path().empty()) { - constraint_location.set_table_name(table.preamble().name()); + SetConstraintLocationName(constraint_kind, preamble.name(), + constraint_location); } if (!absl::ConsumePrefix(&constraint_string, "(\"") || !absl::ConsumeSuffix(&constraint_string, "\")")) { + bool is_table = constraint_kind == ConstraintKind::kTableConstraint; return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) - << "In table " << table.preamble().name() << ":\n" - << "Syntax error: @entry_restriction must be enclosed in " - "'(\"' and '\")'"; + << "In " << (is_table ? "table " : "action ") << preamble.name() + << ":\n" + << "Syntax error: @" << (is_table ? "entry" : "action") + << "_restriction must be enclosed in '(\"' and '\")'"; } return ConstraintSource{ .constraint_string = std::string(constraint_string), @@ -129,6 +175,19 @@ absl::StatusOr ParseKeyType(const MatchField& key) { } } +absl::StatusOr ParseParamType(const Action_Param& param) { + ast::Type type; + // P4NamedType is unset if the param does not use a user-defined type. + // Currently we do not support user-defined types. + if (!param.type_name().name().empty()) { + type.mutable_unsupported()->set_name(param.type_name().name()); + return type; + } + + type.mutable_fixed_unsigned()->set_bitwidth(param.bitwidth()); + return type; +} + absl::StatusOr ParseTableInfo(const Table& table) { absl::flat_hash_map keys_by_id; absl::flat_hash_map keys_by_name; @@ -148,8 +207,9 @@ absl::StatusOr ParseTableInfo(const Table& table) { } } - ASSIGN_OR_RETURN(absl::optional constraint_source, - ExtractConstraint(table)); + ASSIGN_OR_RETURN( + absl::optional constraint_source, + ExtractConstraint(ConstraintKind::kTableConstraint, table.preamble())); absl::optional constraint = absl::nullopt; if (constraint_source.has_value()) { @@ -175,6 +235,52 @@ absl::StatusOr ParseTableInfo(const Table& table) { return table_info; } +absl::StatusOr ParseActionInfo(const Action& action) { + absl::flat_hash_map params_by_id; + absl::flat_hash_map params_by_name; + + for (const Action_Param& param : action.params()) { + ASSIGN_OR_RETURN(const ast::Type type, ParseParamType(param)); + const ParamInfo param_info{ + .id = param.id(), + .name = param.name(), + .type = type, + }; + if (!params_by_id.insert({param_info.id, param_info}).second) { + return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) + << "action " << action.preamble().name() + << " has duplicate param: " << param.DebugString(); + } + if (!params_by_name.insert({param_info.name, param_info}).second) { + return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) + << "action " << action.preamble().name() + << " has duplicate param: " << param.DebugString(); + } + } + ASSIGN_OR_RETURN( + absl::optional constraint_source, + ExtractConstraint(ConstraintKind::kActionConstraint, action.preamble())); + absl::optional constraint; + if (constraint_source.has_value()) { + ASSIGN_OR_RETURN( + constraint, + ParseConstraint(ConstraintKind::kActionConstraint, *constraint_source)); + } + ActionInfo action_info{ + .id = action.preamble().id(), + .name = action.preamble().name(), + .constraint = constraint, + .constraint_source = constraint_source.value_or(ConstraintSource()), + .params_by_id = params_by_id, + .params_by_name = params_by_name, + }; + // Type check constraint. + if (action_info.constraint.has_value()) { + RETURN_IF_ERROR(InferAndCheckTypes(&*action_info.constraint, action_info)); + } + return action_info; +} + } // namespace std::optional GetAttributeInfo( @@ -201,7 +307,6 @@ const TableInfo* GetTableInfoOrNull(const ConstraintInfo& constraint_info, absl::StatusOr P4ToConstraintInfo( const p4::config::v1::P4Info& p4info) { // Allocate output. - // TODO: b/293655979 - Populate the action_info_by_id map. absl::flat_hash_map action_info_by_id; absl::flat_hash_map table_info_by_id; @@ -217,6 +322,18 @@ absl::StatusOr P4ToConstraintInfo( << "duplicate table: " << table.DebugString()); } } + + for (const Action& action : p4info.actions()) { + absl::StatusOr action_info = ParseActionInfo(action); + if (!action_info.ok()) { + errors.push_back(action_info.status()); + } else if (!action_info_by_id.insert({action.preamble().id(), *action_info}) + .second) { + errors.push_back(gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) + << "duplicate action: " << action.DebugString()); + } + } + if (errors.empty()) { ConstraintInfo info{ .action_info_by_id = std::move(action_info_by_id), diff --git a/p4_constraints/backend/constraint_info.h b/p4_constraints/backend/constraint_info.h index 24b48b4..093cba8 100644 --- a/p4_constraints/backend/constraint_info.h +++ b/p4_constraints/backend/constraint_info.h @@ -49,6 +49,15 @@ struct KeyInfo { ast::Type type; }; +struct ParamInfo { + uint32_t id; // Same as Action.Param.id in p4info.proto. + std::string name; // Same as Action.Param.name in p4info.proto. + + // Param type specified by a combination of Action.Param.bitwidth and + // Action.Param.P4NamedType in p4info.proto. + ast::Type type; +}; + template void AbslStringify(Sink& sink, const KeyInfo& info) { absl::Format(&sink, "KeyInfo{ id: %d; name: \"%s\"; type: { %s }; }", info.id, @@ -80,6 +89,11 @@ struct ActionInfo { // If member `constraint` is present, this captures its source. Arbitrary // otherwise. ConstraintSource constraint_source; + + // Maps from param IDs to ParamInfo. + absl::flat_hash_map params_by_id; + // Maps from param names to ParamInfo. + absl::flat_hash_map params_by_name; }; // Contains all information required for constraint checking. @@ -98,11 +112,16 @@ struct ConstraintInfo { absl::StatusOr P4ToConstraintInfo( const p4::config::v1::P4Info& p4info); -// Returns a unique pointer to the TableInfo associated with a given table_id +// Returns a pointer to the TableInfo associated with a given table_id // or std::nullptr if the table_id cannot be found. const TableInfo* GetTableInfoOrNull(const ConstraintInfo& constraint_info, uint32_t table_id); +// Returns a pointer to the ActionInfo associated with a given action_id +// or std::nullptr if the action_id cannot be found. +const ActionInfo* GetActionInfoOrNull(const ConstraintInfo& constraint_info, + uint32_t action_id); + // Table entry attribute accessible in the constraint language, e.g. priority. struct AttributeInfo { std::string name; diff --git a/p4_constraints/backend/constraint_info_test.cc b/p4_constraints/backend/constraint_info_test.cc new file mode 100644 index 0000000..31619a6 --- /dev/null +++ b/p4_constraints/backend/constraint_info_test.cc @@ -0,0 +1,54 @@ +#include "p4_constraints/backend/constraint_info.h" + +#include + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "gutils/proto.h" +#include "gutils/status_matchers.h" +#include "p4/config/v1/p4info.pb.h" + +using p4::config::v1::P4Info; + +namespace p4_constraints { + +constexpr char kActionRestrictionsValidP4InfoFile[] = + "third_party/p4lang_p4_constraints/p4_constraints/backend/" + "p4_programs/action_restrictions_valid.p4info.pb.txt"; + +constexpr char kActionRestrictionsInvalidP4InfoFile[] = + "third_party/p4lang_p4_constraints/p4_constraints/backend/" + "p4_programs/action_restrictions_invalid.p4info.pb.txt"; + +class ActionRestrictionsTest : public ::testing::Test {}; + +TEST_F(ActionRestrictionsTest, ExtractActionConstraint) { + P4Info p4_info; + ASSERT_OK( + gutils::ReadProtoFromFile(kActionRestrictionsValidP4InfoFile, &p4_info)); + absl::StatusOr constraints = + p4_constraints::P4ToConstraintInfo(p4_info); + + ASSERT_OK(constraints); + + absl::flat_hash_map action_info_by_id = + constraints.value().action_info_by_id; + + EXPECT_EQ(action_info_by_id[16777339].constraint_source.constraint_string, + "multicast_group_id != 0"); +} + +TEST_F(ActionRestrictionsTest, ParseParamType) { + P4Info p4_info; + ASSERT_OK(gutils::ReadProtoFromFile(kActionRestrictionsInvalidP4InfoFile, + &p4_info)); + absl::StatusOr constraints = + p4_constraints::P4ToConstraintInfo(p4_info); + + EXPECT_TRUE(!constraints.ok()); +} + +} // namespace p4_constraints diff --git a/p4_constraints/backend/p4_programs/action_restrictions_invalid.p4 b/p4_constraints/backend/p4_programs/action_restrictions_invalid.p4 new file mode 100644 index 0000000..3ae7e4c --- /dev/null +++ b/p4_constraints/backend/p4_programs/action_restrictions_invalid.p4 @@ -0,0 +1,77 @@ +#include + +type bit<16> custom_type_t; + +struct headers {}; +struct metadata { + bit<16> foo; + custom_type_t bar; +}; + +#define MULTICAST_GROUP_ID_BITWIDTH 16 +typedef bit multicast_group_id_t; + +parser MyParser(packet_in packet, + out headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + state start { + transition accept; + } +} + +control MyVerifyChecksum(inout headers hdr, inout metadata meta) { + apply { } +} + +control MyIngress(inout headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + + @id(123) + @action_restriction("multicast_group_id != 0") + action act_1(multicast_group_id_t multicast_group_id) { + meta.foo = multicast_group_id; + } + + @id(1234) + @action_restriction("custom_type_param != 0") + action act_2(custom_type_t custom_type_param) { + meta.bar = custom_type_param; + } + + table tbl { + key = { } + actions = { + act_1; + act_2; + } + } + + apply { + tbl.apply(); + } +} + +control MyEgress(inout headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + apply { } +} + +control MyComputeChecksum(inout headers hdr, inout metadata meta) { + apply { } +} + +control MyDeparser(packet_out packet, in headers hdr) { + apply { } +} + +V1Switch( +MyParser(), +MyVerifyChecksum(), +MyIngress(), +MyEgress(), +MyComputeChecksum(), +MyDeparser() +) main; diff --git a/p4_constraints/backend/p4_programs/action_restrictions_valid.p4 b/p4_constraints/backend/p4_programs/action_restrictions_valid.p4 new file mode 100644 index 0000000..33b5c21 --- /dev/null +++ b/p4_constraints/backend/p4_programs/action_restrictions_valid.p4 @@ -0,0 +1,67 @@ +#include + +struct headers {}; +struct metadata { + bit<16> foo; +}; + +#define MULTICAST_GROUP_ID_BITWIDTH 16 +typedef bit multicast_group_id_t; + +parser MyParser(packet_in packet, + out headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + state start { + transition accept; + } +} + +control MyVerifyChecksum(inout headers hdr, inout metadata meta) { + apply { } +} + +control MyIngress(inout headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + + @id(123) + @action_restriction("multicast_group_id != 0") + action act_1(multicast_group_id_t multicast_group_id) { + meta.foo = multicast_group_id; + } + + table tbl { + key = { } + actions = { + act_1; + } + } + + apply { + tbl.apply(); + } +} + +control MyEgress(inout headers hdr, + inout metadata meta, + inout standard_metadata_t standard_metadata) { + apply { } +} + +control MyComputeChecksum(inout headers hdr, inout metadata meta) { + apply { } +} + +control MyDeparser(packet_out packet, in headers hdr) { + apply { } +} + +V1Switch( +MyParser(), +MyVerifyChecksum(), +MyIngress(), +MyEgress(), +MyComputeChecksum(), +MyDeparser() +) main; diff --git a/p4_constraints/backend/type_checker.cc b/p4_constraints/backend/type_checker.cc index e47f0ff..e7dcbb1 100644 --- a/p4_constraints/backend/type_checker.cc +++ b/p4_constraints/backend/type_checker.cc @@ -15,6 +15,7 @@ #include "p4_constraints/backend/type_checker.h" #include +#include #include #include @@ -171,11 +172,11 @@ absl::Status CastTransitivelyTo(Expression* expr, Type target_type) { // and mutates the expressions by wrapping them with type casts to the least // upper bound. Otherwise, Unify returns an InvalidArgument Status. absl::StatusOr Unify(Expression* left, Expression* right, - const TableInfo& table_info) { + const ConstraintSource& constraint_source) { const absl::optional least_upper_bound = LeastUpperBound(left->type(), right->type()); if (!least_upper_bound.has_value()) { - return StaticTypeError(table_info.constraint_source, left->start_location(), + return StaticTypeError(constraint_source, left->start_location(), right->end_location()) << "cannot unify types " << left->type() << " and " << right->type(); } @@ -221,27 +222,53 @@ absl::optional FieldTypeOfCompositeType(const Type& composite_type, // -- Type checking ------------------------------------------------------------ -absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { +const ConstraintSource& GetConstraintSource(const ActionInfo* action_info, + const TableInfo* table_info) { + if (action_info == nullptr) return table_info->constraint_source; + return action_info->constraint_source; +} + +absl::Status InferAndCheckTypes(Expression* expr, const ActionInfo* action_info, + const TableInfo* table_info) { + const ConstraintSource& constraint_source = + GetConstraintSource(action_info, table_info); + + // We expect exactly one of {action_info, table_info} to be set. + if (action_info != nullptr && table_info != nullptr) { + return gutils::InternalErrorBuilder() + << "Both action_info and table_info are nullptr."; + } + if (action_info == nullptr && table_info == nullptr) { + return gutils::InternalErrorBuilder() + << "Both action_info and table_info are not nullptr."; + } + switch (expr->expression_case()) { - case Expression::kBooleanConstant: + case ast::Expression::kBooleanConstant: expr->mutable_type()->mutable_boolean(); return absl::OkStatus(); - case Expression::kIntegerConstant: + case ast::Expression::kIntegerConstant: expr->mutable_type()->mutable_arbitrary_int(); return absl::OkStatus(); - case Expression::kKey: { - const std::string& key = expr->key(); - const auto& key_info = table_info.keys_by_name.find(key); - if (key_info == table_info.keys_by_name.end()) - return StaticTypeError(table_info.constraint_source, + case ast::Expression::kKey: { + // This case only applies to TableInfo. + if (table_info == nullptr) { + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) + << "unexpected key in action constraint"; + } + const std::string_view key = expr->key(); + const auto& key_info = table_info->keys_by_name.find(key); + if (key_info == table_info->keys_by_name.end()) + return StaticTypeError(table_info->constraint_source, expr->start_location(), expr->end_location()) << "unknown key " << key; *expr->mutable_type() = key_info->second.type; if (expr->type().type_case() == Type::kUnknown || expr->type().type_case() == Type::kUnsupported) { - return StaticTypeError(table_info.constraint_source, + return StaticTypeError(table_info->constraint_source, expr->start_location(), expr->end_location()) << "key " << key << " has illegal type " << TypeName(expr->type()); @@ -249,40 +276,35 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { return absl::OkStatus(); } - case Expression::kActionParameter: { - return absl::UnimplementedError( - "TODO: b/293656077 - Support action constraints"); - } - - case Expression::kAttributeAccess: { - const std::string& attribute_name = - expr->attribute_access().attribute_name(); - const auto attribute_info = GetAttributeInfo(attribute_name); - if (attribute_info == std::nullopt) { - return StaticTypeError(table_info.constraint_source, - expr->start_location(), expr->end_location()) - << "unknown attribute '" << attribute_name << "'"; + case ast::Expression::kActionParameter: { + // This case only applies to ActionInfo. + if (action_info == nullptr) { + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) + << "unexpected action parameter in table constraint"; } - Type& expr_type = *expr->mutable_type(); - expr_type = attribute_info->type; - if (expr_type.type_case() == Type::kUnknown || - expr_type.type_case() == Type::kUnsupported) { - // Since we hardcode the type of attribute in the source code, this line - // should never be reached. - return InternalError(table_info.constraint_source, - expr->start_location(), expr->end_location()) - << "attribute '" << attribute_name << "' has illegal type " - << TypeName(expr_type); + const std::string_view param = expr->action_parameter(); + const auto& param_info = action_info->params_by_name.find(param); + if (param_info == action_info->params_by_name.end()) + return StaticTypeError(action_info->constraint_source, + expr->start_location(), expr->end_location()) + << "unknown action parameter " << param; + *expr->mutable_type() = param_info->second.type; + if (expr->type().type_case() == Type::kUnknown || + expr->type().type_case() == Type::kUnsupported) { + return StaticTypeError(action_info->constraint_source, + expr->start_location(), expr->end_location()) + << "action parameter " << param << " has illegal type " + << TypeName(expr->type()); } return absl::OkStatus(); } - case Expression::kBooleanNegation: { + case ast::Expression::kBooleanNegation: { Expression* sub_expr = expr->mutable_boolean_negation(); - RETURN_IF_ERROR(InferAndCheckTypes(sub_expr, table_info)); + RETURN_IF_ERROR(InferAndCheckTypes(sub_expr, action_info, table_info)); if (!sub_expr->type().has_boolean()) { - return StaticTypeError(table_info.constraint_source, - sub_expr->start_location(), + return StaticTypeError(constraint_source, sub_expr->start_location(), sub_expr->end_location()) << "expected type bool, got " << TypeName(sub_expr->type()); } @@ -290,12 +312,11 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { return absl::OkStatus(); } - case Expression::kArithmeticNegation: { + case ast::Expression::kArithmeticNegation: { Expression* sub_expr = expr->mutable_arithmetic_negation(); - RETURN_IF_ERROR(InferAndCheckTypes(sub_expr, table_info)); + RETURN_IF_ERROR(InferAndCheckTypes(sub_expr, action_info, table_info)); if (!sub_expr->type().has_arbitrary_int()) { - return StaticTypeError(table_info.constraint_source, - sub_expr->start_location(), + return StaticTypeError(constraint_source, sub_expr->start_location(), sub_expr->end_location()) << "expected type int, got " << TypeName(sub_expr->type()); } @@ -303,19 +324,24 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { return absl::OkStatus(); } + case ast::Expression::kTypeCast: + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) + << "type casts should only be inserted by the type checker"; + case Expression::kBinaryExpression: { BinaryExpression* bin_expr = expr->mutable_binary_expression(); Expression* left = bin_expr->mutable_left(); Expression* right = bin_expr->mutable_right(); - RETURN_IF_ERROR(InferAndCheckTypes(left, table_info)); - RETURN_IF_ERROR(InferAndCheckTypes(right, table_info)); + RETURN_IF_ERROR(InferAndCheckTypes(left, action_info, table_info)); + RETURN_IF_ERROR(InferAndCheckTypes(right, action_info, table_info)); switch (bin_expr->binop()) { case ast::BinaryOperator::AND: case ast::BinaryOperator::OR: case ast::BinaryOperator::IMPLIES: { for (auto subexpr : {left, right}) { if (!subexpr->type().has_boolean()) { - return StaticTypeError(table_info.constraint_source, + return StaticTypeError(constraint_source, subexpr->start_location(), subexpr->end_location()) << "expected type bool, got " << TypeName(subexpr->type()); @@ -330,13 +356,13 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { case ast::BinaryOperator::LE: case ast::BinaryOperator::EQ: case ast::BinaryOperator::NE: { - ASSIGN_OR_RETURN(Type type, Unify(left, right, table_info)); + ASSIGN_OR_RETURN(Type type, Unify(left, right, constraint_source)); // Unordered types only support == and !=. if (bin_expr->binop() != ast::BinaryOperator::EQ && bin_expr->binop() != ast::BinaryOperator::NE && !TypeHasOrdering(type)) { - return StaticTypeError(table_info.constraint_source, - expr->start_location(), expr->end_location()) + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) << "operand type " << type << " does not support ordered comparison"; } @@ -350,20 +376,16 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { } } - case Expression::kTypeCast: - return StaticTypeError(table_info.constraint_source, - expr->start_location(), expr->end_location()) - << "type casts should only be inserted by the type checker"; - - case Expression::kFieldAccess: { + case ast::Expression::kFieldAccess: { Expression* composite_expr = expr->mutable_field_access()->mutable_expr(); const std::string& field = expr->mutable_field_access()->field(); - RETURN_IF_ERROR(InferAndCheckTypes(composite_expr, table_info)); + RETURN_IF_ERROR( + InferAndCheckTypes(composite_expr, action_info, table_info)); absl::optional field_type = FieldTypeOfCompositeType(composite_expr->type(), field); if (!field_type.has_value()) { - return StaticTypeError(table_info.constraint_source, - expr->start_location(), expr->end_location()) + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) << "expression of type " << composite_expr->type() << " has no field '" << field << "'"; } @@ -371,12 +393,43 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { return absl::OkStatus(); } - case Expression::EXPRESSION_NOT_SET: + case ast::Expression::kAttributeAccess: { + const std::string& attribute_name = + expr->attribute_access().attribute_name(); + const auto attribute_info = GetAttributeInfo(attribute_name); + if (attribute_info == std::nullopt) { + return StaticTypeError(constraint_source, expr->start_location(), + expr->end_location()) + << "unknown attribute '" << attribute_name << "'"; + } + Type& expr_type = *expr->mutable_type(); + expr_type = attribute_info->type; + if (expr_type.type_case() == Type::kUnknown || + expr_type.type_case() == Type::kUnsupported) { + // Since we hardcode the type of attribute in the source code, this line + // should never be reached. + return InternalError(constraint_source, expr->start_location(), + expr->end_location()) + << "attribute '" << attribute_name << "' has illegal type " + << TypeName(expr_type); + } + return absl::OkStatus(); + } + + case ast::Expression::EXPRESSION_NOT_SET: break; } - return StaticTypeError(table_info.constraint_source, expr->start_location(), + return StaticTypeError(constraint_source, expr->start_location(), expr->end_location()) << "unexpected expression: " << expr->DebugString(); +} // namespace + +absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { + return InferAndCheckTypes(expr, /*action_info=*/nullptr, &table_info); } +absl::Status InferAndCheckTypes(Expression* expr, + const ActionInfo& action_info) { + return InferAndCheckTypes(expr, &action_info, /*table_info=*/nullptr); +} } // namespace p4_constraints diff --git a/p4_constraints/backend/type_checker.h b/p4_constraints/backend/type_checker.h index 186801c..749794d 100644 --- a/p4_constraints/backend/type_checker.h +++ b/p4_constraints/backend/type_checker.h @@ -34,14 +34,16 @@ namespace p4_constraints { // the correct types. // - It may insert type-casts, making implicit casts explicit. // -// Upon successful completion of this function, the given expression is +// Upon successful completion of these functions, the given expression is // guaranteed to contain no ast::Type::Unknown/Unsupported types. // -// This function should not be called on an `expr` that has already been type +// These functions should not be called on an `expr` that has already been type // checked (more specifically, on an `expr` that already contains type casts). // Doing so will result in an InvalidInput Error. absl::Status InferAndCheckTypes(ast::Expression* expr, const TableInfo& table_info); +absl::Status InferAndCheckTypes(ast::Expression* expr, + const ActionInfo& action_info); } // namespace p4_constraints diff --git a/p4_constraints/backend/type_checker_test.cc b/p4_constraints/backend/type_checker_test.cc index cd26353..ea547df 100644 --- a/p4_constraints/backend/type_checker_test.cc +++ b/p4_constraints/backend/type_checker_test.cc @@ -83,6 +83,20 @@ class InferAndCheckTypesTest : public ::testing::Test { {"range32", {0, "range32", kRange32}}, }}; + const ActionInfo kActionInfo{.id = 0, + .name = "action", + .constraint = {}, + .constraint_source = + ConstraintSource{ + .constraint_string = " ", + .constraint_location = kMockLocation, + }, + .params_by_id = {}, + .params_by_name = { + {"bit16", {0, "bit16", kFixedUnsigned16}}, + {"bit32", {0, "bit32", kFixedUnsigned32}}, + }}; + // Required by negative tests to avoid internal quoting errors. void AddMockSourceLocations(Expression& expr) { *expr.mutable_start_location() = kMockLocation; @@ -173,13 +187,32 @@ TEST_F(InferAndCheckTypesTest, KnownVariablesTypeCheck) { } } -TEST_F(InferAndCheckTypesTest, UnknownVariablesDontTypeCheck) { - std::string keys[] = {"unknown", "unsupported", "not even a key"}; - for (auto& key : keys) { +TEST_F(InferAndCheckTypesTest, KnownVariablesTypeForActionsCheck) { + std::pair param_type_pairs[] = { + {"bit16", kFixedUnsigned16}, + {"bit32", kFixedUnsigned32}, + }; + for (auto& [param_name, param_type] : param_type_pairs) { Expression expr = ParseTextProtoOrDie( - absl::Substitute(R"( key: "$0" )", key)); - AddMockSourceLocations(expr); - EXPECT_THAT(InferAndCheckTypes(&expr, kTableInfo), + absl::Substitute(R"( action_parameter: "$0" )", param_name)); + ASSERT_THAT(InferAndCheckTypes(&expr, kActionInfo), IsOk()); + EXPECT_TRUE(expr.type() == param_type); + } +} + +TEST_F(InferAndCheckTypesTest, UnknownVariablesDontTypeCheck) { + std::string variables[] = {"unknown", "unsupported", "not even valid"}; + for (auto& variable : variables) { + Expression table_expr = ParseTextProtoOrDie( + absl::Substitute(R"( key: "$0" )", variable)); + AddMockSourceLocations(table_expr); + EXPECT_THAT(InferAndCheckTypes(&table_expr, kTableInfo), + StatusIs(StatusCode::kInvalidArgument)); + + Expression action_expr = ParseTextProtoOrDie( + absl::Substitute(R"( action_parameter: "$0" )", variable)); + AddMockSourceLocations(action_expr); + EXPECT_THAT(InferAndCheckTypes(&action_expr, kActionInfo), StatusIs(StatusCode::kInvalidArgument)); } } @@ -227,6 +260,14 @@ TEST_F(InferAndCheckTypesTest, BooleanNegationOfNonBooleansDoesNotTypeCheck) { StatusIs(StatusCode::kInvalidArgument)) << "cannot negate key " << key; } + for (std::string param : {"bit16", "bit32"}) { + expr = ParseTextProtoOrDie(absl::Substitute( + R"(boolean_negation { action_parameter: "$0" })", param)); + AddMockSourceLocations(expr); + ASSERT_THAT(InferAndCheckTypes(&expr, kActionInfo), + StatusIs(StatusCode::kInvalidArgument)) + << "cannot negate action parameter"; + } } TEST_F(InferAndCheckTypesTest, ArithmeticNegationOfIntTypeChecks) { @@ -272,6 +313,14 @@ TEST_F(InferAndCheckTypesTest, ArithmeticNegationOfNonIntDoesNotTypeChecks) { StatusIs(StatusCode::kInvalidArgument)) << "cannot negate key " << key; } + for (std::string param : {"bit32", "bit16"}) { + expr = ParseTextProtoOrDie(absl::Substitute( + R"(arithmetic_negation { action_parameter: "$0" })", param)); + AddMockSourceLocations(expr); + ASSERT_THAT(InferAndCheckTypes(&expr, kActionInfo), + StatusIs(StatusCode::kInvalidArgument)) + << "cannot negate " << param; + } } TEST_F(InferAndCheckTypesTest, TypeCastNeverTypeChecks) { diff --git a/p4_constraints/constraint_source.h b/p4_constraints/constraint_source.h index e046cd2..fbd9828 100644 --- a/p4_constraints/constraint_source.h +++ b/p4_constraints/constraint_source.h @@ -21,7 +21,7 @@ namespace p4_constraints { -// Convienent struct of source information for quoting. +// Convenient struct of source information for quoting. struct ConstraintSource { std::string constraint_string; ast::SourceLocation constraint_location; diff --git a/p4_constraints/frontend/ast_constructors.cc b/p4_constraints/frontend/ast_constructors.cc index 452a99d..0697055 100644 --- a/p4_constraints/frontend/ast_constructors.cc +++ b/p4_constraints/frontend/ast_constructors.cc @@ -146,8 +146,16 @@ absl::StatusOr MakeVariable(absl::Span tokens, for (int i = 0; i < tokens.size(); i++) { const Token& id = tokens[i]; RET_CHECK_EQ(id.kind, Token::ID); - if (constraint_kind == ConstraintKind::kTableConstraint) - key_or_param << (i == 0 ? "" : ".") << id.text; + switch (constraint_kind) { + case ConstraintKind::kTableConstraint: { + key_or_param << (i == 0 ? "" : ".") << id.text; + break; + } + case ConstraintKind::kActionConstraint: { + key_or_param << id.text; + break; + } + } } switch (constraint_kind) { case ConstraintKind::kTableConstraint: {