-
Notifications
You must be signed in to change notification settings - Fork 15
/
train_neat.py
85 lines (64 loc) · 2.23 KB
/
train_neat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import multiprocessing as mp
import pickle
import os
import neat
from gomoku import *
runs_per_net = 5
w, h = 10, 10
max_generation = 50
def eval_genome(genome, config):
net = neat.nn.FeedForwardNetwork.create(genome, config)
fitnesses = []
for runs in range(runs_per_net):
fitness = 0
board = Board(w=w, h=h)
game = Gomoku(board=board)
game.next() # first random place
for i in range(w * h - 1):
player = game.current_player
# 1: current player, -1: opposite player, 0: space available
inputs = game.board.board.copy()
inputs[(inputs != player) & (inputs != 0)] = -1
inputs[(inputs == player) & (inputs != 0)] = 1
inputs = inputs.astype(np.float32).flatten()
# 1: space available, 0: unavailable
# spaces = game.board.board.copy()
# spaces[spaces == 0] = 10
# spaces[spaces != 10] = 0
# spaces[spaces == 10] = 1
action = np.array(net.activate(inputs), np.float32).reshape((board.h, board.w))
x, y = np.unravel_index(np.argmax(action), action.shape)
# x, y = np.clip(net.activate(inputs), 0, 9).astype(np.int)
if game.board.board[y][x] == 0:
game.next(x=x, y=y)
fitness += 2
else:
game.next()
fitness -= 1
won_player = game.check_won()
if won_player > 0:
fitness += 50
break
fitnesses.append(fitness)
return min(fitnesses)
def eval_genomes(genomes, config):
for genome_id, genome in genomes:
genome.fitness = eval_genome(genome, config)
config = neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
'config'
)
pop = neat.Population(config)
stats = neat.StatisticsReporter()
pop.add_reporter(stats)
pop.add_reporter(neat.StdOutReporter(True))
# winner = pop.run(eval_genomes)
pe = neat.ParallelEvaluator(mp.cpu_count(), eval_genome)
winner = pop.run(pe.evaluate, n=max_generation)
os.makedirs('result', exist_ok=True)
with open('result/winner', 'wb') as f:
pickle.dump(winner, f)
print(winner)