Skip to content

Commit

Permalink
Merge pull request swiftlang#77174 from slavapestov/fix-undo-order
Browse files Browse the repository at this point in the history
Sema: Undo changes in chronological order in SolverTrail::undo()
  • Loading branch information
slavapestov authored Nov 20, 2024
2 parents e6b4e0f + a19c92a commit a7b1e78
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 82 deletions.
40 changes: 28 additions & 12 deletions include/swift/Sema/ConstraintGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,12 @@ class ConstraintGraphNode {
/// gets removed for a constraint graph.
void retractFromInference(Constraint *constraint);

/// Re-evaluate the given constraint. This happens when there are changes
/// in associated type variables e.g. bound/unbound to/from a fixed type,
/// equivalence class changes.
void reintroduceToInference(Constraint *constraint);

/// Similar to \c introduceToInference(Constraint *, ...) this method is going
/// to notify inference that this type variable has been bound to a concrete
/// type.
/// Perform graph updates that must be undone after we bind a fixed type
/// to a type variable.
void retractFromInference(Type fixedType);

/// Perform graph updates that must be undone before we bind a fixed type
/// to a type variable.
///
/// The reason why this can't simplify be a part of \c bindTypeVariable
/// is related to the fact that it's sometimes expensive to re-compute
Expand All @@ -161,12 +159,18 @@ class ConstraintGraphNode {
///
/// This is useful in situations when type variable gets bound and unbound,
/// or equivalence class changes.
void notifyReferencingVars() const;
void notifyReferencingVars(
llvm::function_ref<void(ConstraintGraphNode &,
Constraint *)> notification) const;

/// Notify all of the type variables referenced by this one about a change.
void notifyReferencedVars(
llvm::function_ref<void(ConstraintGraphNode &)> notification);
llvm::function_ref<void(ConstraintGraphNode &)> notification) const;

void updateFixedType(
Type fixedType,
llvm::function_ref<void (ConstraintGraphNode &,
Constraint *)> notification) const;
/// }

/// The constraint graph this node belongs to.
Expand Down Expand Up @@ -261,16 +265,28 @@ class ConstraintGraph {
/// Primitive form for SolverTrail::Change::undo().
void removeConstraint(TypeVariableType *typeVar, Constraint *constraint);

/// Prepare to merge the given node into some other node.
///
/// This records graph changes that must be undone after the merge has
/// been undone.
void mergeNodesPre(TypeVariableType *typeVar2);

/// Merge the two nodes for the two given type variables.
///
/// The type variables must actually have been merged already; this
/// operation merges the two nodes.
/// operation merges the two nodes. This also records graph changes
/// that must be undone before the merge can be undone.
void mergeNodes(TypeVariableType *typeVar1, TypeVariableType *typeVar2);

/// Bind the given type variable to the given fixed type.
void bindTypeVariable(TypeVariableType *typeVar, Type fixedType);

/// Introduce the type variable's fixed type to inference.
/// Perform graph updates that must be undone after we bind a fixed type
/// to a type variable.
void retractFromInference(TypeVariableType *typeVar, Type fixedType);

/// Perform graph updates that must be undone before we bind a fixed type
/// to a type variable.
void introduceToInference(TypeVariableType *typeVar, Type fixedType);

/// Describes which constraints \c gatherConstraints should gather.
Expand Down
7 changes: 1 addition & 6 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,7 @@ class TypeVariableType::Implementation {
/// \param trail The record of state changes.
void mergeEquivalenceClasses(TypeVariableType *other,
constraints::SolverTrail *trail) {
// Merge the equivalence classes corresponding to these two type
// variables. Always merge 'up' the constraint stack, because it is simpler.
if (getID() > other->getImpl().getID()) {
other->getImpl().mergeEquivalenceClasses(getTypeVariable(), trail);
return;
}
ASSERT(getID() < other->getImpl().getID());

auto otherRep = other->getImpl().getRepresentative(trail);
if (trail)
Expand Down
15 changes: 2 additions & 13 deletions lib/Sema/CSTrail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,21 +730,10 @@ void SolverTrail::undo(unsigned toIndex) {
ASSERT(!UndoActive);
UndoActive = true;

// FIXME: Undo all changes in the correct order!
for (unsigned i = Changes.size(); i > toIndex; i--) {
auto change = Changes[i - 1];
if (change.Kind == ChangeKind::UpdatedTypeVariable) {
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}
}

for (unsigned i = Changes.size(); i > toIndex; i--) {
auto change = Changes[i - 1];
if (change.Kind != ChangeKind::UpdatedTypeVariable) {
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}

Changes.resize(toIndex);
Expand Down
120 changes: 76 additions & 44 deletions lib/Sema/ConstraintGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ ConstraintGraph::lookupNode(TypeVariableType *typeVar) {
// If this type variable is not the representative of its equivalence class,
// add it to its representative's set of equivalences.
auto typeVarRep = CS.getRepresentative(typeVar);
if (typeVar != typeVarRep)
mergeNodes(typeVar, typeVarRep);
if (typeVar != typeVarRep) {
mergeNodesPre(typeVar);
mergeNodes(typeVarRep, typeVar);
}
else if (auto fixed = CS.getFixedType(typeVarRep)) {
// Bind the type variable.
bindTypeVariable(typeVar, fixed);
Expand Down Expand Up @@ -177,7 +179,9 @@ void ConstraintGraphNode::removeConstraint(Constraint *constraint) {
Constraints.pop_back();
}

void ConstraintGraphNode::notifyReferencingVars() const {
void ConstraintGraphNode::notifyReferencingVars(
llvm::function_ref<void(ConstraintGraphNode &,
Constraint *)> notification) const {
SmallVector<TypeVariableType *, 4> stack;

stack.push_back(TypeVar);
Expand All @@ -199,7 +203,7 @@ void ConstraintGraphNode::notifyReferencingVars() const {
affectedVar->getImpl().getRepresentative(/*record=*/nullptr);

if (!repr->getImpl().getFixedType(/*record=*/nullptr))
CG[repr].reintroduceToInference(constraint);
notification(CG[repr], constraint);
}
}
};
Expand Down Expand Up @@ -236,7 +240,7 @@ void ConstraintGraphNode::notifyReferencingVars() const {
}

void ConstraintGraphNode::notifyReferencedVars(
llvm::function_ref<void(ConstraintGraphNode &)> notification) {
llvm::function_ref<void(ConstraintGraphNode &)> notification) const {
for (auto *fixedBinding : getReferencedVars()) {
notification(CG[fixedBinding]);
}
Expand All @@ -249,25 +253,6 @@ void ConstraintGraphNode::addToEquivalenceClass(
if (EquivalenceClass.empty())
EquivalenceClass.push_back(getTypeVariable());
EquivalenceClass.append(typeVars.begin(), typeVars.end());

{
for (auto *newMember : typeVars) {
auto &node = CG[newMember];

for (auto *constraint : node.getConstraints()) {
introduceToInference(constraint);

if (!isUsefulForReferencedVars(constraint))
continue;

notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
});
}

node.notifyReferencingVars();
}
}
}

void ConstraintGraphNode::truncateEquivalenceClass(unsigned prevSize) {
Expand Down Expand Up @@ -343,19 +328,17 @@ void ConstraintGraphNode::retractFromInference(Constraint *constraint) {
}
}

void ConstraintGraphNode::reintroduceToInference(Constraint *constraint) {
retractFromInference(constraint);
introduceToInference(constraint);
}

void ConstraintGraphNode::introduceToInference(Type fixedType) {
void ConstraintGraphNode::updateFixedType(
Type fixedType,
llvm::function_ref<void (ConstraintGraphNode &,
Constraint *)> notification) const {
// Notify all of the type variables that reference this one.
//
// Since this type variable has been replaced with a fixed type
// all of the concrete types that reference it are going to change,
// which means that all of the not-yet-attempted bindings should
// change as well.
notifyReferencingVars();
notifyReferencingVars(notification);

if (!fixedType->hasTypeVariable())
return;
Expand All @@ -371,11 +354,27 @@ void ConstraintGraphNode::introduceToInference(Type fixedType) {
// all of the constraints that reference bound type variable.
for (auto *constraint : getConstraints()) {
if (isUsefulForReferencedVars(constraint))
node.reintroduceToInference(constraint);
notification(node, constraint);
}
}
}

void ConstraintGraphNode::retractFromInference(Type fixedType) {
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
});
}

void ConstraintGraphNode::introduceToInference(Type fixedType) {
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
});
}

#pragma mark Graph mutation

void ConstraintGraph::removeNode(TypeVariableType *typeVar) {
Expand Down Expand Up @@ -486,31 +485,60 @@ void ConstraintGraph::removeConstraint(TypeVariableType *typeVar,
OrphanedConstraints.pop_back();
}

void ConstraintGraph::mergeNodesPre(TypeVariableType *typeVar2) {
// Merge equivalence class from the non-representative type variable.
auto &nonRepNode = (*this)[typeVar2];

for (auto *newMember : nonRepNode.getEquivalenceClassUnsafe()) {
auto &node = (*this)[newMember];

node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
});
}
}

void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1,
TypeVariableType *typeVar2) {
assert(CS.getRepresentative(typeVar1) == CS.getRepresentative(typeVar2) &&
"type representatives don't match");

// Retrieve the node for the representative that we're merging into.
auto typeVarRep = CS.getRepresentative(typeVar1);
auto &repNode = (*this)[typeVarRep];
ASSERT(CS.getRepresentative(typeVar1) == typeVar1);

// Retrieve the node for the non-representative.
assert((typeVar1 == typeVarRep || typeVar2 == typeVarRep) &&
"neither type variable is the new representative?");
auto typeVarNonRep = typeVar1 == typeVarRep? typeVar2 : typeVar1;
auto &repNode = (*this)[typeVar1];

// Record the change, if there are active scopes.
if (CS.isRecordingChanges()) {
CS.recordChange(
SolverTrail::Change::ExtendedEquivalenceClass(
typeVarRep,
typeVar1,
repNode.getEquivalenceClass().size()));
}

// Merge equivalence class from the non-representative type variable.
auto &nonRepNode = (*this)[typeVarNonRep];
repNode.addToEquivalenceClass(nonRepNode.getEquivalenceClassUnsafe());
auto &nonRepNode = (*this)[typeVar2];

auto typeVars = nonRepNode.getEquivalenceClassUnsafe();
repNode.addToEquivalenceClass(typeVars);

for (auto *newMember : typeVars) {
auto &node = (*this)[newMember];

for (auto *constraint : node.getConstraints()) {
repNode.introduceToInference(constraint);

if (!isUsefulForReferencedVars(constraint))
continue;

repNode.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
});
}

node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
});
}
}

