-
Notifications
You must be signed in to change notification settings - Fork 5
/
solve.py
133 lines (109 loc) · 3.97 KB
/
solve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import config.config as config
import argparse
import os
from search.BWAS import batchedWeightedAStarSearch
from environment.getEnvironment import getEnvironment
from networks.getNetwork import getNetwork
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n", "--network", required=True, help="Path of Saved Network", type=str
)
parser.add_argument(
"-c", "--config", required=True, help="Path of Config File", type=str
)
parser.add_argument(
"-s", "--scrambleDepth", default=1000, help="Depth of Scramble", type=int
)
parser.add_argument(
"-hf", "--heuristicFunction", default="net", help="net or manhattan", type=str
)
parser.add_argument(
"-ns", "--numSolve", default=100, help="Number to Solve", type=int
)
args = parser.parse_args()
conf = config.Config(args.config)
env = getEnvironment(conf.puzzle)(conf.puzzleSize)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.heuristicFunction == "net":
net = getNetwork(conf.puzzle, conf.networkType)(conf.puzzleSize)
loadPath = args.network
if not os.path.isfile(loadPath):
raise ValueError("No Network Saved in this Path")
net.to(device)
net.load_state_dict(torch.load(loadPath)) # ["net_state_dict"])
net.eval()
heuristicFn = net
elif args.heuristicFunction == "manhattan" and conf.puzzle == "puzzleN":
heuristicFn = env.manhattanDistance
else:
raise ValueError("Invalid Heuristic Function")
movesList = []
numNodesGeneratedList = []
searchItrList = []
isSolvedList = []
timeList = []
numToSolve = args.numSolve
for i in range(1, numToSolve + 1):
scramble = env.generateScramble(args.scrambleDepth)
(
moves,
numNodesGenerated,
searchItr,
isSolved,
time,
) = batchedWeightedAStarSearch(
scramble,
conf.depthWeight,
conf.numParallel,
env,
heuristicFn,
device,
conf.maxSearchItr,
)
if moves:
movesList.append(len(moves))
numNodesGeneratedList.append(numNodesGenerated)
searchItrList.append(searchItr)
isSolvedList.append(isSolved)
timeList.append(time)
if isSolved:
print("Solved!")
print(scramble)
print("Moves are %s" % "".join(moves))
print("Solve Length is %i" % len(moves))
print("%i Nodes were generated" % numNodesGenerated)
print("There were %i search iterations" % searchItr)
print("Time of Solve is %.3f seconds" % time)
else:
print(scramble)
print("Max Search Iterations Reached")
print("%i Nodes were generated" % numNodesGenerated)
print("Search time was %.3f seconds" % time)
print("%d out of %d" % (i, numToSolve), flush=True)
print("Average Move Count %.2f" % (sum(movesList) / len(movesList)))
print(
"Average Nodes Generated: %.2f"
% (sum(numNodesGeneratedList) / len(numNodesGeneratedList))
)
print("Number Solved: %d" % isSolvedList.count(True))
print("Average Time: %.2f seconds" % (sum(timeList) / len(timeList)))
print(
"Average Time of Successful Solves is %.2f seconds"
% (
sum([i for (i, v) in zip(timeList, isSolvedList) if v])
/ len([i for (i, v) in zip(timeList, isSolvedList) if v])
)
)
print(
"Average Search Iterations of Successful Solves is %.2f"
% (
sum([i for (i, v) in zip(searchItrList, isSolvedList) if v])
/ len([i for (i, v) in zip(searchItrList, isSolvedList) if v])
)
)
print(
"Max Search Iterations of Successful Solve is %d"
% (max([i for (i, v) in zip(searchItrList, isSolvedList) if v]))
)