-
Notifications
You must be signed in to change notification settings - Fork 358
/
base_workspace.py
145 lines (127 loc) · 4.84 KB
/
base_workspace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Optional
import os
import pathlib
import hydra
import copy
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
import dill
import torch
import threading
class BaseWorkspace:
include_keys = tuple()
exclude_keys = tuple()
def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None):
self.cfg = cfg
self._output_dir = output_dir
self._saving_thread = None
@property
def output_dir(self):
output_dir = self._output_dir
if output_dir is None:
output_dir = HydraConfig.get().runtime.output_dir
return output_dir
def run(self):
"""
Create any resource shouldn't be serialized as local variables
"""
pass
def save_checkpoint(self, path=None, tag='latest',
exclude_keys=None,
include_keys=None,
use_thread=True):
if path is None:
path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
else:
path = pathlib.Path(path)
if exclude_keys is None:
exclude_keys = tuple(self.exclude_keys)
if include_keys is None:
include_keys = tuple(self.include_keys) + ('_output_dir',)
path.parent.mkdir(parents=False, exist_ok=True)
payload = {
'cfg': self.cfg,
'state_dicts': dict(),
'pickles': dict()
}
for key, value in self.__dict__.items():
if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'):
# modules, optimizers and samplers etc
if key not in exclude_keys:
if use_thread:
payload['state_dicts'][key] = _copy_to_cpu(value.state_dict())
else:
payload['state_dicts'][key] = value.state_dict()
elif key in include_keys:
payload['pickles'][key] = dill.dumps(value)
if use_thread:
self._saving_thread = threading.Thread(
target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill))
self._saving_thread.start()
else:
torch.save(payload, path.open('wb'), pickle_module=dill)
return str(path.absolute())
def get_checkpoint_path(self, tag='latest'):
return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs):
if exclude_keys is None:
exclude_keys = tuple()
if include_keys is None:
include_keys = payload['pickles'].keys()
for key, value in payload['state_dicts'].items():
if key not in exclude_keys:
self.__dict__[key].load_state_dict(value, **kwargs)
for key in include_keys:
if key in payload['pickles']:
self.__dict__[key] = dill.loads(payload['pickles'][key])
def load_checkpoint(self, path=None, tag='latest',
exclude_keys=None,
include_keys=None,
**kwargs):
if path is None:
path = self.get_checkpoint_path(tag=tag)
else:
path = pathlib.Path(path)
payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs)
self.load_payload(payload,
exclude_keys=exclude_keys,
include_keys=include_keys)
return payload
@classmethod
def create_from_checkpoint(cls, path,
exclude_keys=None,
include_keys=None,
**kwargs):
payload = torch.load(open(path, 'rb'), pickle_module=dill)
instance = cls(payload['cfg'])
instance.load_payload(
payload=payload,
exclude_keys=exclude_keys,
include_keys=include_keys,
**kwargs)
return instance
def save_snapshot(self, tag='latest'):
"""
Quick loading and saving for reserach, saves full state of the workspace.
However, loading a snapshot assumes the code stays exactly the same.
Use save_checkpoint for long-term storage.
"""
path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl')
path.parent.mkdir(parents=False, exist_ok=True)
torch.save(self, path.open('wb'), pickle_module=dill)
return str(path.absolute())
@classmethod
def create_from_snapshot(cls, path):
return torch.load(open(path, 'rb'), pickle_module=dill)
def _copy_to_cpu(x):
if isinstance(x, torch.Tensor):
return x.detach().to('cpu')
elif isinstance(x, dict):
result = dict()
for k, v in x.items():
result[k] = _copy_to_cpu(v)
return result
elif isinstance(x, list):
return [_copy_to_cpu(k) for k in x]
else:
return copy.deepcopy(x)