Skip to content

Commit

Permalink
format + prep for merge
Browse files Browse the repository at this point in the history
(cherry picked from commit c444edd)
  • Loading branch information
plutonium-239 committed Aug 22, 2024
1 parent 7cde54b commit 51537c6
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import itertools
import math
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torchvision.models as tvm
Expand Down Expand Up @@ -111,7 +111,18 @@ def get_transformers_config(model_name: str) -> AutoConfig:
return AutoConfig.from_pretrained(props.hf_name, **props.extra_kwargs)


def get_arch_models(arch: str):
def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]:
"""Get the dict of all defined functions for an architecture
Args:
arch (str): The architecture
Returns:
Tuple[Dict[str, Callable], Any]: Dict of all defined functions
Raises:
ValueError: Invalid architecture
"""
if arch == "conv":
return conv_model_fns, conv_input_shape
if arch == "transformer":
Expand Down

0 comments on commit 51537c6

Please sign in to comment.