diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..86613f8 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,14 @@ +name: Ruff +on: [push, pull_request] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: ruff + env: + RUFF_OUTPUT_FORMAT: github + run: | + pip install ruff + ruff format --check + ruff check diff --git a/memsave_torch/__init__.py b/memsave_torch/__init__.py index f941a84..267cc7e 100644 --- a/memsave_torch/__init__.py +++ b/memsave_torch/__init__.py @@ -1,3 +1,4 @@ """memsave_torch package""" + import memsave_torch.nn as nn # noqa: F401 import memsave_torch.util as util # noqa: F401 diff --git a/memsave_torch/get_best_results.py b/memsave_torch/get_best_results.py index c5c88b6..184997e 100644 --- a/memsave_torch/get_best_results.py +++ b/memsave_torch/get_best_results.py @@ -1,5 +1,7 @@ """Simple script that goes over the raw results and finds the best results.""" +import argparse +import os from glob import glob from itertools import product @@ -7,34 +9,58 @@ from memsave_torch.util.collect_results import case_mapping -for device, arch in product(["cuda", "cpu"], ["linear", "conv"]): - # usage stats - df = None - idx_col = ["model", "case"] - for fname in glob(f"results/usage_stats-{arch}-{device}-*.csv"): - with open(fname) as f: - f.readline() - temp_df = pd.read_csv(f, index_col=idx_col) - df = temp_df if df is None else pd.concat([df, temp_df]) - if df is not None: - df = df.rename(index=case_mapping, level=1) - df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024 - df = df.drop(columns=["Memory Usage (MB)"]) - best_results = df.groupby(idx_col).min() - # scale - maxes = best_results.groupby(["model"]).max() - best_results[["Scaled T", "Scaled M"]] = best_results / maxes - best_results.to_csv(f"results/best_results-{arch}-{device}-usage_stats.csv") - - # savings - df = None - idx_col = ["model", "input_vjps"] - for fname in glob(f"results/savings-{arch}-{device}*.csv"): - with open(fname) as f: - f.readline() - temp_df = pd.read_csv(f, index_col=idx_col) - df = temp_df if df is None else pd.concat([df, temp_df]) - - if df is not None: - best_results = df.groupby(idx_col).max() - best_results.to_csv(f"results/best_results-{arch}-{device}-savings.csv") + +def main(base_dir: str): + """Gets the best results from all previous runs + + Args: + base_dir (str): The base results dir + """ + for device, arch in product(["cuda", "cpu"], ["linear", "conv"]): + # usage stats + df = None + idx_col = ["model", "case"] + for fname in glob(os.path.join(base_dir, f"usage_stats-{arch}-{device}-*.csv")): + with open(fname) as f: + f.readline() + temp_df = pd.read_csv(f, index_col=idx_col) + df = temp_df if df is None else pd.concat([df, temp_df]) + if df is not None: + df = df.rename(index=case_mapping, level=1) + df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024 + df = df.drop(columns=["Memory Usage (MB)"]) + best_results = df.groupby(idx_col).min() + # scale + maxes = best_results.groupby(["model"]).max() + best_results[["Scaled T", "Scaled M"]] = best_results / maxes + best_results.to_csv( + os.path.join(base_dir, f"best_results-{arch}-{device}-usage_stats.csv") + ) + + # savings + df = None + idx_col = ["model", "input_vjps"] + for fname in glob(os.path.join(base_dir, f"savings-{arch}-{device}*.csv")): + with open(fname) as f: + f.readline() + temp_df = pd.read_csv(f, index_col=idx_col) + df = temp_df if df is None else pd.concat([df, temp_df]) + + if df is not None: + best_results = df.groupby(idx_col).max() + best_results.to_csv( + os.path.join(base_dir, f"best_results-{arch}-{device}-savings.csv") + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--results_dir", type=str, default="results/", help="The base results dir" + ) + args = parser.parse_args() + + base_dir = args.results_dir + os.path.exists(base_dir) + + main(base_dir) diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index 8c026e0..1733210 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -107,7 +107,7 @@ def recursive_setattr(obj: nn.Module, attr: str, value: nn.Module, clone_params: clone_params (bool): Whether to make a copy of the parameters or reuse them """ attr_split = attr.split(".", 1) - if len(attr) == 1: + if len(attr_split) == 1: setattr(obj, attr_split[0], value) if clone_params: value.load_state_dict(value.state_dict()) # makes a copy diff --git a/memsave_torch/paper_demo.py b/memsave_torch/paper_demo.py index 3d82a23..9ecd3ce 100644 --- a/memsave_torch/paper_demo.py +++ b/memsave_torch/paper_demo.py @@ -89,7 +89,7 @@ ], ] -pbar = tqdm(total=len(models) * len(estimators) * 3, leave=False) +pbar = tqdm(total=len(models) * len(estimators) * len(cases), leave=False) collector = collect_results.ResultsCollector( batch_size, input_channels, @@ -111,7 +111,7 @@ pbar.set_description(f"{model} {estimate} case {case}") case_str = f"--case {' '.join(case)}" if case is not None else "" cmd = ( - f"python estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} " + f"python memsave_torch/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} " + f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}" ) proc = subprocess.run(shlex.split(cmd), capture_output=True) diff --git a/memsave_torch/util/collect_results.py b/memsave_torch/util/collect_results.py index 52a2f3e..9314d8f 100644 --- a/memsave_torch/util/collect_results.py +++ b/memsave_torch/util/collect_results.py @@ -62,8 +62,6 @@ def hyperparam_str(args: SimpleNamespace) -> str: class ResultsCollector: """This class collects results by reading from the results/ directory""" - # TODO: Maybe change to results/temp - def __init__( self, batch_size: int, @@ -74,6 +72,8 @@ def __init__( architecture: str, vjp_improvements: List[float], cases: List[Union[None, List[str]]], + results_dir: str, + print: Callable = tqdm.write, ) -> None: """Initialize the collector before all runs. @@ -86,6 +86,8 @@ def __init__( architecture (str): conv or linear vjp_improvements (List[float]): vjp_improvements cases (List[Union[None, List[str]]]): list of cases + results_dir (str): The base results dir + print (Callable, optional): Which function to use for printing (i.e. `print()` causes problems in a tqdm context) """ # TODO: architecture is pointless since there is no arch-specific code anymore self.batch_size = batch_size @@ -97,8 +99,8 @@ def __init__( self.vjp_improvements = vjp_improvements self.cases = cases # assert len(cases) == 3, f"len(cases) > 3:\n{cases}" Not anymore - self.base_location = f"results/{architecture}-" - os.makedirs("results/", exist_ok=True) + self.base_location = os.path.join(results_dir, "raw") + os.makedirs(self.base_location, exist_ok=True) self.savings = pd.DataFrame( columns=["model", "input_vjps", strings["time"][4], strings["memory"][4]] ) @@ -107,6 +109,7 @@ def __init__( ) self.savings.set_index(["model", "input_vjps"], inplace=True) self.usage_stats.set_index(["model", "case"], inplace=True) + self.print = print def collect_from_file(self, estimate: str, model: str): """To be called after all cases of a model have finished. @@ -118,25 +121,25 @@ def collect_from_file(self, estimate: str, model: str): Raises: e: Description """ - with open(f"results/{estimate}-{self.architecture}.txt") as f: + with open(f"{self.base_location}/{estimate}-{self.architecture}.txt") as f: lines = f.readlines() try: assert ( len(lines) == len(self.cases) - ), f"More than {len(self.cases)} lines found in results/{estimate}-{self.architecture}.txt:\n{lines}" + ), f"More than {len(self.cases)} lines found in {self.base_location}/{estimate}-{self.architecture}.txt:\n{lines}" outputs = [float(line.strip()) for line in lines] for case, out in zip(self.cases, outputs): self.usage_stats.loc[ (model, make_case_str(case)), strings[estimate][5] ] = out - self._display_run(outputs, estimate, model) + self._display_run(outputs, estimate, model, self.print) except AssertionError as e: raise e except ValueError as e: print( - f'File results/{estimate}-{self.architecture}.txt has unallowed text. Contents: \n{"".join(lines)}' + f'File {self.base_location}/{estimate}-{self.architecture}.txt has unallowed text. Contents: \n{"".join(lines)}' ) raise e finally: @@ -148,7 +151,7 @@ def clear_file(self, estimate: str): Args: estimate (str): time or memory """ - with open(f"results/{estimate}-{self.architecture}.txt", "w") as f: + with open(f"{self.base_location}/{estimate}-{self.architecture}.txt", "w") as f: f.write("") def _display_run( @@ -156,7 +159,7 @@ def _display_run( outputs: List[float], estimate: str, model: str, - print: Callable = tqdm.write, + print: Callable, ): """Function to display the data collected over all cases for a model. @@ -164,7 +167,7 @@ def _display_run( outputs (List[float]): The collected outputs estimate (str): time or memory model (str): The name of the model - print (Callable, optional): Which function to use for printing (i.e. `print()` causes problems in a tqdm context) + print (Callable): Which function to use for printing (i.e. `print()` causes problems in a tqdm context) """ # print(f"{model} input ({input_channels},{input_HW},{input_HW}) {device}") # print('='*78) @@ -176,39 +179,40 @@ def _display_run( f"{strings[estimate][1]} ({case_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" ) - q_conv_weight = outputs[1] - outputs[2] - ratio = q_conv_weight / outputs[0] - if estimate == "time": - print( - f"{self.architecture.capitalize()} weight VJPs use {100 * ratio:.1f}% of time" - ) - else: - print( - f"Information for {self.architecture} weight VJPs uses {100 * ratio:.1f}% of memory" - ) - # self.models.loc[model, ''] - - tot_improvements = [ - 1 - (1 - improvement) * ratio for improvement in self.vjp_improvements - ] - for vjp, tot in zip(self.vjp_improvements, tot_improvements): - print( - f"Weight VJP {strings[estimate][2]} of {vjp:.2f}x ({1 / vjp:.1f}x {strings[estimate][3]})" - + f" would lead to total {strings[estimate][2]} of {tot:.2f}x ({1 / tot:.1f}x {strings[estimate][3]})" - ) - self.savings.loc[(model, vjp), strings[estimate][4]] = f"{1 / tot:.1f}x" + # CODE ONLY APPLIES WITH OLD RUNDEMO.PY + # q_conv_weight = outputs[1] - outputs[2] + # ratio = q_conv_weight / outputs[0] + # if estimate == "time": + # print( + # f"{self.architecture.capitalize()} weight VJPs use {100 * ratio:.1f}% of time" + # ) + # else: + # print( + # f"Information for {self.architecture} weight VJPs uses {100 * ratio:.1f}% of memory" + # ) + # # self.models.loc[model, ''] + + # tot_improvements = [ + # 1 - (1 - improvement) * ratio for improvement in self.vjp_improvements + # ] + # for vjp, tot in zip(self.vjp_improvements, tot_improvements): + # print( + # f"Weight VJP {strings[estimate][2]} of {vjp:.2f}x ({1 / vjp:.1f}x {strings[estimate][3]})" + # + f" would lead to total {strings[estimate][2]} of {tot:.2f}x ({1 / tot:.1f}x {strings[estimate][3]})" + # ) + # self.savings.loc[(model, vjp), strings[estimate][4]] = f"{1 / tot:.1f}x" print("") def finish(self): """To be called after ALL cases on all models have been run, saves dataframes to csv files.""" time = datetime.now().strftime("%d.%m.%y %H.%M") s = f"input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" - savings_path = f"results/savings-{self.architecture}-{self.device}-{time}.csv" - with open(savings_path, "w") as f: - f.write(s + "\n") - self.savings.to_csv(savings_path, mode="a") + # savings_path = f"{self.base_location}/../savings-{self.architecture}-{self.device}-{time}.csv" + # with open(savings_path, "w") as f: + # f.write(s + "\n") + # self.savings.to_csv(savings_path, mode="a") - usage_path = f"results/usage_stats-{self.architecture}-{self.device}-{time}.csv" + usage_path = f"{self.base_location}/../usage_stats-{self.architecture}-{self.device}-{time}.csv" with open(usage_path, "w") as f: f.write(s + "\n") self.usage_stats.to_csv(usage_path, mode="a") diff --git a/memsave_torch/util/estimate.py b/memsave_torch/util/estimate.py index 41746f5..79a6238 100644 --- a/memsave_torch/util/estimate.py +++ b/memsave_torch/util/estimate.py @@ -11,6 +11,7 @@ """ import argparse +import os from typing import Callable, Dict, List, Optional from torch import Tensor, device, manual_seed, rand, randint @@ -79,7 +80,12 @@ def skip_case_check(args: argparse.Namespace) -> bool: if c not in args.case and args.model in models.models_without_norm: invalid = True if invalid: - with open(f"results/{args.estimate}-conv.txt", "a") as f: + with open( + os.path.join( + args.results_dir, f"raw/{args.estimate}-{args.architecture}.txt" + ), + "a", + ) as f: f.write("-1\n") return invalid @@ -90,8 +96,10 @@ def estimate_speedup( x: Tensor, y: Tensor, targets: Optional[List[Dict[str, Tensor]]], + architecture: str, dev: device, case: List[str], + results_dir: str, return_val: bool = False, ): """Save an estimate of total training speed-up caused by a weight VJP speed-up. @@ -102,8 +110,10 @@ def estimate_speedup( x: Input to the model. y: Labels of the input. targets: Targets in case of detection model + architecture: linear or conv dev: Device to run the computation on. case: str indicating which grads to take + results_dir: See args.results_dir return_val: Whether to return the value or save it (Default: Save) Returns: @@ -131,7 +141,7 @@ def estimate_speedup( if return_val: return result - with open("results/time-conv.txt", "a") as f: + with open(os.path.join(results_dir, f"raw/time-{architecture}.txt"), "a") as f: # f.write(f"{args.model},{loss_fn.__name__},{dev},{case},{result},{x.shape},{y.shape}\n") f.write(f"{result}\n") @@ -142,8 +152,10 @@ def estimate_mem_savings( x: Tensor, y: Tensor, targets: Optional[List[Dict[str, Tensor]]], + architecture: str, dev: device, case: List[str], + results_dir: str, return_val: bool = False, ): """Print an estimate of the memory savings caused by weight VJP memory savings. @@ -154,8 +166,10 @@ def estimate_mem_savings( x: Input to the model. y: Labels of the input. targets: Targets in case of detection model + architecture: linear or conv dev: Device to run the computation on. case: str indicating which grads to take + results_dir: See args.results_dir return_val: Whether to return the value or save it (Default: Save) Returns: @@ -181,7 +195,7 @@ def estimate_mem_savings( if return_val: return result - with open("results/memory-conv.txt", "a") as f: + with open(os.path.join(results_dir, f"raw/memory-{architecture}.txt"), "a") as f: # f.write(f"{args.model},{loss_fn.__name__},{dev},{case},{result},{x.shape},{y.shape}\n") f.write(f"{result}\n") @@ -232,9 +246,13 @@ def estimate_mem_savings( default=False, help="Print result to stdout instead of writing to file", ) + parser.add_argument( + "--results_dir", type=str, default="results/", help="the base results dir" + ) args = parser.parse_args() + assert os.path.exists(args.results_dir) if not skip_case_check(args): dev = device(args.device) @@ -297,11 +315,29 @@ def estimate_mem_savings( if args.estimate == "time": res = estimate_speedup( - model_fn, loss_fn, x, y, targets, dev, args.case, args.print + model_fn, + loss_fn, + x, + y, + targets, + args.architecture, + dev, + args.case, + args.results_dir, + args.print, ) elif args.estimate == "memory": res = estimate_mem_savings( - model_fn, loss_fn, x, y, targets, dev, args.case, args.print + model_fn, + loss_fn, + x, + y, + targets, + args.architecture, + dev, + args.case, + args.results_dir, + args.print, ) if args.print: print(res)