Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL][CAPI] Add more functions for discriminating and querying type #7960

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions include/circt-c/Dialect/FIRRTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,74 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FIRRTL, firrtl);
// Type API.
//===----------------------------------------------------------------------===//

/// Returns `true` if this is a const type whose value is guaranteed to be
/// unchanging at circuit execution time.
MLIR_CAPI_EXPORTED bool firrtlTypeIsConst(MlirType type);

/// Returns a const or non-const version of this type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetConstType(MlirType type, bool isConst);

/// Gets the bit width for this type, returns -1 if unknown.
///
/// It recursively computes the bit width of aggregate types. For bundle and
/// vectors, recursively get the width of each field element and return the
/// total bit width of the aggregate type. This returns -1, if any of the bundle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns -1, if any of the bundle fields is a flip type

why

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the design of

// Get the bit width for this type, return None if unknown. Unlike
// getBitWidthOrSentinel(), this can recursively compute the bitwidth of
// aggregate types. For bundle and vectors, recursively get the width of each
// field element and return the total bit width of the aggregate type. This
// returns None, if any of the bundle fields is a flip type, or ground type with
// unknown bit width.
std::optional<int64_t> getBitWidth(FIRRTLBaseType type,
bool ignoreFlip = false);

But I forgot to present the ignoreFlip parameter in C-API, I should add it as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated~

/// fields is a flip type, or ground type with unknown bit width.
MLIR_CAPI_EXPORTED int64_t firrtlTypeGetBitWidth(MlirType type,
bool ignoreFlip);

/// Checks if this type is a unsigned integer type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAUInt(MlirType type);

/// Creates a unsigned integer type with the specified width.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetUInt(MlirContext ctx, int32_t width);

/// Checks if this type is a signed integer type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsASInt(MlirType type);

/// Creates a signed integer type with the specified width.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetSInt(MlirContext ctx, int32_t width);

/// Checks if this type is a clock type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAClock(MlirType type);

/// Creates a clock type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetClock(MlirContext ctx);

/// Checks if this type is a reset type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAReset(MlirType type);

/// Creates a reset type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetReset(MlirContext ctx);

/// Creates a async reset type.
/// Checks if this type is an async reset type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAAsyncReset(MlirType type);

/// Creates an async reset type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAsyncReset(MlirContext ctx);

/// Creates a analog type with the specified width.
/// Checks if this type is an analog type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAAnalog(MlirType type);

/// Creates an analog type with the specified width.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAnalog(MlirContext ctx, int32_t width);

/// Checks if this type is a vector type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAVector(MlirType type);

/// Creates a vector type with the specified element type and count.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetVector(MlirContext ctx,
MlirType element, size_t count);

/// Returns the element type of a vector type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetVectorElement(MlirType vec);

/// Returns the number of elements in a vector type.
MLIR_CAPI_EXPORTED size_t firrtlTypeGetVectorNumElements(MlirType vec);

/// Returns true if the specified type is a bundle type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsABundle(MlirType type);

/// Returns true if the specified type is an open bundle type.
///
/// An open bundle type means that it contains non FIRRTL base types.
Expand All @@ -139,35 +185,70 @@ MLIR_CAPI_EXPORTED bool firrtlTypeIsAOpenBundle(MlirType type);
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetBundle(
MlirContext ctx, size_t count, const FIRRTLBundleField *fields);

/// Returns the number of fields in the bundle type.
MLIR_CAPI_EXPORTED size_t firrtlTypeGetBundleNumFields(MlirType bundle);

/// Returns the field at the specified index in the bundle type.
MLIR_CAPI_EXPORTED bool
firrtlTypeGetBundleFieldByIndex(MlirType type, size_t index,
FIRRTLBundleField *field);

/// Returns the index of the field with the specified name in the bundle type.
MLIR_CAPI_EXPORTED unsigned
firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName);

/// Checks if this type is a ref type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsARef(MlirType type);

/// Creates a ref type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetRef(MlirType target, bool forceable);

/// Checks if this type is an anyref type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAAnyRef(MlirType type);

/// Creates an anyref type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetAnyRef(MlirContext ctx);

/// Checks if this type is a property integer type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAInteger(MlirType type);

/// Creates a property integer type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetInteger(MlirContext ctx);

/// Checks if this type is a property double type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsADouble(MlirType type);

/// Creates a property double type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetDouble(MlirContext ctx);

/// Checks if this type is a property string type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAString(MlirType type);

/// Creates a property string type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetString(MlirContext ctx);

/// Checks if this type is a property boolean type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsABoolean(MlirType type);

/// Creates a property boolean type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetBoolean(MlirContext ctx);

/// Checks if this type is a property path type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAPath(MlirType type);

/// Creates a property path type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetPath(MlirContext ctx);

/// Creates a property path type with the specified element type.
/// Checks if this type is a property list type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAList(MlirType type);

/// Creates a property list type with the specified element type.
MLIR_CAPI_EXPORTED MlirType firrtlTypeGetList(MlirContext ctx,
MlirType elementType);

/// Checks if this type is a class type.
MLIR_CAPI_EXPORTED bool firrtlTypeIsAClass(MlirType type);

/// Creates a class type with the specified name and elements.
MLIR_CAPI_EXPORTED MlirType
firrtlTypeGetClass(MlirContext ctx, MlirAttribute name, size_t numberOfElements,
Expand Down
97 changes: 97 additions & 0 deletions lib/CAPI/Dialect/FIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,80 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FIRRTL, firrtl,
// Type API.
//===----------------------------------------------------------------------===//

bool firrtlTypeIsConst(MlirType type) { return isConst(unwrap(type)); }

MlirType firrtlTypeGetConstType(MlirType type, bool isConst) {
return wrap(cast<FIRRTLBaseType>(unwrap(type)).getConstType(isConst));
}

int64_t firrtlTypeGetBitWidth(MlirType type, bool ignoreFlip) {
return getBitWidth(cast<FIRRTLBaseType>(unwrap(type)), ignoreFlip)
.value_or(-1);
}

bool firrtlTypeIsAUInt(MlirType type) { return isa<UIntType>(unwrap(type)); }

MlirType firrtlTypeGetUInt(MlirContext ctx, int32_t width) {
return wrap(UIntType::get(unwrap(ctx), width));
}

bool firrtlTypeIsASInt(MlirType type) { return isa<SIntType>(unwrap(type)); }

MlirType firrtlTypeGetSInt(MlirContext ctx, int32_t width) {
return wrap(SIntType::get(unwrap(ctx), width));
}

bool firrtlTypeIsAClock(MlirType type) { return isa<ClockType>(unwrap(type)); }

MlirType firrtlTypeGetClock(MlirContext ctx) {
return wrap(ClockType::get(unwrap(ctx)));
}

bool firrtlTypeIsAReset(MlirType type) { return isa<ResetType>(unwrap(type)); }

MlirType firrtlTypeGetReset(MlirContext ctx) {
return wrap(ResetType::get(unwrap(ctx)));
}

bool firrtlTypeIsAAsyncReset(MlirType type) {
return isa<AsyncResetType>(unwrap(type));
}

MlirType firrtlTypeGetAsyncReset(MlirContext ctx) {
return wrap(AsyncResetType::get(unwrap(ctx)));
}

bool firrtlTypeIsAAnalog(MlirType type) {
return isa<AnalogType>(unwrap(type));
}

MlirType firrtlTypeGetAnalog(MlirContext ctx, int32_t width) {
return wrap(AnalogType::get(unwrap(ctx), width));
}

bool firrtlTypeIsAVector(MlirType type) {
return isa<FVectorType>(unwrap(type));
}

MlirType firrtlTypeGetVector(MlirContext ctx, MlirType element, size_t count) {
auto baseType = cast<FIRRTLBaseType>(unwrap(element));
assert(baseType && "element must be base type");

return wrap(FVectorType::get(baseType, count));
}

MlirType firrtlTypeGetVectorElement(MlirType vec) {
return wrap(cast<FVectorType>(unwrap(vec)).getElementType());
}

size_t firrtlTypeGetVectorNumElements(MlirType vec) {
return cast<FVectorType>(unwrap(vec)).getNumElements();
}

bool firrtlTypeIsABundle(MlirType type) {
return isa<BundleType>(unwrap(type));
}

bool firrtlTypeIsAOpenBundle(MlirType type) {
return isa<OpenBundleType>(unwrap(type));
}
Expand Down Expand Up @@ -98,6 +141,34 @@ MlirType firrtlTypeGetBundle(MlirContext ctx, size_t count,
return wrap(OpenBundleType::get(unwrap(ctx), bundleFields));
}

size_t firrtlTypeGetBundleNumFields(MlirType bundle) {
if (auto bundleType = dyn_cast<BundleType>(unwrap(bundle))) {
return bundleType.getNumElements();
} else if (auto bundleType = dyn_cast<OpenBundleType>(unwrap(bundle))) {
return bundleType.getNumElements();
} else {
llvm_unreachable("must be a bundle type");
}
}

bool firrtlTypeGetBundleFieldByIndex(MlirType type, size_t index,
FIRRTLBundleField *field) {
auto cvt = [field](auto element) {
field->name = wrap(element.name);
field->isFlip = element.isFlip;
field->type = wrap(element.type);
};

if (auto bundleType = dyn_cast<BundleType>(unwrap(type))) {
cvt(bundleType.getElement(index));
return true;
} else if (auto bundleType = dyn_cast<OpenBundleType>(unwrap(type))) {
cvt(bundleType.getElement(index));
return true;
}
return false;
}

unsigned firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName) {
std::optional<unsigned> fieldIndex;
if (auto bundleType = dyn_cast<BundleType>(unwrap(type))) {
Expand All @@ -111,44 +182,70 @@ unsigned firrtlTypeGetBundleFieldIndex(MlirType type, MlirStringRef fieldName) {
return fieldIndex.value();
}

bool firrtlTypeIsARef(MlirType type) { return isa<RefType>(unwrap(type)); }

MlirType firrtlTypeGetRef(MlirType target, bool forceable) {
auto baseType = dyn_cast<FIRRTLBaseType>(unwrap(target));
assert(baseType && "target must be base type");

return wrap(RefType::get(baseType, forceable));
}

bool firrtlTypeIsAAnyRef(MlirType type) {
return isa<AnyRefType>(unwrap(type));
}

MlirType firrtlTypeGetAnyRef(MlirContext ctx) {
return wrap(AnyRefType::get(unwrap(ctx)));
}

bool firrtlTypeIsAInteger(MlirType type) {
return isa<FIntegerType>(unwrap(type));
}

MlirType firrtlTypeGetInteger(MlirContext ctx) {
return wrap(FIntegerType::get(unwrap(ctx)));
}

bool firrtlTypeIsADouble(MlirType type) {
return isa<DoubleType>(unwrap(type));
}

MlirType firrtlTypeGetDouble(MlirContext ctx) {
return wrap(DoubleType::get(unwrap(ctx)));
}

bool firrtlTypeIsAString(MlirType type) {
return isa<StringType>(unwrap(type));
}

MlirType firrtlTypeGetString(MlirContext ctx) {
return wrap(StringType::get(unwrap(ctx)));
}

bool firrtlTypeIsABoolean(MlirType type) { return isa<BoolType>(unwrap(type)); }

MlirType firrtlTypeGetBoolean(MlirContext ctx) {
return wrap(BoolType::get(unwrap(ctx)));
}

bool firrtlTypeIsAPath(MlirType type) { return isa<PathType>(unwrap(type)); }

MlirType firrtlTypeGetPath(MlirContext ctx) {
return wrap(PathType::get(unwrap(ctx)));
}

bool firrtlTypeIsAList(MlirType type) { return isa<ListType>(unwrap(type)); }

MlirType firrtlTypeGetList(MlirContext ctx, MlirType elementType) {
auto type = dyn_cast<PropertyType>(unwrap(elementType));
assert(type && "element must be property type");

return wrap(ListType::get(unwrap(ctx), type));
}

bool firrtlTypeIsAClass(MlirType type) { return isa<ClassType>(unwrap(type)); }

MlirType firrtlTypeGetClass(MlirContext ctx, MlirAttribute name,
size_t numberOfElements,
const FIRRTLClassElement *elements) {
Expand Down
Loading
Loading