From 451818af00717ccce6e16ba639fc9117b80e0684 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 21:23:33 +0000 Subject: [PATCH] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cyclops/data/features/medical_image.py | 13 ++++++++----- cyclops/evaluate/utils.py | 8 +++++--- cyclops/models/wrappers/pt_model.py | 21 ++++++++++++--------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/cyclops/data/features/medical_image.py b/cyclops/data/features/medical_image.py index 7b7d9cae9..3dcfd9583 100644 --- a/cyclops/data/features/medical_image.py +++ b/cyclops/data/features/medical_image.py @@ -209,11 +209,14 @@ def decode_example( use_auth_token = token_per_repo_id.get(repo_id) except ValueError: use_auth_token = None - with xopen( - path, - "rb", - use_auth_token=use_auth_token, - ) as file_obj, BytesIO(file_obj.read()) as buffer: + with ( + xopen( + path, + "rb", + use_auth_token=use_auth_token, + ) as file_obj, + BytesIO(file_obj.read()) as buffer, + ): image, metadata = self._read_file_from_bytes(buffer) metadata["filename_or_obj"] = path diff --git a/cyclops/evaluate/utils.py b/cyclops/evaluate/utils.py index 835de6263..1bbe99d1e 100644 --- a/cyclops/evaluate/utils.py +++ b/cyclops/evaluate/utils.py @@ -159,9 +159,11 @@ def get_columns_as_array( if isinstance(columns, str): columns = [columns] - with dataset.formatted_as("arrow", columns=columns, output_all_columns=True) if ( - isinstance(dataset, Dataset) and dataset.format != "arrow" - ) else nullcontext(): + with ( + dataset.formatted_as("arrow", columns=columns, output_all_columns=True) + if (isinstance(dataset, Dataset) and dataset.format != "arrow") + else nullcontext() + ): out_arr = squeeze_all( xp.stack( [xp.asarray(dataset[col].to_pylist()) for col in columns], axis=-1 diff --git a/cyclops/models/wrappers/pt_model.py b/cyclops/models/wrappers/pt_model.py index 50b3a4302..37499dec2 100644 --- a/cyclops/models/wrappers/pt_model.py +++ b/cyclops/models/wrappers/pt_model.py @@ -968,14 +968,17 @@ def fit( splits_mapping["validation"] = val_split format_kwargs = {} if transforms is None else {"transform": transforms} - with X[train_split].formatted_as( - "custom" if transforms is not None else "torch", - columns=feature_columns + target_columns, - **format_kwargs, - ), X[val_split].formatted_as( - "custom" if transforms is not None else "torch", - columns=feature_columns + target_columns, - **format_kwargs, + with ( + X[train_split].formatted_as( + "custom" if transforms is not None else "torch", + columns=feature_columns + target_columns, + **format_kwargs, + ), + X[val_split].formatted_as( + "custom" if transforms is not None else "torch", + columns=feature_columns + target_columns, + **format_kwargs, + ), ): self.partial_fit( X, @@ -1309,7 +1312,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs): if include_lr_scheduler: state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined] - epoch = kwargs.get("epoch", None) + epoch = kwargs.get("epoch") if epoch is not None: filename, extension = os.path.basename(filepath).split(".") filepath = join(