diff --git a/include/circt-c/Dialect/FIRRTL.h b/include/circt-c/Dialect/FIRRTL.h index 6743abcf1fd8..6dfcbe16cb19 100644 --- a/include/circt-c/Dialect/FIRRTL.h +++ b/include/circt-c/Dialect/FIRRTL.h @@ -105,28 +105,74 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FIRRTL, firrtl); // Type API. //===----------------------------------------------------------------------===// +/// Returns `true` if this is a const type whose value is guaranteed to be +/// unchanging at circuit execution time. +MLIR_CAPI_EXPORTED bool firrtlTypeIsConst(MlirType type); + +/// Returns a const or non-const version of this type. +MLIR_CAPI_EXPORTED MlirType firrtlTypeGetConstType(MlirType type, bool isConst); + +/// Gets the bit width for this type, returns -1 if unknown. +/// +/// It recursively computes the bit width of aggregate types. For bundle and +/// vectors, recursively get the width of each field element and return the +/// total bit width of the aggregate type. This returns -1, if any of the bundle +/// fields is a flip type, or ground type with unknown bit width. +MLIR_CAPI_EXPORTED int64_t firrtlTypeGetBitWidth(MlirType type, + bool ignoreFlip); + +/// Checks if this type is a unsigned integer type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAUInt(MlirType type); + /// Creates a unsigned integer type with the specified width. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetUInt(MlirContext ctx, int32_t width); +/// Checks if this type is a signed integer type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsASInt(MlirType type); + /// Creates a signed integer type with the specified width. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetSInt(MlirContext ctx, int32_t width); +/// Checks if this type is a clock type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAClock(MlirType type); + /// Creates a clock type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetClock(MlirContext ctx); +/// Checks if this type is a reset type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAReset(MlirType type); + /// Creates a reset type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetReset(MlirContext ctx); -/// Creates a async reset type. +/// Checks if this type is an async reset type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAAsyncReset(MlirType type); + +/// Creates an async reset type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAsyncReset(MlirContext ctx); -/// Creates a analog type with the specified width. +/// Checks if this type is an analog type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAAnalog(MlirType type); + +/// Creates an analog type with the specified width. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAnalog(MlirContext ctx, int32_t width); +/// Checks if this type is a vector type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAVector(MlirType type); + /// Creates a vector type with the specified element type and count. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetVector(MlirContext ctx, MlirType element, size_t count); +/// Returns the element type of a vector type. +MLIR_CAPI_EXPORTED MlirType firrtlTypeGetVectorElement(MlirType vec); + +/// Returns the number of elements in a vector type. +MLIR_CAPI_EXPORTED size_t firrtlTypeGetVectorNumElements(MlirType vec); + +/// Returns true if the specified type is a bundle type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsABundle(MlirType type); + /// Returns true if the specified type is an open bundle type. /// /// An open bundle type means that it contains non FIRRTL base types. @@ -139,35 +185,70 @@ MLIR_CAPI_EXPORTED bool firrtlTypeIsAOpenBundle(MlirType type); MLIR_CAPI_EXPORTED MlirType firrtlTypeGetBundle( MlirContext ctx, size_t count, const FIRRTLBundleField *fields); +/// Returns the number of fields in the bundle type. +MLIR_CAPI_EXPORTED size_t firrtlTypeGetBundleNumFields(MlirType bundle); + +/// Returns the field at the specified index in the bundle type. +MLIR_CAPI_EXPORTED bool +firrtlTypeGetBundleFieldByIndex(MlirType type, size_t index, + FIRRTLBundleField *field); + /// Returns the index of the field with the specified name in the bundle type. MLIR_CAPI_EXPORTED unsigned firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName); +/// Checks if this type is a ref type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsARef(MlirType type); + /// Creates a ref type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetRef(MlirType target, bool forceable); +/// Checks if this type is an anyref type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAAnyRef(MlirType type); + /// Creates an anyref type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAnyRef(MlirContext ctx); +/// Checks if this type is a property integer type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAInteger(MlirType type); + /// Creates a property integer type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetInteger(MlirContext ctx); +/// Checks if this type is a property double type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsADouble(MlirType type); + /// Creates a property double type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetDouble(MlirContext ctx); +/// Checks if this type is a property string type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAString(MlirType type); + /// Creates a property string type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetString(MlirContext ctx); +/// Checks if this type is a property boolean type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsABoolean(MlirType type); + /// Creates a property boolean type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetBoolean(MlirContext ctx); +/// Checks if this type is a property path type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAPath(MlirType type); + /// Creates a property path type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetPath(MlirContext ctx); -/// Creates a property path type with the specified element type. +/// Checks if this type is a property list type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAList(MlirType type); + +/// Creates a property list type with the specified element type. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetList(MlirContext ctx, MlirType elementType); +/// Checks if this type is a class type. +MLIR_CAPI_EXPORTED bool firrtlTypeIsAClass(MlirType type); + /// Creates a class type with the specified name and elements. MLIR_CAPI_EXPORTED MlirType firrtlTypeGetClass(MlirContext ctx, MlirAttribute name, size_t numberOfElements, diff --git a/lib/CAPI/Dialect/FIRRTL.cpp b/lib/CAPI/Dialect/FIRRTL.cpp index b76a5e703c65..c3dea4d70dba 100644 --- a/lib/CAPI/Dialect/FIRRTL.cpp +++ b/lib/CAPI/Dialect/FIRRTL.cpp @@ -37,30 +37,61 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FIRRTL, firrtl, // Type API. //===----------------------------------------------------------------------===// +bool firrtlTypeIsConst(MlirType type) { return isConst(unwrap(type)); } + +MlirType firrtlTypeGetConstType(MlirType type, bool isConst) { + return wrap(cast(unwrap(type)).getConstType(isConst)); +} + +int64_t firrtlTypeGetBitWidth(MlirType type, bool ignoreFlip) { + return getBitWidth(cast(unwrap(type)), ignoreFlip) + .value_or(-1); +} + +bool firrtlTypeIsAUInt(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetUInt(MlirContext ctx, int32_t width) { return wrap(UIntType::get(unwrap(ctx), width)); } +bool firrtlTypeIsASInt(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetSInt(MlirContext ctx, int32_t width) { return wrap(SIntType::get(unwrap(ctx), width)); } +bool firrtlTypeIsAClock(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetClock(MlirContext ctx) { return wrap(ClockType::get(unwrap(ctx))); } +bool firrtlTypeIsAReset(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetReset(MlirContext ctx) { return wrap(ResetType::get(unwrap(ctx))); } +bool firrtlTypeIsAAsyncReset(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetAsyncReset(MlirContext ctx) { return wrap(AsyncResetType::get(unwrap(ctx))); } +bool firrtlTypeIsAAnalog(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetAnalog(MlirContext ctx, int32_t width) { return wrap(AnalogType::get(unwrap(ctx), width)); } +bool firrtlTypeIsAVector(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetVector(MlirContext ctx, MlirType element, size_t count) { auto baseType = cast(unwrap(element)); assert(baseType && "element must be base type"); @@ -68,6 +99,18 @@ MlirType firrtlTypeGetVector(MlirContext ctx, MlirType element, size_t count) { return wrap(FVectorType::get(baseType, count)); } +MlirType firrtlTypeGetVectorElement(MlirType vec) { + return wrap(cast(unwrap(vec)).getElementType()); +} + +size_t firrtlTypeGetVectorNumElements(MlirType vec) { + return cast(unwrap(vec)).getNumElements(); +} + +bool firrtlTypeIsABundle(MlirType type) { + return isa(unwrap(type)); +} + bool firrtlTypeIsAOpenBundle(MlirType type) { return isa(unwrap(type)); } @@ -98,6 +141,35 @@ MlirType firrtlTypeGetBundle(MlirContext ctx, size_t count, return wrap(OpenBundleType::get(unwrap(ctx), bundleFields)); } +size_t firrtlTypeGetBundleNumFields(MlirType bundle) { + if (auto bundleType = dyn_cast(unwrap(bundle))) { + return bundleType.getNumElements(); + } else if (auto bundleType = dyn_cast(unwrap(bundle))) { + return bundleType.getNumElements(); + } else { + llvm_unreachable("must be a bundle type"); + } +} + +bool firrtlTypeGetBundleFieldByIndex(MlirType type, size_t index, + FIRRTLBundleField *field) { + auto unwrapped = unwrap(type); + + auto cvt = [field](auto element) { + field->name = wrap(element.name); + field->isFlip = element.isFlip; + field->type = wrap(element.type); + }; + if (auto bundleType = dyn_cast(unwrapped)) { + cvt(bundleType.getElement(index)); + return true; + } else if (auto bundleType = dyn_cast(unwrapped)) { + cvt(bundleType.getElement(index)); + return true; + } + return false; +} + unsigned firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName) { std::optional fieldIndex; if (auto bundleType = dyn_cast(unwrap(type))) { @@ -111,6 +183,8 @@ unsigned firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName) { return fieldIndex.value(); } +bool firrtlTypeIsARef(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetRef(MlirType target, bool forceable) { auto baseType = dyn_cast(unwrap(target)); assert(baseType && "target must be base type"); @@ -118,30 +192,52 @@ MlirType firrtlTypeGetRef(MlirType target, bool forceable) { return wrap(RefType::get(baseType, forceable)); } +bool firrtlTypeIsAAnyRef(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetAnyRef(MlirContext ctx) { return wrap(AnyRefType::get(unwrap(ctx))); } +bool firrtlTypeIsAInteger(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetInteger(MlirContext ctx) { return wrap(FIntegerType::get(unwrap(ctx))); } +bool firrtlTypeIsADouble(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetDouble(MlirContext ctx) { return wrap(DoubleType::get(unwrap(ctx))); } +bool firrtlTypeIsAString(MlirType type) { + return isa(unwrap(type)); +} + MlirType firrtlTypeGetString(MlirContext ctx) { return wrap(StringType::get(unwrap(ctx))); } +bool firrtlTypeIsABoolean(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetBoolean(MlirContext ctx) { return wrap(BoolType::get(unwrap(ctx))); } +bool firrtlTypeIsAPath(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetPath(MlirContext ctx) { return wrap(PathType::get(unwrap(ctx))); } +bool firrtlTypeIsAList(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetList(MlirContext ctx, MlirType elementType) { auto type = dyn_cast(unwrap(elementType)); assert(type && "element must be property type"); @@ -149,6 +245,8 @@ MlirType firrtlTypeGetList(MlirContext ctx, MlirType elementType) { return wrap(ListType::get(unwrap(ctx), type)); } +bool firrtlTypeIsAClass(MlirType type) { return isa(unwrap(type)); } + MlirType firrtlTypeGetClass(MlirContext ctx, MlirAttribute name, size_t numberOfElements, const FIRRTLClassElement *elements) { diff --git a/test/CAPI/firrtl.c b/test/CAPI/firrtl.c index 351700dcd7af..db8487182661 100644 --- a/test/CAPI/firrtl.c +++ b/test/CAPI/firrtl.c @@ -32,6 +32,12 @@ void appendBufferCallback(MlirStringRef message, void *userData) { sprintf(buffer + strlen(buffer), "%.*s", (int)message.length, message.data); } +bool bundleFieldEqual(const FIRRTLBundleField *lhs, + const FIRRTLBundleField *rhs) { + return mlirIdentifierEqual(lhs->name, rhs->name) && + lhs->isFlip == rhs->isFlip && mlirTypeEqual(lhs->type, rhs->type); +} + void testExport(MlirContext ctx) { // clang-format off const char *testFIR = @@ -187,6 +193,131 @@ void testAttrGetIntegerFromString(MlirContext ctx) { mlirStringRefCreateFromCString("114514"), 10)); } +void testTypeDiscriminantsAndQueries(MlirContext ctx) { + FIRRTLBundleField bundleFields[] = { + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f1")), + .isFlip = false, + .type = firrtlTypeGetUInt(ctx, 32), + }, + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f2")), + .isFlip = false, + .type = firrtlTypeGetSInt(ctx, 64), + }, + }; + FIRRTLBundleField openBundleFields[] = { + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f1")), + .isFlip = false, + .type = firrtlTypeGetUInt(ctx, 32), + }, + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f2")), + .isFlip = false, + .type = firrtlTypeGetInteger(ctx), + }, + }; + MlirType bundle = + firrtlTypeGetBundle(ctx, ARRAY_SIZE(bundleFields), bundleFields); + bundleFields[1].isFlip = true; + MlirType bundleContainsFlip = + firrtlTypeGetBundle(ctx, ARRAY_SIZE(bundleFields), bundleFields); + MlirType openBundle = + firrtlTypeGetBundle(ctx, ARRAY_SIZE(openBundleFields), openBundleFields); + + FIRRTLClassElement classElements[] = { + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f1")), + .type = firrtlTypeGetUInt(ctx, 32), + .direction = FIRRTL_DIRECTION_IN, + }, + { + .name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f2")), + .type = firrtlTypeGetInteger(ctx), + .direction = FIRRTL_DIRECTION_OUT, + }, + }; + MlirType cls = firrtlTypeGetClass( + ctx, mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreateFromCString("cls")), + ARRAY_SIZE(classElements), classElements); + + assert(firrtlTypeIsConst( + firrtlTypeGetConstType(firrtlTypeGetUInt(ctx, 32), true))); + assert(firrtlTypeIsConst(firrtlTypeGetConstType( + firrtlTypeGetConstType(firrtlTypeGetUInt(ctx, 32), false), true))); + assert(firrtlTypeIsAUInt(firrtlTypeGetUInt(ctx, 32))); + assert(firrtlTypeIsASInt(firrtlTypeGetSInt(ctx, 32))); + assert(firrtlTypeIsAClock(firrtlTypeGetClock(ctx))); + assert(firrtlTypeIsAReset(firrtlTypeGetReset(ctx))); + assert(firrtlTypeIsAAsyncReset(firrtlTypeGetAsyncReset(ctx))); + assert(firrtlTypeIsAAnalog(firrtlTypeGetAnalog(ctx, 32))); + assert(firrtlTypeIsAVector( + firrtlTypeGetVector(ctx, firrtlTypeGetUInt(ctx, 32), 4))); + assert(firrtlTypeIsABundle(bundle)); + assert(firrtlTypeIsAOpenBundle(openBundle)); + assert(firrtlTypeIsARef(firrtlTypeGetRef(firrtlTypeGetClock(ctx), false))); + assert(firrtlTypeIsAAnyRef(firrtlTypeGetAnyRef(ctx))); + assert(firrtlTypeIsAInteger(firrtlTypeGetInteger(ctx))); + assert(firrtlTypeIsADouble(firrtlTypeGetDouble(ctx))); + assert(firrtlTypeIsAString(firrtlTypeGetString(ctx))); + assert(firrtlTypeIsABoolean(firrtlTypeGetBoolean(ctx))); + assert(firrtlTypeIsAPath(firrtlTypeGetPath(ctx))); + assert(firrtlTypeIsAList(firrtlTypeGetList(ctx, firrtlTypeGetInteger(ctx)))); + assert(firrtlTypeIsAClass(cls)); + assert(!firrtlTypeIsConst(firrtlTypeGetUInt(ctx, 32))); + assert(!firrtlTypeIsConst( + firrtlTypeGetConstType(firrtlTypeGetUInt(ctx, 32), false))); + assert(!firrtlTypeIsConst(firrtlTypeGetConstType( + firrtlTypeGetConstType(firrtlTypeGetUInt(ctx, 32), true), false))); + assert(!firrtlTypeIsAUInt(firrtlTypeGetSInt(ctx, 32))); + assert(!firrtlTypeIsASInt(firrtlTypeGetUInt(ctx, 32))); + assert(!firrtlTypeIsAClock(firrtlTypeGetReset(ctx))); + assert(!firrtlTypeIsAReset(firrtlTypeGetClock(ctx))); + assert(!firrtlTypeIsAAsyncReset(firrtlTypeGetReset(ctx))); + assert(!firrtlTypeIsAAnalog(firrtlTypeGetReset(ctx))); + assert(!firrtlTypeIsABundle(openBundle)); + assert(!firrtlTypeIsAOpenBundle(bundle)); + assert(!firrtlTypeIsARef(firrtlTypeGetAnyRef(ctx))); + assert( + !firrtlTypeIsAAnyRef(firrtlTypeGetRef(firrtlTypeGetClock(ctx), false))); + assert(!firrtlTypeIsAInteger(firrtlTypeGetDouble(ctx))); + assert(!firrtlTypeIsADouble(firrtlTypeGetString(ctx))); + assert(!firrtlTypeIsAString(firrtlTypeGetBoolean(ctx))); + assert(!firrtlTypeIsABoolean(firrtlTypeGetPath(ctx))); + assert(!firrtlTypeIsAPath(firrtlTypeGetBoolean(ctx))); + assert(!firrtlTypeIsAList(cls)); + assert(!firrtlTypeIsAClass(firrtlTypeGetPath(ctx))); + + assert(firrtlTypeGetBitWidth(firrtlTypeGetUInt(ctx, 32), false) == 32); + assert(firrtlTypeGetBitWidth(firrtlTypeGetSInt(ctx, 32), false) == 32); + assert(firrtlTypeGetBitWidth(firrtlTypeGetClock(ctx), false) == 1); + assert(firrtlTypeGetBitWidth(firrtlTypeGetReset(ctx), false) == 1); + assert(firrtlTypeGetBitWidth(bundle, false) == 96); + assert(firrtlTypeGetBitWidth(bundleContainsFlip, false) == -1); + assert(firrtlTypeGetBitWidth(bundleContainsFlip, true) == 96); + assert(mlirTypeEqual(firrtlTypeGetVectorElement(firrtlTypeGetVector( + ctx, firrtlTypeGetUInt(ctx, 32), 4)), + firrtlTypeGetUInt(ctx, 32))); + assert(firrtlTypeGetVectorNumElements( + firrtlTypeGetVector(ctx, firrtlTypeGetUInt(ctx, 32), 4)) == 4); + assert(firrtlTypeGetBundleNumFields(bundle) == 2); + assert(firrtlTypeGetBundleNumFields(openBundle) == 2); + FIRRTLBundleField field; + assert(firrtlTypeGetBundleFieldByIndex(bundle, 0, &field) && + bundleFieldEqual(&bundleFields[0], &field)); + assert(firrtlTypeGetBundleFieldByIndex(openBundle, 1, &field) && + bundleFieldEqual(&openBundleFields[1], &field)); + assert(firrtlTypeGetBundleFieldIndex( + bundle, mlirStringRefCreateFromCString("f1")) == 0); + assert(firrtlTypeGetBundleFieldIndex( + bundle, mlirStringRefCreateFromCString("f2")) == 1); + assert(firrtlTypeGetBundleFieldIndex( + openBundle, mlirStringRefCreateFromCString("f1")) == 0); + assert(firrtlTypeGetBundleFieldIndex( + openBundle, mlirStringRefCreateFromCString("f2")) == 1); +} + void testTypeGetMaskType(MlirContext ctx) { assert(mlirTypeEqual(firrtlTypeGetMaskType(firrtlTypeGetUInt(ctx, 32)), firrtlTypeGetUInt(ctx, 1))); @@ -240,6 +371,7 @@ int main(void) { testValueFoldFlow(ctx); testImportAnnotations(ctx); testAttrGetIntegerFromString(ctx); + testTypeDiscriminantsAndQueries(ctx); testTypeGetMaskType(ctx); return 0; }