Skip to content

Commit

Permalink
fix inputs for transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 16, 2024
1 parent 1cf9d68 commit 421aa6b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions experiments/util/estimate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Estimate possible speed-up when randomizing the weight VJP of convolutions.
We take a CNN and answer the following question:
We take a CNN and answer the following questions:
Q1) What is the relative run time consumed by the weight VJP for convolutions?
Expand Down Expand Up @@ -48,7 +48,7 @@ def parse_case(case: Optional[List[str]]) -> Dict[str, bool]:
case (Optional[List[str]]): List of all cases
Returns:
Dict[str, bool]: dictionary with keys as allowed_cases present in the input (which dont start with 'no_')
Dict[str, bool]: dictionary with keys as allowed_cases present in the input (which dont start with `no_`)
"""
kw = {}
if case is None:
Expand Down Expand Up @@ -307,6 +307,12 @@ def estimate_mem_savings(
)
loss_fn_orig = loss_fn
loss_fn = lambda: models.SegmentationLossWrapper(loss_fn_orig) # noqa: E731
elif args.model in models.transformers_models:
config = models.get_transformers_config(args.model)
model_fn_orig = model_fn
model_fn = lambda: models.TransformersModelWrapper(model_fn_orig)
x = randint(config.vocab_size, (batch_size, args.input_hw), device=dev)
y = randint(config.vocab_size, (batch_size, args.input_hw), device=dev)

# warm-up
# with redirect_stdout(open(devnull, "w")):
Expand Down

0 comments on commit 421aa6b

Please sign in to comment.