-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
56 lines (48 loc) Β· 1.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import transformers
import yaml
from scipy.stats import pearsonr
def compute_pearson_correlation(
pred: transformers.trainer_utils.EvalPrediction,
) -> dict:
"""
νΌμ΄μ¨ μκ΄ κ³μλ₯Ό κ³μ°ν΄μ£Όλ ν¨μ
Args:
pred (torch.Tensor): λͺ¨λΈμ μμΈ‘κ°κ³Ό λ μ΄λΈμ ν¬ν¨ν λ°μ΄ν°
Returns:
perason_correlation (dict): μ
λ ₯κ°μ ν΅ν΄ κ³μ°ν νΌμ΄μ¨ μκ΄ κ³μ
"""
preds = pred.predictions.flatten()
labels = pred.label_ids.flatten()
perason_correlation = {"pearson_correlation": pearsonr(preds, labels)[0]}
return perason_correlation
def seed_everything(seed: int) -> None:
"""
λͺ¨λΈμμ μ¬μ©νλ λͺ¨λ λλ€ μλλ₯Ό κ³ μ ν΄μ£Όλ ν¨μ
Args:
seed (int): μλ κ³ μ μ μ¬μ©ν μ μκ°
Returns:
None
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = True
def load_yaml(path: str) -> dict:
"""
λͺ¨λΈ νλ ¨, μμΈ‘μ μ¬μ©ν yaml νμΌμ λΆλ¬μ€λ ν¨μ
Args:
path (str): λΆλ¬μ¬ yaml νμΌμ κ²½λ‘
Returns:
loaded_yaml (dict): μ§μ ν κ²½λ‘μμ λΆλ¬μ¨ yaml νμΌ λ΄μ©
"""
with open(path, "r") as f:
loaded_yaml = yaml.load(f, Loader=yaml.FullLoader)
return loaded_yaml