-
Notifications
You must be signed in to change notification settings - Fork 5
/
util.py
119 lines (96 loc) · 3.62 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
import json
import glob
import re
import random
from collections import OrderedDict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import torch
from dataset import MaskBaseDataset
def read_json(fname):
fname = Path(fname)
with fname.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)
def update_argument(args, configs):
for arg in configs:
if arg in args:
setattr(args, arg, configs[arg])
else:
raise ValueError(f"no argument {arg}")
return args
def ages_subdiv_to_origin(sdage):
result = []
for age in sdage:
if age < 2:
result.append(0)
elif age < 5:
result.append(1)
else:
result.append(2)
return result
def draw_confusion_matrix(true, pred, dir, num_classes):
cm = confusion_matrix(true, pred)
df = pd.DataFrame(cm/np.sum(cm, axis=1)[:, None],
index=list(range(num_classes)), columns=list(range(num_classes)))
df = df.fillna(0) # NaN 값을 0으로 변경
plt.figure(figsize=(16, 16))
plt.tight_layout()
plt.suptitle('Confusion Matrix')
sns.heatmap(df, annot=True, cmap=sns.color_palette("Blues"))
plt.xlabel("Predicted Label")
plt.ylabel("True label")
plt.savefig(f"{dir}/confusion_matrix.png")
plt.close('all')
def seed_everything(seed):
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def increment_path(path, exist_ok=False):
""" Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.
Args:
path (str or pathlib.Path): f"{model_dir}/{args.name}".
exist_ok (bool): whether increment path (increment if False).
"""
path = Path(path)
if (path.exists() and exist_ok) or (not path.exists()):
return str(path)
else:
dirs = glob.glob(f"{path}*")
matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m]
n = max(i) + 1 if i else 2
return f"{path}{n}"
def grid_image(np_images, gts, preds, n=16, shuffle=False):
batch_size = np_images.shape[0]
assert n <= batch_size
choices = random.choices(range(batch_size), k=n) if shuffle else list(range(n))
figure = plt.figure(figsize=(12, 18 + 2)) # cautions: hardcoded, 이미지 크기에 따라 figsize 를 조정해야 할 수 있습니다. T.T
plt.subplots_adjust(top=0.8) # cautions: hardcoded, 이미지 크기에 따라 top 를 조정해야 할 수 있습니다. T.T
n_grid = int(np.ceil(n ** 0.5))
tasks = ["mask", "gender", "age"]
for idx, choice in enumerate(choices):
gt = gts[choice].item()
pred = preds[choice].item()
image = np_images[choice]
# title = f"gt: {gt}, pred: {pred}"
gt_decoded_labels = MaskBaseDataset.decode_multi_class(gt)
pred_decoded_labels = MaskBaseDataset.decode_multi_class(pred)
title = "\n".join([
f"{task} - gt: {gt_label}, pred: {pred_label}"
for gt_label, pred_label, task
in zip(gt_decoded_labels, pred_decoded_labels, tasks)
])
plt.subplot(n_grid, n_grid, idx + 1, title=title)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image, cmap=plt.cm.binary)
return figure