Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Target Platform Capabilities - Phase 4 #1301

Merged
merged 12 commits into from
Dec 31, 2024
Merged

Conversation

lior-dikstein
Copy link
Collaborator

@lior-dikstein lior-dikstein commented Dec 25, 2024

Replace schema classes from dataclass to pydantic 'BaseModel'.
Fix tests to support pydantic schema classes.
Add test for exporting and loading tp model to json.

Pull Request Description:

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

Replace schema classes from dataclass to pydantic 'BaseModel'.
Fix tests to support pydantic schema classes.
Add test for exporting tp model to json
'quantization_configurations': updated_configs
})

def clone_and_edit_weight_attribute(self, attrs: Optional[List[str]] = None, **kwargs) -> 'QuantizationConfigOptions':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if we discussed this, but this function and clone_and_map_weights_attr_keys don't seem like something that should be part of the schema, but rather an external functionality (although it is probably not possible because of the immutability? is that the problem?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it stays, in any case it makes more sense that OpQuantizationConfig updates its attributes. This class can call configs methods, but it shouldn't go inside and update them itself.

@@ -39,7 +39,7 @@ def generate_test_tp_model(edit_params_dict, name=""):
base_config, op_cfg_list, default_config = get_op_quantization_configs()

# separate weights attribute parameters from the requested param to edit
weights_params_names = [name for name in schema.AttributeQuantizationConfig.__init__.__code__.co_varnames if
weights_params_names = [name for name in schema.AttributeQuantizationConfig.model_fields.keys() if
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe create a property within the class to return this? because .model_fields.keys() is kind of cryptic

@@ -165,8 +165,8 @@ def get_tpc(self):
quantization_configurations.extend([
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=4),
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=2)])
tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config,
quantization_configurations=tuple(quantization_configurations))
tpc.layer2qco[tf.multiply] = tpc.layer2qco[tf.multiply].model_copy(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's model_copy? is it a pydantic method? is it replacing the clone_and_edit ins some way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a method to copy the instance with the new attributes (since the original instance is immutable)

Copy link
Collaborator

@irenaby irenaby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only looked at v1.py. And sorry :)

if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int):
Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover
@field_validator("weights_n_bits")
def validate_weights_n_bits(cls, value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intentionally not use PositiveInt in field definition?

return value

@field_validator("lut_values_bitwidth", mode="before")
def validate_lut_values_bitwidth(cls, value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it validated automatically?

elif not isinstance(self.supported_input_activation_n_bits, tuple):

# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
if isinstance(v, list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the following flow is would be clearer (at least switch int and list order)

if isinstance(v, int):
    return (v,)    # or v = (v,)
if isinstance(v, list):
    v = tuple(v)
validate v
return v

# Pydantic v2 configuration for immutability
model_config = ConfigDict(frozen=True)

@model_validator(mode='before')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider using "after" or model_post_init? This way pydantic will do the initial validations for you and you can also access the fields via self.

if not isinstance(self.quantization_configurations, tuple):
quantization_configurations = values.get('quantization_configurations', ())
num_configs = len(quantization_configurations)
base_config = values.get('base_config')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ofir's comment, shouldn't base_config exist? If it's None that's fine, but it should be its value, not get's default

) # pragma: no cover

# Validate that there are at least two operator groups
if len(fusing.operator_groups) < 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Field(min_length=2)

# Validate the operator_groups
if not isinstance(self.operator_groups, tuple):
# Validate operator_groups is a tuple
if not isinstance(fusing.operator_groups, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt it validated automatically?

operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
operator_set (Tuple[OperatorsSetBase, ...]): Tuple of operator sets within the model.
fusing_patterns (Tuple[Fusing, ...]): Tuple of fusing patterns for the model.
is_simd_padding (bool): Indicates if SIMD padding is applied.
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
"""
default_qco: QuantizationConfigOptions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove default values, unless there is a good reason

"""
# Validate `default_qco`
if not isinstance(self.default_qco, QuantizationConfigOptions):
default_qco = tp_model.default_qco
if not isinstance(default_qco, QuantizationConfigOptions):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt is validated automatically?

Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
if len(self.default_qco.quantization_configurations) != 1:

if len(default_qco.quantization_configurations) != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt it be checked at QuantizationConfigOptions level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the default qco so it should have only one option. a regular qco can have multiple options

opsets_names = [
op.name.value if isinstance(op, OperatorSetNames) else op.name
for op in operator_set
] if operator_set else []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you dont need if else []. if will already be an empty list

@lior-dikstein lior-dikstein merged commit afed6e3 into main Dec 31, 2024
42 checks passed
@lior-dikstein lior-dikstein deleted the tpc_refactor_phase4 branch December 31, 2024 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants