-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Target Platform Capabilities - Phase 4 #1301
Conversation
Replace schema classes from dataclass to pydantic 'BaseModel'. Fix tests to support pydantic schema classes. Add test for exporting tp model to json
'quantization_configurations': updated_configs | ||
}) | ||
|
||
def clone_and_edit_weight_attribute(self, attrs: Optional[List[str]] = None, **kwargs) -> 'QuantizationConfigOptions': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if we discussed this, but this function and clone_and_map_weights_attr_keys
don't seem like something that should be part of the schema, but rather an external functionality (although it is probably not possible because of the immutability? is that the problem?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it stays, in any case it makes more sense that OpQuantizationConfig updates its attributes. This class can call configs methods, but it shouldn't go inside and update them itself.
@@ -39,7 +39,7 @@ def generate_test_tp_model(edit_params_dict, name=""): | |||
base_config, op_cfg_list, default_config = get_op_quantization_configs() | |||
|
|||
# separate weights attribute parameters from the requested param to edit | |||
weights_params_names = [name for name in schema.AttributeQuantizationConfig.__init__.__code__.co_varnames if | |||
weights_params_names = [name for name in schema.AttributeQuantizationConfig.model_fields.keys() if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe create a property within the class to return this? because .model_fields.keys()
is kind of cryptic
@@ -165,8 +165,8 @@ def get_tpc(self): | |||
quantization_configurations.extend([ | |||
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=4), | |||
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=2)]) | |||
tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config, | |||
quantization_configurations=tuple(quantization_configurations)) | |||
tpc.layer2qco[tf.multiply] = tpc.layer2qco[tf.multiply].model_copy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's model_copy
? is it a pydantic method? is it replacing the clone_and_edit ins some way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a method to copy the instance with the new attributes (since the original instance is immutable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only looked at v1.py. And sorry :)
if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int): | ||
Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover | ||
@field_validator("weights_n_bits") | ||
def validate_weights_n_bits(cls, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you intentionally not use PositiveInt in field definition?
return value | ||
|
||
@field_validator("lut_values_bitwidth", mode="before") | ||
def validate_lut_values_bitwidth(cls, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it validated automatically?
elif not isinstance(self.supported_input_activation_n_bits, tuple): | ||
|
||
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple. | ||
if isinstance(v, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the following flow is would be clearer (at least switch int and list order)
if isinstance(v, int):
return (v,) # or v = (v,)
if isinstance(v, list):
v = tuple(v)
validate v
return v
# Pydantic v2 configuration for immutability | ||
model_config = ConfigDict(frozen=True) | ||
|
||
@model_validator(mode='before') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you consider using "after" or model_post_init? This way pydantic will do the initial validations for you and you can also access the fields via self.
if not isinstance(self.quantization_configurations, tuple): | ||
quantization_configurations = values.get('quantization_configurations', ()) | ||
num_configs = len(quantization_configurations) | ||
base_config = values.get('base_config') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as Ofir's comment, shouldn't base_config exist? If it's None that's fine, but it should be its value, not get's default
) # pragma: no cover | ||
|
||
# Validate that there are at least two operator groups | ||
if len(fusing.operator_groups) < 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Field(min_length=2)
# Validate the operator_groups | ||
if not isinstance(self.operator_groups, tuple): | ||
# Validate operator_groups is a tuple | ||
if not isinstance(fusing.operator_groups, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt it validated automatically?
operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model. | ||
fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model. | ||
operator_set (Tuple[OperatorsSetBase, ...]): Tuple of operator sets within the model. | ||
fusing_patterns (Tuple[Fusing, ...]): Tuple of fusing patterns for the model. | ||
is_simd_padding (bool): Indicates if SIMD padding is applied. | ||
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model. | ||
""" | ||
default_qco: QuantizationConfigOptions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove default values, unless there is a good reason
""" | ||
# Validate `default_qco` | ||
if not isinstance(self.default_qco, QuantizationConfigOptions): | ||
default_qco = tp_model.default_qco | ||
if not isinstance(default_qco, QuantizationConfigOptions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt is validated automatically?
Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover | ||
if len(self.default_qco.quantization_configurations) != 1: | ||
|
||
if len(default_qco.quantization_configurations) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldnt it be checked at QuantizationConfigOptions level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the default qco so it should have only one option. a regular qco can have multiple options
opsets_names = [ | ||
op.name.value if isinstance(op, OperatorSetNames) else op.name | ||
for op in operator_set | ||
] if operator_set else [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you dont need if else []. if will already be an empty list
Replace MaxTensor with MaxCut for activation mixed precision (Experimental).
… tensorflow <=2.13) Fix PR comments
… tensorflow <=2.13) Fix PR comments
Replace schema classes from dataclass to pydantic 'BaseModel'.
Fix tests to support pydantic schema classes.
Add test for exporting and loading tp model to json.
Pull Request Description:
Checklist before requesting a review: