Skip to content

Commit

Permalink
steiner trees: fix hash collisions due to overflows
Browse files Browse the repository at this point in the history
Also ensure k is small enough: we need to fit the connectivity set of a
hyperedge into 64 bit, which implies k <= 64.
  • Loading branch information
N-Maas committed Nov 21, 2024
1 parent f34dfa9 commit 525cd5c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
3 changes: 3 additions & 0 deletions mt-kahypar/partition/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ namespace mt_kahypar {
if ( partition.preset_type == PresetType::large_k ) {
// steiner trees scale really badly with k (cubic with no parallelization), so we don't want to support this
ERR("Large k partitioning is not supported for steiner tree metric.");
} else if ( partition.k > 64 && partition.instance_type == InstanceType::hypergraph ) {
// larger k currently don't work correctly due to collisions in the hash table
ERR("Steiner tree metric on hypergraphs is currently only supported for k <= 64.");
}
if ( !target_graph ) {
partition.objective = Objective::km1;
Expand Down
2 changes: 1 addition & 1 deletion mt-kahypar/partition/mapping/kerninghan_lin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void KerninghanLin<CommunicationHypergraph>::improve(CommunicationHypergraph& co
const TargetGraph& target_graph) {
ASSERT(communication_hg.initialNumNodes() == target_graph.graph().initialNumNodes());

HyperedgeWeight current_objective = metrics::quality(communication_hg, Objective::steiner_tree);
HyperedgeWeight current_objective = metrics::quality(communication_hg, Objective::steiner_tree, false);
vec<bool> marked_hes(communication_hg.initialNumEdges(), false);
bool found_improvement = true;
size_t fruitless_rounds = 0;
Expand Down
11 changes: 6 additions & 5 deletions mt-kahypar/partition/mapping/target_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ void TargetGraph::precomputeDistances(const size_t max_connectivity) {

HyperedgeWeight TargetGraph::distance(const ds::StaticBitset& connectivity_set) const {
const PartitionID connectivity = connectivity_set.popcount();
const size_t idx = index(connectivity_set);
if ( likely(connectivity <= _max_precomputed_connectitivty) ) {
const size_t idx = index(connectivity_set);
ASSERT(idx < _distances.size());
if constexpr ( TRACK_STATS ) ++_stats.precomputed;
return _distances[idx];
} else {
const uint64_t hash_key = computeHash(connectivity_set);
// We have not precomputed the optimal steiner tree for the connectivity set.
#ifdef __linux__
HashTableHandle& handle = _handles.local();
auto res = handle.find(idx);
auto res = handle.find(hash_key);
if ( likely( res != handle.end() ) ) {
if constexpr ( TRACK_STATS ) ++_stats.cache_hits;
return (*res).second;
Expand All @@ -71,11 +72,11 @@ HyperedgeWeight TargetGraph::distance(const ds::StaticBitset& connectivity_set)
// Entry is not cached => Compute 2-approximation of optimal steiner tree
const HyperedgeWeight mst_weight =
computeWeightOfMSTOnMetricCompletion(connectivity_set);
handle.insert(idx, mst_weight);
handle.insert(hash_key, mst_weight);
return mst_weight;
}
#elif defined(_WIN32) or defined(__APPLE__)
auto res = _cache.find(idx);
auto res = _cache.find(hash_key);
if ( likely ( res != _cache.end() ) ) {
if constexpr ( TRACK_STATS ) ++_stats.cache_hits;
return res->second;
Expand All @@ -84,7 +85,7 @@ HyperedgeWeight TargetGraph::distance(const ds::StaticBitset& connectivity_set)
// Entry is not cached => Compute 2-approximation of optimal steiner tree
const HyperedgeWeight mst_weight =
computeWeightOfMSTOnMetricCompletion(connectivity_set);
_cache.insert(std::make_pair(idx, mst_weight));
_cache.insert(std::make_pair(hash_key, mst_weight));
return mst_weight;
}
#endif
Expand Down
13 changes: 11 additions & 2 deletions mt-kahypar/partition/mapping/target_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class TargetGraph {
using hasher_type = utils_tm::hash_tm::murmur2_hash;
using allocator_type = growt::AlignedAllocator<>;
using ConcurrentHashTable = typename growt::table_config<
size_t, size_t, hasher_type, allocator_type, hmod::growable, hmod::sync>::table_type;
uint64_t, uint64_t, hasher_type, allocator_type, hmod::growable, hmod::sync>::table_type;
using HashTableHandle = typename ConcurrentHashTable::handle_type;
#elif defined(_WIN32) or defined(__APPLE__)
using ConcurrentHashTable = tbb::concurrent_unordered_map<size_t, size_t>;
using ConcurrentHashTable = tbb::concurrent_unordered_map<uint64_t, uint64_t>;
#endif

struct MSTData {
Expand Down Expand Up @@ -241,6 +241,15 @@ class TargetGraph {
(multiplier == UL(_k) ? last_block * _k : 0) : 0;
}

MT_KAHYPAR_ATTRIBUTE_ALWAYS_INLINE uint64_t computeHash(const ds::StaticBitset& connectivity_set) const {
uint64_t index = 0;
for ( const PartitionID block : connectivity_set ) {
ASSERT(block != kInvalidPartition && block < _k && block < 64);
index |= (static_cast<uint64_t>(1) << block);
}
return index;
}

// ! This function computes an MST on the metric completion of the target graph
// ! restricted to the blocks in the connectivity set. The metric completion is
// ! complete graph where each edge {u,v} has a weight equals the shortest path
Expand Down

0 comments on commit 525cd5c

Please sign in to comment.