-
Notifications
You must be signed in to change notification settings - Fork 0
/
equiv_query_random.py
105 lines (94 loc) · 4.13 KB
/
equiv_query_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
from typing import *
import WFA
import ContinuousStateMachine
import util
from time import time
import equiv_query
import os.path
import json
mylogger = logging.getLogger("rnn2wfa").getChild("equiv_query_random")
class EquivalenceQueryParameters:
comment: str
max_length: int
eps: float
time_limit: Optional[int]
stabilized_allowance: float
stabilized_period: int
shutdown_accuracy: float
exclude_list: List[str]
random_seed: int
def __init__(self,
comment: str,
eps: float,
max_length: int,
train_size: int,
stabilized_allowance: float,
stabilized_period: int,
shutdown_accuracy: float,
random_seed: int,
time_limit: Optional[int] = None):
self.comment = comment
self.eps = eps
self.max_length = max_length
self.train_size = train_size
self.stabilized_allowance = stabilized_allowance
self.stabilized_period = stabilized_period
self.shutdown_accuracy = shutdown_accuracy
self.time_limit: Optional[int] = time_limit
self.random_seed = random_seed
class EquivalenceQueryAnswerer(equiv_query.EquivalenceQueryAnswererBase):
start: float
rnn: ContinuousStateMachine.ContinuousStateMachine
train_set: List[str]
def __init__(self, rnn: ContinuousStateMachine.ContinuousStateMachine, params: EquivalenceQueryParameters,
dirname: str):
self.rnn = rnn
self.params = params
self.acc_history = []
self.dirname = dirname
with open(os.path.join(self.dirname, "test.txt"), "r") as f:
exclude_list = [x.strip() for x in f.readlines()]
mylogger.info(f"exclude_list: {exclude_list}")
self.train_set = util.make_words(self.rnn.alphabet, self.params.max_length, self.params.train_size,
util.sample_length_from_all_lengths, exclude_list,
self.params.random_seed)
with open(os.path.join(self.dirname, "eqqt_train_set.json"), "w") as f:
json.dump(self.train_set, f)
def _reset_timeout(self):
self.start = time()
def _assert_not_timeout(self):
if self.params.time_limit is not None:
if time() - self.start > self.params.time_limit:
raise equiv_query.EquivalenceQueryTimedOut()
def answer_query(self, wfa: WFA.WFA, existing_ces: Iterable[str], assert_timeout: Callable[[], None]) -> Tuple[
equiv_query.ResultAnswerQuery.T, Any]:
self._reset_timeout()
mylogger.info("Starting answer_query")
self._reset_timeout()
word2diff: Dict[str, float] = {}
correct = 0
for i, word in enumerate(self.train_set):
if i % 100 == 0:
mylogger.info(f"train{i}")
assert_timeout()
diff = abs(self.rnn.get_value(word) - wfa.get_value(word))
word2diff[word] = diff
if diff < self.params.eps:
correct += 1
self._assert_not_timeout()
acc = correct / len(self.train_set)
self.acc_history.append(acc)
if acc > self.params.shutdown_accuracy:
mylogger.info(f"They seems equivalent because of the sufficient accuracy: {acc}")
return equiv_query.ResultAnswerQuery.Equivalent(), None
if len(self.acc_history) >= self.params.stabilized_period:
history_tail = self.acc_history[-self.params.stabilized_period:]
if max(history_tail) - min(history_tail) < self.params.stabilized_allowance:
mylogger.info(f"They seems Equivalent because of the stabilized accuracy: {history_tail}")
return equiv_query.ResultAnswerQuery.Equivalent(), None
# if they are not equivalent, returns the word of the biggest diff
argmax = util.argmax_dict(word2diff)
mylogger.info(
f"Accuracy is insufficient {acc}. The difference is {word2diff[argmax]} and the word is {argmax}.")
return equiv_query.ResultAnswerQuery.Counterexample(argmax), None