-
Notifications
You must be signed in to change notification settings - Fork 2
/
success.py
executable file
·47 lines (43 loc) · 2.11 KB
/
success.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python2
import argparse
import matplotlib.pyplot as plt
import numpy
import peeking.algorithm
import peeking.concurrent
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('A', type=float)
parser.add_argument('B', type=float)
parser.add_argument('--output')
parser.add_argument('--peeking-frequency', type=int, required=True)
parser.add_argument('--p-value', required=True, type=float)
parser.add_argument('--min-sample-size', type=int)
parser.add_argument('--sample-size', type=int)
parser.add_argument('--runs', required=True, type=int)
args = parser.parse_args()
distributions = (args.A, args.B)
algorithms = [
('500 samples', peeking.algorithm.FixedFrequencyTest(distributions, args.p_value, 500)),
('2000 samples', peeking.algorithm.FixedFrequencyTest(distributions, args.p_value, 2000)),
('4000 samples', peeking.algorithm.FixedFrequencyTest(distributions, args.p_value, 4000)),
('peeking', peeking.algorithm.FrequencyTest(distributions, args.p_value, args.peeking_frequency, args.min_sample_size)),
('thompson', peeking.algorithm.ThompsonSampling(distributions, (1, 1))),
]
results = []
for name, algorithm in algorithms:
# Thompson sampling algorithm results relatively stable
runs = args.runs / 4 if isinstance(algorithm, peeking.algorithm.ThompsonSampling) else args.runs
with peeking.concurrent.run(algorithm.success, runs, ((args.sample_size,),)) as successes:
success_rate = numpy.mean(list(successes)) / float(args.sample_size)
results.append((name, success_rate))
plt.title('A = {:.2f}, B = {:.2f}, {} samples'.format(args.A, args.B, args.sample_size))
plt.barh(range(len(results)), [rate for _, rate in results], align='center')
plt.xlim(min(distributions), max(distributions))
plt.xlabel('Cummulative success rate')
plt.yticks(range(len(results)), [name for name, _ in results])
plt.ylim(len(results) - 0.3, -0.7)
plt.tight_layout()
if args.output:
plt.savefig(args.output)
else:
plt.show()