-
Notifications
You must be signed in to change notification settings - Fork 0
/
dpll.py
159 lines (145 loc) · 6.68 KB
/
dpll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Uses CNF input as described here:
# http://www.satcompetition.org/2011/format-benchmarks2011.html
import copy
import math
# Given a symbol, get its assigned value
def _get_assignment(symbol, assignments):
value = assignments[abs(symbol)-1]
if (symbol > 0 and value == True) or (symbol < 0 and value == False):
return True
elif(symbol > 0 and value == False) or (symbol < 0 and value == True):
return False
else:
return None
# Update the assignment list for a symbol
def _set_assignment(value, symbol, assignments):
if symbol < 0:
assignments[abs(symbol)-1] = not value
return
else:
assignments[symbol-1] = value
return
# Reduce the clauses using the updated assignment list
def _reduce(assignments, clauses):
reduced_clauses = copy.deepcopy(clauses)
for i, clause in enumerate(clauses):
# We have to check the type because we could have a unit clause [1]
# that evaluates to [True]
if type(clause[0]) != type(True):
for symbol in clause:
# Check if symbols in clause have been assigned values
value = _get_assignment(symbol, assignments)
if value != None:
# If the symbol is True, we can set the whole clause to true
if value == True:
reduced_clauses[i] = [True]
break
# If the symbol is False, then we can drop it from the clause
if value == False:
reduced_clauses[i].remove(symbol)
return reduced_clauses
# Set to true any symbols that are never negated
# Set to false any symbols that are always negated
def _find_pure_symbols(symbols, assignments, clauses):
pure_symbols = []
for symbol in symbols:
# We don't have to check for symbols that have an assignment
if _get_assignment(symbol, assignments) != None:
continue
seen_positive = False
seen_negative = False
for clause in clauses:
# We have to check the type because we could have a unit clause [1]
# that evaluates to [True]
if type(clause[0]) != type(True):
for instance in clause:
# When we see a symbol, record whether it was negated or not
if symbol == instance:
seen_positive = True
break
if symbol == -instance:
seen_negative = True
break
# If we say both normal and negated versions of the symbol,
# it is not a pure literal
if seen_positive and seen_negative:
break
if seen_positive ^ seen_negative:
pure_symbols.append(symbol)
_set_assignment(True, int(math.pow(-1, seen_negative)*symbol), assignments)
return pure_symbols
def _find_unit_clauses(symbols, assignments, clauses):
unit_clause_symbols = []
for clause in clauses:
# We have to check the type because we could have a unit clause [1]
# that evaluates to [True]
if (type(clause[0]) != type(True)) and len(clause) == 1:
_set_assignment(True, clause[0], assignments)
unit_clause_symbols.append(clause[0])
return unit_clause_symbols
def dpll(symbols, assignments, clauses, recursion_depth, max_list):
clauses_matched = len([clause for clause in clauses if len(clause) != 0 and type(clause[0]) == type(True)])
# Keep track of the maximum clauses satisfied
# and maximum recursion depth reached
if clauses_matched > max_list[0]:
max_list[0] = clauses_matched
if recursion_depth > max_list[1]:
max_list[1] = recursion_depth
print(str(recursion_depth) + ", " + str(clauses_matched))
updated_assignments = copy.deepcopy(assignments)
# Check if all clauses are already true
if all([clause == [True] for clause in clauses]):
return True
# Check if we hit an an unsatisfiable branch
if any([clause == [] for clause in clauses]):
return False
# Propagate pure literals and unit clauses
pure_symbols = _find_pure_symbols(symbols, updated_assignments, clauses)
if len(pure_symbols) > 0:
reduced_clauses = _reduce(updated_assignments, clauses)
return dpll(symbols, updated_assignments, reduced_clauses, recursion_depth+1, max_list)
unit_clauses = _find_unit_clauses(symbols, updated_assignments, clauses)
if len(unit_clauses) > 0:
reduced_clauses = _reduce(updated_assignments, clauses)
return dpll(symbols, updated_assignments, reduced_clauses, recursion_depth+1, max_list)
# Branch to try both the positive and negative assignment of the next unassigned symbol
guess_symbol = next(symbol for symbol in symbols if _get_assignment(symbol, updated_assignments) == None)
guess_true = updated_assignments[0:guess_symbol-1] + [True] + updated_assignments[guess_symbol:]
guess_false = updated_assignments[0:guess_symbol-1] + [False] + updated_assignments[guess_symbol:]
return dpll(symbols, guess_true, _reduce(guess_true, clauses), recursion_depth+1, max_list) \
or dpll(symbols, guess_false, _reduce(guess_false, clauses), recursion_depth+1, max_list)
# Provide a filename to log the recursion depth and clauses satisfied values
# If not filename is provided, all logging is printed to stdout
if __name__ == "__main__":
import sys
user_input = input()
while(user_input[0] == 'c'):
user_input = input()
_, _, nbvar, nbclauses = user_input.split()
nbvar = int(nbvar)
nbclauses = int(nbclauses)
symbols = [i for i in range(1, nbvar+1)]
assignments = [None] * nbvar
clauses = [None] * nbclauses
max_recursion_depth = 0
max_clauses_sat = 0
max_list = [max_clauses_sat, max_recursion_depth]
for i in range(nbclauses):
user_input = input()
clauses[i] = [int(i) for i in user_input.split()[:-1]]
# Special thanks to Stack Abuse for teaching me the stdout swap technique
# Jacob Stopak, https://stackabuse.com/writing-to-a-file-with-pythons-print-function/
original_stdout = sys.stdout
if(len(sys.argv) > 1):
f = open(sys.argv[1], 'w')
sys.stdout = f
if(dpll(symbols, assignments, clauses, 0, max_list)):
sys.stdout = original_stdout
print(str(len(symbols)) + ", " + str(len(clauses)) +
", SATISFIABLE, " + str(max_list[0]) +
", " + str(max_list[1]))
else:
sys.stdout = original_stdout
print(str(len(symbols)) + ", " + str(len(clauses)) +
", UNSATISFIABLE, " + str(max_list[0]) +
", " + str(max_list[1]))