void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) {
Expand All @@ -537,6 +565,10 @@ void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) {
}
}

void ConstraintGraph::retractFromInference(TypeVariableType *typeVar, Type fixed) {
(*this)[typeVar].retractFromInference(fixed);
}

void ConstraintGraph::introduceToInference(TypeVariableType *typeVar, Type fixed) {
(*this)[typeVar].introduceToInference(fixed);
}
Expand Down
9 changes: 7 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ void ConstraintSystem::mergeEquivalenceClasses(TypeVariableType *typeVar1,
assert(typeVar2 == getRepresentative(typeVar2) &&
"typeVar2 is not the representative");
assert(typeVar1 != typeVar2 && "cannot merge type with itself");
typeVar1->getImpl().mergeEquivalenceClasses(typeVar2, getTrail());

// Merge nodes in the constraint graph.
// Always merge 'up' the constraint stack, because it is simpler.
if (typeVar1->getImpl().getID() > typeVar2->getImpl().getID())
std::swap(typeVar1, typeVar2);

CG.mergeNodesPre(typeVar2);
typeVar1->getImpl().mergeEquivalenceClasses(typeVar2, getTrail());
CG.mergeNodes(typeVar1, typeVar2);

if (updateWorkList) {
Expand Down Expand Up @@ -205,6 +209,7 @@ void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type,
assert(!type->hasError() &&
"Should not be assigning a type involving ErrorType!");

CG.retractFromInference(typeVar, type);
typeVar->getImpl().assignFixedType(type, getTrail());

if (!updateState)
Expand Down
2 changes: 1 addition & 1 deletion test/Sema/issue-46000.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct Data {}
extension DispatchData {
func asFoundationData<T>(execute: (Data) throws -> T) rethrows -> T {
return try withUnsafeBytes { (ptr: UnsafePointer<Int8>) -> Void in
// expected-error@-1 {{cannot convert return expression of type 'Void' to return type 'T'}}
// expected-error@-1 {{declared closure result 'Void' is incompatible with contextual type 'T'}}
let data = Data()
return try execute(data) // expected-error {{cannot convert value of type 'T' to closure result type 'Void'}}
}
Expand Down
5 changes: 1 addition & 4 deletions test/type/opaque.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,8 @@ func associatedTypeIdentity() {
sameType(cr, dr) // expected-error {{conflicting arguments to generic parameter 'T' ('(some R).S' (result type of 'candace') vs. '(some R).S' (result type of 'doug'))}}
sameType(gary(candace()).r_out(), gary(candace()).r_out())
sameType(gary(doug()).r_out(), gary(doug()).r_out())
// TODO(diagnostics): This is not great but the problem comes from the way solver discovers and attempts bindings, if we could detect that
// `(some R).S` from first reference to `gary()` in inconsistent with the second one based on the parent type of `S` it would be much easier to diagnose.
sameType(gary(doug()).r_out(), gary(candace()).r_out())
// expected-error@-1:12 {{conflicting arguments to generic parameter 'T' ('some R' (result type of 'doug') vs. 'some R' (result type of 'candace'))}}
// expected-error@-2:34 {{conflicting arguments to generic parameter 'T' ('some R' (result type of 'doug') vs. 'some R' (result type of 'candace'))}}
// expected-error@-1:39 {{cannot convert value of type 'some R' (result of 'candace()') to expected argument type 'some R' (result of 'doug()')}}
}

func redeclaration() -> some P { return 0 } // expected-note 2{{previously declared}}
Expand Down

0 comments on commit a7b1e78

Please sign in to comment.