diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d1b7b4a0c..364535c5e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -211,7 +211,7 @@ jobs: if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all')) run: | . .venv/bin/activate - pip install flax==0.5.0 jax==0.3.13 jaxlib==0.3.10 tensorflow-cpu==2.9.1 optax==0.1.3 + pip install flax==0.6.1 jax==0.4.1 jaxlib==0.4.1 tensorflow-cpu==2.9.1 optax==0.1.3 - name: Install editable (no cache hit) if: steps.virtualenv-cache.outputs.cache-hit != 'true' diff --git a/CHANGELOG.md b/CHANGELOG.md index 0939be634..075775320 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed unnecessary code coverage dev requirements. - Fixed issue where new version of torch caused no LR schedulers to be registered. +- Updated pinned versions of jax, jaxlib, and flax. ## [v1.2.1](https://github.com/allenai/tango/releases/tag/v1.2.1) - 2023-04-06 diff --git a/pyproject.toml b/pyproject.toml index 5548dfd91..31a1807ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,8 +89,9 @@ fairscale = [ ] flax = [ "datasets>=1.12,<3.0", - "jax>=0.3.13", - "flax>=0.5.0", + "jax>=0.4.1,<=0.4.13", + "jaxlib>=0.4.1,<=0.4.13", + "flax>=0.6.1,<=0.7.0", "optax>=0.1.2", "tensorflow-cpu>=2.9.1" ] diff --git a/tango/common/file_lock.py b/tango/common/file_lock.py index b76c2858d..45874a26a 100644 --- a/tango/common/file_lock.py +++ b/tango/common/file_lock.py @@ -39,9 +39,9 @@ def acquire( # type: ignore[override] if err.errno not in (1, 13, 30): raise - if os.path.isfile(self._lock_file) and self._read_only_ok: + if os.path.isfile(self._lock_file) and self._read_only_ok: # type: ignore warnings.warn( - f"Lacking permissions required to obtain lock '{self._lock_file}'. " + f"Lacking permissions required to obtain lock '{self._lock_file}'. " # type: ignore "Race conditions are possible if other processes are writing to the same resource.", UserWarning, ) @@ -62,7 +62,7 @@ def acquire_with_updates(self, desc: Optional[str] = None) -> AcquireReturnProxy from .tqdm import Tqdm if desc is None: - desc = f"acquiring lock at {self._lock_file}" + desc = f"acquiring lock at {self._lock_file}" # type: ignore progress = Tqdm.tqdm(desc=desc, bar_format="{desc} [{elapsed}]") while True: diff --git a/tango/common/from_params.py b/tango/common/from_params.py index 9575c8874..fb7bf63be 100644 --- a/tango/common/from_params.py +++ b/tango/common/from_params.py @@ -514,7 +514,7 @@ def construct_arg( ) elif annotation == str: # Strings are special because we allow casting from Path to str. - if type(popped_params) == str or isinstance(popped_params, Path): + if isinstance(popped_params, str) or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError( diff --git a/tango/integrations/torch/optim.py b/tango/integrations/torch/optim.py index a3147a4c3..d532be246 100644 --- a/tango/integrations/torch/optim.py +++ b/tango/integrations/torch/optim.py @@ -1,3 +1,5 @@ +from typing import Type + import torch from tango.common.registrable import Registrable @@ -73,11 +75,14 @@ class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable): ): Optimizer.register("torch::" + name)(cls) +# Note: This is a hack. Remove after we upgrade the torch version. +base_class: Type +try: + base_class = torch.optim.lr_scheduler.LRScheduler +except AttributeError: + base_class = torch.optim.lr_scheduler._LRScheduler + # Register all learning rate schedulers. for name, cls in torch.optim.lr_scheduler.__dict__.items(): - if ( - isinstance(cls, type) - and issubclass(cls, torch.optim.lr_scheduler.LRScheduler) - and not cls == torch.optim.lr_scheduler.LRScheduler - ): + if isinstance(cls, type) and issubclass(cls, base_class) and not cls == base_class: LRScheduler.register("torch::" + name)(cls) diff --git a/tango/integrations/transformers/__init__.py b/tango/integrations/transformers/__init__.py index 927a126a2..31f6a7049 100644 --- a/tango/integrations/transformers/__init__.py +++ b/tango/integrations/transformers/__init__.py @@ -44,76 +44,10 @@ from tango.integrations.torch import Model from tango.integrations.transformers import * + available_models = [] for name in sorted(Model.list_available()): if name.startswith("transformers::AutoModel"): - print(name) - - .. testoutput:: - - transformers::AutoModel::from_config - transformers::AutoModel::from_pretrained - transformers::AutoModelForAudioClassification::from_config - transformers::AutoModelForAudioClassification::from_pretrained - transformers::AutoModelForAudioFrameClassification::from_config - transformers::AutoModelForAudioFrameClassification::from_pretrained - transformers::AutoModelForAudioXVector::from_config - transformers::AutoModelForAudioXVector::from_pretrained - transformers::AutoModelForCTC::from_config - transformers::AutoModelForCTC::from_pretrained - transformers::AutoModelForCausalLM::from_config - transformers::AutoModelForCausalLM::from_pretrained - transformers::AutoModelForDepthEstimation::from_config - transformers::AutoModelForDepthEstimation::from_pretrained - transformers::AutoModelForDocumentQuestionAnswering::from_config - transformers::AutoModelForDocumentQuestionAnswering::from_pretrained - transformers::AutoModelForImageClassification::from_config - transformers::AutoModelForImageClassification::from_pretrained - transformers::AutoModelForImageSegmentation::from_config - transformers::AutoModelForImageSegmentation::from_pretrained - transformers::AutoModelForInstanceSegmentation::from_config - transformers::AutoModelForInstanceSegmentation::from_pretrained - transformers::AutoModelForMaskGeneration::from_config - transformers::AutoModelForMaskGeneration::from_pretrained - transformers::AutoModelForMaskedImageModeling::from_config - transformers::AutoModelForMaskedImageModeling::from_pretrained - transformers::AutoModelForMaskedLM::from_config - transformers::AutoModelForMaskedLM::from_pretrained - transformers::AutoModelForMultipleChoice::from_config - transformers::AutoModelForMultipleChoice::from_pretrained - transformers::AutoModelForNextSentencePrediction::from_config - transformers::AutoModelForNextSentencePrediction::from_pretrained - transformers::AutoModelForObjectDetection::from_config - transformers::AutoModelForObjectDetection::from_pretrained - transformers::AutoModelForPreTraining::from_config - transformers::AutoModelForPreTraining::from_pretrained - transformers::AutoModelForQuestionAnswering::from_config - transformers::AutoModelForQuestionAnswering::from_pretrained - transformers::AutoModelForSemanticSegmentation::from_config - transformers::AutoModelForSemanticSegmentation::from_pretrained - transformers::AutoModelForSeq2SeqLM::from_config - transformers::AutoModelForSeq2SeqLM::from_pretrained - transformers::AutoModelForSequenceClassification::from_config - transformers::AutoModelForSequenceClassification::from_pretrained - transformers::AutoModelForSpeechSeq2Seq::from_config - transformers::AutoModelForSpeechSeq2Seq::from_pretrained - transformers::AutoModelForTableQuestionAnswering::from_config - transformers::AutoModelForTableQuestionAnswering::from_pretrained - transformers::AutoModelForTokenClassification::from_config - transformers::AutoModelForTokenClassification::from_pretrained - transformers::AutoModelForUniversalSegmentation::from_config - transformers::AutoModelForUniversalSegmentation::from_pretrained - transformers::AutoModelForVideoClassification::from_config - transformers::AutoModelForVideoClassification::from_pretrained - transformers::AutoModelForVision2Seq::from_config - transformers::AutoModelForVision2Seq::from_pretrained - transformers::AutoModelForVisualQuestionAnswering::from_config - transformers::AutoModelForVisualQuestionAnswering::from_pretrained - transformers::AutoModelForZeroShotImageClassification::from_config - transformers::AutoModelForZeroShotImageClassification::from_pretrained - transformers::AutoModelForZeroShotObjectDetection::from_config - transformers::AutoModelForZeroShotObjectDetection::from_pretrained - transformers::AutoModelWithLMHead::from_config - transformers::AutoModelWithLMHead::from_pretrained + available_models.append(name) - :class:`~tango.integrations.torch.Optimizer`: All optimizers from transformers are registered according to their class names (e.g. "transformers::AdaFactor"). diff --git a/tests/common/from_params_test.py b/tests/common/from_params_test.py index 4c0bc4af8..40c983d2f 100644 --- a/tests/common/from_params_test.py +++ b/tests/common/from_params_test.py @@ -470,7 +470,7 @@ def __init__(self, a: str, x: int = 42, **kwargs): assert instance.x == 42 assert instance.a == -1 assert len(instance.rest) == 1 # type: ignore - assert type(instance.rest["raw_a"]) == str # type: ignore + assert isinstance(instance.rest["raw_a"], str) # type: ignore assert instance.rest["raw_a"] == "123" # type: ignore def test_kwargs_are_passed_to_deeper_superclasses(self):