diff --git a/.gitlab/os.yml b/.gitlab/os.yml index 9bc4b8146..e7a6a8fda 100644 --- a/.gitlab/os.yml +++ b/.gitlab/os.yml @@ -18,7 +18,7 @@ .on_blueos_3_ppc64: variables: ARCH: 'blueos_3_ppc64le_ib_p9' - GCC_VERSION: '8.3.1' + GCC_VERSION: '10.2.1' CLANG_VERSION: '9.0.0' SPHERAL_BUILDS_DIR: /p/gpfs1/sphapp/spheral-ci-builds extends: [.sys_config] diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 85809ca49..9e5d37f99 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -28,6 +28,9 @@ Notable changes include: * Physics packages can indicate if they require Voronoi cell information be available. If so, a new package which computes and updates the Voronoi information is automatically added to the package list by the SpheralController (similar to how the Reproducing Kernel corrections are handled). + * Cleaned up use of std::any in State objects using a visitor pattern to be rigorous ensuring all state entries are handled properly + during assignement, equality, and cloning operations. This is intended to help ensure our Physics advance during time integration + is correct. * Build changes / improvements: * Distributed source directory must always be built now. diff --git a/scripts/devtools/spec-list.json b/scripts/devtools/spec-list.json index 59f681eb8..395d5d0cd 100644 --- a/scripts/devtools/spec-list.json +++ b/scripts/devtools/spec-list.json @@ -7,9 +7,9 @@ ] , "blueos_3_ppc64le_ib_p9": [ - "gcc@8.3.1", - "gcc@8.3.1+cuda~mpi cuda_arch=70", - "gcc@8.3.1+cuda cuda_arch=70" + "gcc@10.2.1", + "gcc@10.2.1+cuda~mpi cuda_arch=70", + "gcc@10.2.1+cuda cuda_arch=70" ] } } diff --git a/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml b/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml index b876f15bc..84c1b8b70 100644 --- a/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml +++ b/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml @@ -13,12 +13,12 @@ compilers: environment: {} extra_rpaths: [] - compiler: - spec: gcc@8.3.1 + spec: gcc@10.2.1 paths: - cc: /usr/tce/packages/gcc/gcc-8.3.1/bin/gcc - cxx: /usr/tce/packages/gcc/gcc-8.3.1/bin/g++ - f77: /usr/tce/packages/gcc/gcc-8.3.1/bin/gfortran - fc: /usr/tce/packages/gcc/gcc-8.3.1/bin/gfortran + cc: /usr/tce/packages/gcc/gcc-10.2.1/bin/gcc + cxx: /usr/tce/packages/gcc/gcc-10.2.1/bin/g++ + f77: /usr/tce/packages/gcc/gcc-10.2.1/bin/gfortran + fc: /usr/tce/packages/gcc/gcc-10.2.1/bin/gfortran flags: {} operating_system: rhel7 target: ppc64le diff --git a/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml b/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml index 4f4b7f697..ea43681ac 100644 --- a/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml +++ b/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml @@ -39,6 +39,8 @@ packages: - 10.1.243 buildable: false externals: + - spec: cuda@11.4.1+allow-unsupported-compilers + prefix: /usr/tce/packages/cuda/cuda-11.4.1 - spec: cuda@11.1.0~allow-unsupported-compilers prefix: /usr/tce/packages/cuda/cuda-11.1.0 - spec: cuda@11.0.2~allow-unsupported-compilers diff --git a/src/ArtificialConduction/ArtificialConduction.cc b/src/ArtificialConduction/ArtificialConduction.cc index 413eb19cd..318793efe 100644 --- a/src/ArtificialConduction/ArtificialConduction.cc +++ b/src/ArtificialConduction/ArtificialConduction.cc @@ -121,10 +121,10 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, ReproducingKernel WR; auto maxOrder = RKOrder::ZerothOrder; if (useRK) { - const auto& rkOrders = state.template getAny>(RKFieldNames::rkOrders); + const auto& rkOrders = state.template get>(RKFieldNames::rkOrders); CHECK(not rkOrders.empty()); const auto maxOrder = *rkOrders.rbegin(); - WR = state.template getAny>(RKFieldNames::reproducingKernel(maxOrder)); + WR = state.template get>(RKFieldNames::reproducingKernel(maxOrder)); } // The connectivity map diff --git a/src/ArtificialViscosity/TensorCRKSPHViscosity.cc b/src/ArtificialViscosity/TensorCRKSPHViscosity.cc index 67a4fa120..b6a805d1f 100644 --- a/src/ArtificialViscosity/TensorCRKSPHViscosity.cc +++ b/src/ArtificialViscosity/TensorCRKSPHViscosity.cc @@ -184,7 +184,7 @@ calculateSigmaAndGradDivV(const DataBase& dataBase, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto rho = state.fields(HydroFieldNames::massDensity, 0.0); const auto H = state.fields(HydroFieldNames::H, SymTensor::zero); - const auto WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto WR = state.template get>(RKFieldNames::reproducingKernel(order)); const auto corrections = state.fields(RKFieldNames::rkCorrections(order), RKCoefficients()); const auto& connectivityMap = dataBase.connectivityMap(); diff --git a/src/ArtificialViscosity/VonNeumanViscosity.cc b/src/ArtificialViscosity/VonNeumanViscosity.cc index 83bb3e99b..367313e07 100644 --- a/src/ArtificialViscosity/VonNeumanViscosity.cc +++ b/src/ArtificialViscosity/VonNeumanViscosity.cc @@ -83,7 +83,7 @@ initialize(const DataBase& dataBase, const auto pressure = state.fields(HydroFieldNames::pressure, 0.0); const auto soundSpeed = state.fields(HydroFieldNames::soundSpeed, 0.0); const auto vol = mass/massDensity; - const auto WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto WR = state.template get>(RKFieldNames::reproducingKernel(order)); const auto corrections = state.fields(RKFieldNames::rkCorrections(order), RKCoefficients()); // We'll compute the higher-accuracy RK gradient. diff --git a/src/CRKSPH/CRKSPHEvaluateDerivatives.cc b/src/CRKSPH/CRKSPHEvaluateDerivatives.cc index fa1c72d26..2d99a67ae 100644 --- a/src/CRKSPH/CRKSPHEvaluateDerivatives.cc +++ b/src/CRKSPH/CRKSPHEvaluateDerivatives.cc @@ -16,7 +16,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto& Q = this->artificialViscosity(); // The kernels and such. - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -65,7 +65,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(DxDt.size() == numNodeLists); CHECK(DrhoDt.size() == numNodeLists); diff --git a/src/CRKSPH/CRKSPHHydroBase.cc b/src/CRKSPH/CRKSPHHydroBase.cc index 06047630c..13f8b60c6 100644 --- a/src/CRKSPH/CRKSPHHydroBase.cc +++ b/src/CRKSPH/CRKSPHHydroBase.cc @@ -263,7 +263,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDspecificThermalEnergyDt); derivs.enroll(mDvDx); derivs.enroll(mInternalDvDx); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); } //------------------------------------------------------------------------------ @@ -282,7 +282,7 @@ preStepInitialize(const DataBase& dataBase, if (mDensityUpdate == MassDensityType::RigorousSumDensity or mDensityUpdate == MassDensityType::VoronoiCellDensity) { auto massDensity = state.fields(HydroFieldNames::massDensity, 0.0); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); const auto& W = WR.kernel(); const auto& connectivityMap = dataBase.connectivityMap(); const auto mass = state.fields(HydroFieldNames::mass, 0.0); @@ -311,7 +311,7 @@ initialize(const typename Dimension::Scalar time, State& state, StateDerivatives& derivs) { // Initialize the artificial viscosity - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); auto& Q = this->artificialViscosity(); Q.initialize(dataBase, state, diff --git a/src/CRKSPH/CRKSPHHydroBaseRZ.cc b/src/CRKSPH/CRKSPHHydroBaseRZ.cc index 53387f123..3dc37ca08 100644 --- a/src/CRKSPH/CRKSPHHydroBaseRZ.cc +++ b/src/CRKSPH/CRKSPHHydroBaseRZ.cc @@ -219,7 +219,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, // The kernels and such. //const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); // A few useful constants we'll use in the following loop. //const auto tiny = 1.0e-30; @@ -263,7 +263,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(DxDt.size() == numNodeLists); CHECK(DrhoDt.size() == numNodeLists); diff --git a/src/CRKSPH/SolidCRKSPHHydroBase.cc b/src/CRKSPH/SolidCRKSPHHydroBase.cc index be82ef925..6e9402385 100644 --- a/src/CRKSPH/SolidCRKSPHHydroBase.cc +++ b/src/CRKSPH/SolidCRKSPHHydroBase.cc @@ -258,7 +258,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, // The kernels and such. const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(order)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -318,7 +318,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); CHECK(DxDt.size() == numNodeLists); diff --git a/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc b/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc index 9a6068eb4..b8b3b951b 100644 --- a/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc +++ b/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc @@ -275,7 +275,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, // The kernels and such. const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(order)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -334,7 +334,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); CHECK(DxDt.size() == numNodeLists); diff --git a/src/DEM/IncrementPairFieldList.cc b/src/DEM/IncrementPairFieldList.cc index e6169f859..371116a60 100644 --- a/src/DEM/IncrementPairFieldList.cc +++ b/src/DEM/IncrementPairFieldList.cc @@ -50,7 +50,7 @@ update(const KeyType& key, // Find all the available matching derivative FieldList keys. const auto incrementKey = prefix() + fieldKey; // cerr << "IncrementPairFieldList: [" << fieldKey << "] [" << incrementKey << "] : " << endl; - const auto allkeys = derivs.fieldKeys(); + const auto allkeys = derivs.fullFieldKeys(); vector incrementKeys; for (const auto& key: allkeys) { // if (std::regex_search(key, std::regex("^" + incrementKey))) { diff --git a/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc b/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc index efda0cf5e..7569c16ca 100644 --- a/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc @@ -63,9 +63,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(pointKey,mVelocity); - state.enrollAny(pointKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(pointKey,mVelocity); + state.enroll(pointKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc b/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc index 0be718247..6ea0bfd7f 100644 --- a/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc +++ b/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc @@ -87,9 +87,9 @@ registerState(DataBase& dataBase, const auto clipPointKey = boundaryKey +"_clipPoint"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mCenter); - state.enrollAny(clipPointKey,mClipPoint); - state.enrollAny(pointKey,mVelocity); + state.enroll(pointKey,mCenter); + state.enroll(clipPointKey,mClipPoint); + state.enroll(pointKey,mVelocity); } diff --git a/src/DEM/SolidBoundary/CylinderSolidBoundary.cc b/src/DEM/SolidBoundary/CylinderSolidBoundary.cc index 8280ef1c4..53ad556cc 100644 --- a/src/DEM/SolidBoundary/CylinderSolidBoundary.cc +++ b/src/DEM/SolidBoundary/CylinderSolidBoundary.cc @@ -64,9 +64,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; //const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(pointKey,mVelocity); - //state.enrollAny(pointKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(pointKey,mVelocity); + //state.enroll(pointKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc b/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc index eb29ba5c5..5276f1541 100644 --- a/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc @@ -54,9 +54,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(velocityKey,mVelocity); - state.enrollAny(normalKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(velocityKey,mVelocity); + state.enroll(normalKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc b/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc index 9e38fc6e2..eff75c3e6 100644 --- a/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc @@ -57,8 +57,8 @@ registerState(DataBase& dataBase, const auto boundaryKey = "RectangularPlaneSolidBoundary_" + std::to_string(std::abs(this->uniqueIndex())); const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(velocityKey,mVelocity); + state.enroll(pointKey,mPoint); + state.enroll(velocityKey,mVelocity); } template void diff --git a/src/DEM/SolidBoundary/SphereSolidBoundary.cc b/src/DEM/SolidBoundary/SphereSolidBoundary.cc index 3a600a1a1..205aaee85 100644 --- a/src/DEM/SolidBoundary/SphereSolidBoundary.cc +++ b/src/DEM/SolidBoundary/SphereSolidBoundary.cc @@ -62,8 +62,8 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mCenter); - state.enrollAny(pointKey,mVelocity); + state.enroll(pointKey,mCenter); + state.enroll(pointKey,mVelocity); } diff --git a/src/DataBase/CopyStateInline.hh b/src/DataBase/CopyStateInline.hh index 2782a067f..a70425343 100644 --- a/src/DataBase/CopyStateInline.hh +++ b/src/DataBase/CopyStateInline.hh @@ -42,10 +42,10 @@ update(const KeyType& key, REQUIRE(key == mCopyStateName); // The state we're updating - ValueType& f = state.template getAny(key); + ValueType& f = state.template get(key); // The master state we're copying - const ValueType& fmaster = state.template getAny(mMasterStateName); + const ValueType& fmaster = state.template get(mMasterStateName); // Copy the master state using the assignment operator f = fmaster; diff --git a/src/DataBase/DataBase.cc b/src/DataBase/DataBase.cc index 3ad7073e8..117ea69b0 100644 --- a/src/DataBase/DataBase.cc +++ b/src/DataBase/DataBase.cc @@ -13,7 +13,6 @@ #include "Material/EquationOfState.hh" #include "Utilities/testBoxIntersection.hh" #include "Utilities/safeInv.hh" -#include "State.hh" #include "Hydro/HydroFieldNames.hh" #include "Utilities/globalBoundingVolumes.hh" #include "Utilities/globalNodeIDs.hh" diff --git a/src/DataBase/State.cc b/src/DataBase/State.cc index 0859d195d..186c28d07 100644 --- a/src/DataBase/State.cc +++ b/src/DataBase/State.cc @@ -94,9 +94,7 @@ State(DataBase& dataBase, mPolicyMap(), mTimeAdvanceOnly(false) { // Iterate over the physics packages, and have them register their state. - for (PackageIterator itr = physicsPackages.begin(); - itr != physicsPackages.end(); - ++itr) (*itr)->registerState(dataBase, *this); + for (auto pkg: physicsPackages) pkg->registerState(dataBase, *this); } //------------------------------------------------------------------------------ @@ -111,9 +109,7 @@ State(DataBase& dataBase, mPolicyMap(), mTimeAdvanceOnly(false) { // Iterate over the physics packages, and have them register their state. - for (PackageIterator itr = physicsPackageBegin; - itr != physicsPackageEnd; - ++itr) (*itr)->registerState(dataBase, *this); + for (auto pkg: range(physicsPackageBegin, physicsPackageEnd)) pkg->registerState(dataBase, *this); } //------------------------------------------------------------------------------ @@ -160,6 +156,105 @@ operator==(const StateBase& rhs) const { return StateBase::operator==(rhs); } +//------------------------------------------------------------------------------ +// The set of keys for all registered policies. +//------------------------------------------------------------------------------ +template +vector::KeyType> +State:: +policyKeys() const { + vector result; + for (const auto itr: mPolicyMap) result.push_back(itr.first); + ENSURE(result.size() == mPolicyMap.size()); + return result; +} + +//------------------------------------------------------------------------------ +// Return the policy for the given key. +//------------------------------------------------------------------------------ +template +typename State::PolicyPointer +State:: +policy(const typename State::KeyType& key) const { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + const auto outerItr = mPolicyMap.find(fieldKey); + if (outerItr == mPolicyMap.end()) return PolicyPointer(); + // VERIFY2(outerItr != mPolicyMap.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + const auto& key2policies = outerItr->second; + const auto innerItr = key2policies.find(key); + if (innerItr == key2policies.end()) return PolicyPointer(); + // VERIFY2(innerItr != policies.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + return innerItr->second; +} + +//------------------------------------------------------------------------------ +// Return all the policies for the given field key. +//------------------------------------------------------------------------------ +template +std::map::KeyType, typename State::PolicyPointer> +State:: +policies(const typename State::KeyType& fieldKey) const { + const auto outerItr = mPolicyMap.find(fieldKey); + if (outerItr == mPolicyMap.end()) return std::map(); + // VERIFY2(outerItr != mPolicyMap.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + return outerItr->second; +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with the given key. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(const typename State::KeyType& key) { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + typename PolicyMapType::iterator outerItr = mPolicyMap.find(fieldKey); + VERIFY2(outerItr != mPolicyMap.end(), + "State ERROR: attempted to remove non-existent policy for field key " << fieldKey); + std::map& policies = outerItr->second; + typename std::map::iterator innerItr = policies.find(key); + if (innerItr == policies.end()) { + cerr << "State ERROR: attempted to remove non-existent policy for inner key " << key << endl + << "Known keys are: " << endl; + for (auto itr = policies.begin(); itr != policies.end(); ++itr) cerr << " --> " << itr->first << endl; + VERIFY(innerItr != policies.end()); + } + policies.erase(innerItr); + if (policies.size() == 0) mPolicyMap.erase(outerItr); +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with a Field. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(FieldBase& field) { + this->removePolicy(StateBase::key(field)); +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with a FieldList. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(FieldListBase& fieldList, + const bool clonePerField) { + if (clonePerField) { + for (auto fieldPtrItr = fieldList.begin_base(); + fieldPtrItr < fieldList.end_base(); + ++fieldPtrItr) this->removePolicy(**fieldPtrItr); + } else { + this->removePolicy(StateBase::key(fieldList)); + } +} + //------------------------------------------------------------------------------ // Update the state with the given derivatives object, according to the per // state field policies. @@ -273,104 +368,5 @@ update(StateDerivatives& derivs, } } -//------------------------------------------------------------------------------ -// The set of keys for all registered policies. -//------------------------------------------------------------------------------ -template -vector::KeyType> -State:: -policyKeys() const { - vector result; - for (const auto itr: mPolicyMap) result.push_back(itr.first); - ENSURE(result.size() == mPolicyMap.size()); - return result; -} - -//------------------------------------------------------------------------------ -// Return the policy for the given key. -//------------------------------------------------------------------------------ -template -typename State::PolicyPointer -State:: -policy(const typename State::KeyType& key) const { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - const auto outerItr = mPolicyMap.find(fieldKey); - if (outerItr == mPolicyMap.end()) return PolicyPointer(); - // VERIFY2(outerItr != mPolicyMap.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - const auto& key2policies = outerItr->second; - const auto innerItr = key2policies.find(key); - if (innerItr == key2policies.end()) return PolicyPointer(); - // VERIFY2(innerItr != policies.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - return innerItr->second; -} - -//------------------------------------------------------------------------------ -// Return all the policies for the given field key. -//------------------------------------------------------------------------------ -template -std::map::KeyType, typename State::PolicyPointer> -State:: -policies(const typename State::KeyType& fieldKey) const { - const auto outerItr = mPolicyMap.find(fieldKey); - if (outerItr == mPolicyMap.end()) return std::map(); - // VERIFY2(outerItr != mPolicyMap.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - return outerItr->second; -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with the given key. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(const typename State::KeyType& key) { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - typename PolicyMapType::iterator outerItr = mPolicyMap.find(fieldKey); - VERIFY2(outerItr != mPolicyMap.end(), - "State ERROR: attempted to remove non-existent policy for field key " << fieldKey); - std::map& policies = outerItr->second; - typename std::map::iterator innerItr = policies.find(key); - if (innerItr == policies.end()) { - cerr << "State ERROR: attempted to remove non-existent policy for inner key " << key << endl - << "Known keys are: " << endl; - for (auto itr = policies.begin(); itr != policies.end(); ++itr) cerr << " --> " << itr->first << endl; - VERIFY(innerItr != policies.end()); - } - policies.erase(innerItr); - if (policies.size() == 0) mPolicyMap.erase(outerItr); -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with a Field. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(FieldBase& field) { - this->removePolicy(StateBase::key(field)); -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with a FieldList. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(FieldListBase& fieldList, - const bool clonePerField) { - if (clonePerField) { - for (auto fieldPtrItr = fieldList.begin_base(); - fieldPtrItr < fieldList.end_base(); - ++fieldPtrItr) this->removePolicy(**fieldPtrItr); - } else { - this->removePolicy(StateBase::key(fieldList)); - } -} - } diff --git a/src/DataBase/State.hh b/src/DataBase/State.hh index 59b7dbb49..6f5c5f493 100644 --- a/src/DataBase/State.hh +++ b/src/DataBase/State.hh @@ -52,15 +52,22 @@ public: // Assignment. State& operator=(const State& rhs); - // Override the base method. + // Override the base equivalence operator virtual bool operator==(const StateBase& rhs) const override; - // Update the registered state according to the policies. - void update(StateDerivatives& derivs, - const double multiplier, - const double t, - const double dt); + //........................................................................... + // Enroll state with update policies + void enroll(FieldBase& field, PolicyPointer policy); + // Enroll the given FieldList and associated update policy + // This method queries the "clonePerField" method of the policy, and + // if true enrolls each Field in the FieldList with a copy of the policy. + // Otherwise the FieldList is enrolled as a single entity, and the policy is + // assumed to handle a FieldList as a whole. + void enroll(FieldListBase& fieldList, PolicyPointer policy); + + //........................................................................... + // Policies // Enroll a policy by itself. void enroll(const KeyType& key, PolicyPointer policy); @@ -70,23 +77,6 @@ public: void removePolicy(FieldListBase& field, const bool clonePerField); - // Enroll the given Field and associated update policy - void enroll(FieldBase& field, PolicyPointer policy); - - // Enroll the given FieldList and associated update policy - // This method queries the "clonePerField" method of the policy, and - // if true enrolls each Field in the FieldList with a copy of the policy. - // Otherwise the FieldList is enrolled directly as normal, and the policy is - // assumed to handle a FieldList directly. - void enroll(FieldListBase& fieldList, PolicyPointer policy); - - // The base class method for just registering a field. - virtual void enroll(FieldBase& field) override; - virtual void enroll(std::shared_ptr>& fieldPtr) override; - - // The base class method for just registering a field list. - virtual void enroll(FieldListBase& fieldList) override; - // The full set of keys for all policies. std::vector policyKeys() const; @@ -100,10 +90,25 @@ public: template PolicyPointer policy(const Field& field) const; + //........................................................................... + // Update the registered state according to the policies. + void update(StateDerivatives& derivs, + const double multiplier, + const double t, + const double dt); + // Optionally trip a flag indicating policies should time advance only -- no replacing state! // This is useful when you're trying to cheat and reuse derivatives from a prior advance. - bool timeAdvanceOnly() const; - void timeAdvanceOnly(const bool x); + bool timeAdvanceOnly() const { return mTimeAdvanceOnly; } + void timeAdvanceOnly(const bool x) { mTimeAdvanceOnly = x; } + + //........................................................................... + // Expose the StateBase enroll methods + using StateBase::enroll; + // virtual void enroll(FieldBase& field) override { StateBase::enroll(field); } + // virtual void enroll(std::shared_ptr>& fieldPtr) override { StateBase::enroll(fieldPtr); } + // virtual void enroll(FieldListBase& fieldList) override { StateBase::enroll(fieldList); } + template void enroll(const KeyType& key, T& thing); private: //--------------------------- Private Interface ---------------------------// diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index 4ae5e7c99..698a91d58 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -10,10 +10,14 @@ #include "Field/FieldList.hh" #include "Neighbor/ConnectivityMap.hh" #include "Mesh/Mesh.hh" +#include "RK/RKCorrectionParams.hh" +#include "RK/ReproducingKernel.hh" +#include "Utilities/AnyVisitor.hh" #include "Utilities/DBC.hh" #include #include + using std::vector; using std::cout; using std::cerr; @@ -21,25 +25,31 @@ using std::endl; using std::min; using std::max; using std::abs; +using std::sort; +using std::shared_ptr; +using std::make_shared; +using std::any; +using std::any_cast; namespace Spheral { -// namespace { -// //------------------------------------------------------------------------------ -// // Helper for copying a type, used in copyState -// //------------------------------------------------------------------------------ -// template -// T* -// extractType(boost::any& anyT) { -// try { -// T* result = boost::any_cast(anyT); -// return result; -// } catch (boost::any_cast_error) { -// return NULL; -// } -// } +namespace { -// } +//------------------------------------------------------------------------------ +// Template for generic cloning during copyState +//------------------------------------------------------------------------------ +template +void +genericClone(std::any& x, + const std::string& key, + typename std::map& storage, + typename std::list& cache) { + auto clone = std::make_shared(std::any_cast>(x).get()); + cache.push_back(clone); + storage[key] = std::ref(*clone); +} + +} //------------------------------------------------------------------------------ // Default constructor. @@ -49,120 +59,101 @@ StateBase:: StateBase(): mStorage(), mCache(), + mNodeListPtrs(), mConnectivityMapPtr(), mMeshPtr(new MeshType()) { } //------------------------------------------------------------------------------ -// Copy constructor. +// Test if the internal state is equal. //------------------------------------------------------------------------------ template +bool StateBase:: -StateBase(const StateBase& rhs): - mStorage(rhs.mStorage), - mCache(), - mNodeListPtrs(rhs.mNodeListPtrs), - mConnectivityMapPtr(rhs.mConnectivityMapPtr), - mMeshPtr(rhs.mMeshPtr) { +operator==(const StateBase& rhs) const { + + // Compare raw sizes + if (mStorage.size() != rhs.mStorage.size()) { + cerr << "Storage sizes don't match." << endl; + return false; + } + + // Keys + auto lhsKeys = keys(); + auto rhsKeys = rhs.keys(); + if (lhsKeys != rhsKeys) { + cerr << "Keys don't match." << endl; + return false; + } + + // Build up a visitor to compare each type of state data we support holding + AnyVisitor EQUAL; + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>>([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + + // Apply the equality visitor to all the stored State data + auto lhsitr = mStorage.begin(); + auto rhsitr = rhs.mStorage.begin(); + for (; lhsitr != mStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + if (not EQUAL.visit(lhsitr->second, rhsitr->second)) { + cerr << "States don't match for key " << lhsitr->first << endl; + return false; + } + } + + return true; } //------------------------------------------------------------------------------ -// Destructor. +// Enroll a Field //------------------------------------------------------------------------------ template +void StateBase:: -~StateBase() { +enroll(FieldBase& field) { + const auto key = this->key(field); + mStorage[key] = std::ref(field); + mNodeListPtrs.insert(field.nodeListPtr()); + // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; + ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); } //------------------------------------------------------------------------------ -// Assignment. +// Enroll a Field (shared_ptr). //------------------------------------------------------------------------------ template -StateBase& +void StateBase:: -operator=(const StateBase& rhs) { - if (this != &rhs) { - mStorage = rhs.mStorage; - mCache = CacheType(); - mNodeListPtrs = rhs.mNodeListPtrs; - mConnectivityMapPtr = rhs.mConnectivityMapPtr; - mMeshPtr = rhs.mMeshPtr; - } - return *this; +enroll(std::shared_ptr>& fieldPtr) { + const auto key = this->key(*fieldPtr); + mStorage[key] = std::ref(*fieldPtr); + mNodeListPtrs.insert(fieldPtr->nodeListPtr()); + mCache.push_back(fieldPtr); + ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); } //------------------------------------------------------------------------------ -// Test if the internal state is equal. +// Add the fields from a FieldList. //------------------------------------------------------------------------------ template -bool +void StateBase:: -operator==(const StateBase& rhs) const { - if (mStorage.size() != rhs.mStorage.size()) { - cerr << "Storage sizes don't match." << endl; - return false; - } - vector lhsKeys = keys(); - vector rhsKeys = rhs.keys(); - if (lhsKeys.size() != rhsKeys.size()) { - cerr << "Keys sizes don't match." << endl; - return false; - } - sort(lhsKeys.begin(), lhsKeys.end()); - sort(rhsKeys.begin(), rhsKeys.end()); - if (lhsKeys != rhsKeys) { - cerr << "Keys don't match." << endl; - return false; - } - - // Walk the keys, and rely on the virtual overloaded - // Field::operator==(FieldBase) to do the right thing! - // We are also relying here on the fact that std::map with a given - // set of keys will always result in the same order. - bool result = true; - typename StorageType::const_iterator lhsItr, rhsItr; - for (rhsItr = rhs.mStorage.begin(), lhsItr = mStorage.begin(); - rhsItr != rhs.mStorage.end(); - ++rhsItr, ++lhsItr) { - try { - auto lhsPtr = boost::any_cast*>(lhsItr->second); - auto rhsPtr = boost::any_cast*>(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Fields for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast*>(lhsItr->second); - auto rhsPtr = boost::any_cast*>(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "vector for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast(lhsItr->second); - auto rhsPtr = boost::any_cast(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Vector for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast(lhsItr->second); - auto rhsPtr = boost::any_cast(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Scalar for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - std::cerr << "StateBase::operator== WARNING: unable to compare values for " << lhsItr->first << "\n"; - } - } - } - } +enroll(FieldListBase& fieldList) { + for (auto* fptr: range(fieldList.begin_base(), fieldList.end_base())) { + this->enroll(*fptr); } - return result; } //------------------------------------------------------------------------------ @@ -172,7 +163,7 @@ template bool StateBase:: registered(const StateBase::KeyType& key) const { - return (mStorage.find(key) != mStorage.end()); + return mStorage.find(key) != mStorage.end(); } //------------------------------------------------------------------------------ @@ -182,9 +173,8 @@ template bool StateBase:: registered(const FieldBase& field) const { - const KeyType key = this->key(field); - typename StorageType::const_iterator itr = mStorage.find(key); - return (itr != mStorage.end()); + const auto key = this->key(field); + return this->registered(key); } //------------------------------------------------------------------------------ @@ -206,73 +196,54 @@ bool StateBase:: fieldNameRegistered(const FieldName& name) const { KeyType fieldName, nodeListName; - auto itr = mStorage.begin(); - while (itr != mStorage.end()) { - splitFieldKey(itr->first, fieldName, nodeListName); + for (auto [key, valptr]: mStorage) { + splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) return true; - ++itr; } return false; } //------------------------------------------------------------------------------ -// Enroll a field. -//------------------------------------------------------------------------------ -template -void -StateBase:: -enroll(FieldBase& field) { - const KeyType key = this->key(field); - boost::any fieldptr; - fieldptr = &field; - mStorage[key] = fieldptr; - mNodeListPtrs.insert(field.nodeListPtr()); - // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; - ENSURE(&(this->getAny>(key)) == &field); - ENSURE(find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); -} - -//------------------------------------------------------------------------------ -// Enroll a field (shared_ptr). +// Return the full set of known keys. //------------------------------------------------------------------------------ template -void +std::vector::KeyType> StateBase:: -enroll(std::shared_ptr>& fieldPtr) { - const KeyType key = this->key(*fieldPtr); - mStorage[key] = fieldPtr.get(); - mNodeListPtrs.insert(fieldPtr->nodeListPtr()); - mFieldCache.push_back(fieldPtr); - ENSURE(find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); +keys() const { + vector result; + for (auto itr = mStorage.begin(); itr != mStorage.end(); ++itr) result.push_back(itr->first); + return result; } //------------------------------------------------------------------------------ -// Add the fields from a FieldList. +// Return the full set of Field Keys (mangled with NodeList names) //------------------------------------------------------------------------------ template -void +std::vector::KeyType> StateBase:: -enroll(FieldListBase& fieldList) { - for (auto itr = fieldList.begin_base(); - itr != fieldList.end_base(); - ++itr) { - this->enroll(**itr); +fullFieldKeys() const { + vector result; + for (auto [key, aref]: mStorage) { + if (aref.type() == typeid(std::reference_wrapper>)) { + result.push_back(key); + } } + return result; } //------------------------------------------------------------------------------ -// Return the full set of known keys. +// Return the set of non-field keys. //------------------------------------------------------------------------------ template std::vector::KeyType> StateBase:: -keys() const { +miscKeys() const { vector result; - result.reserve(mStorage.size()); - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) result.push_back(itr->first); - ENSURE(result.size() == mStorage.size()); + for (auto [key, aref]: mStorage) { + if (aref.type() != typeid(std::reference_wrapper>)) { + result.push_back(key); + } + } return result; } @@ -282,15 +253,15 @@ keys() const { template std::vector::FieldName> StateBase:: -fieldKeys() const { +fieldNames() const { + vector result; KeyType fieldName, nodeListName; - vector::FieldName> result; - result.reserve(mStorage.size()); - for (typename StorageType::const_iterator itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - splitFieldKey(itr->first, fieldName, nodeListName); - if (fieldName != "" and nodeListName != "") result.push_back(fieldName); + for (auto [key, aref]: mStorage) { + if (aref.type() == typeid(std::reference_wrapper>)) { + auto fref = std::any_cast>>(aref); + splitFieldKey(fref.get().name(), fieldName, nodeListName); + result.push_back(fieldName); + } } // Remove any duplicates. This will happen when we've stored the same field @@ -384,58 +355,40 @@ void StateBase:: assign(const StateBase& rhs) { - // Extract the keys for each state, and verify they line up. - REQUIRE(mStorage.size() == rhs.mStorage.size()); - vector lhsKeys = keys(); - vector rhsKeys = rhs.keys(); - REQUIRE(lhsKeys.size() == rhsKeys.size()); - sort(lhsKeys.begin(), lhsKeys.end()); - sort(rhsKeys.begin(), rhsKeys.end()); - REQUIRE(lhsKeys == rhsKeys); - - // Walk the keys, and rely on the underlying type to know how to copy itself. - for (typename StorageType::const_iterator itr = rhs.mStorage.begin(); - itr != rhs.mStorage.end(); - ++itr) { - auto& anylhs = mStorage[itr->first]; - const auto& anyrhs = itr->second; - try { - auto lhsptr = boost::any_cast*>(anylhs); - const auto rhsptr = boost::any_cast*>(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast*>(anylhs); - const auto rhsptr = boost::any_cast*>(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast(anylhs); - const auto rhsptr = boost::any_cast(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast(anylhs); - const auto rhsptr = boost::any_cast(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - // We'll assume other things don't need to be assigned... - // VERIFY2(false, "StateBase::assign ERROR: unknown type for key " << itr->first << "\n"); - } - } - } - } + // Build a visitor that knows how to assign each of our datatypes + AnyVisitor ASSIGN; + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>>([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + + // Apply the assignment visitor to all the stored State data + auto lhsitr = mStorage.begin(); + auto rhsitr = rhs.mStorage.begin(); + for (; lhsitr != mStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + ASSIGN.visit(lhsitr->second, rhsitr->second); } + // Copy the connectivity (by reference). This thing is too // big to carry around separate copies! - if (rhs.mConnectivityMapPtr != NULL) { + if (rhs.mConnectivityMapPtr != nullptr) { mConnectivityMapPtr = rhs.mConnectivityMapPtr; } else { mConnectivityMapPtr = ConnectivityMapPtr(); } // Copy the mesh. - if (rhs.mMeshPtr != NULL) { + if (rhs.mMeshPtr != nullptr) { mMeshPtr = MeshPtr(new MeshType()); *mMeshPtr = *(rhs.mMeshPtr); } else { @@ -453,40 +406,29 @@ copyState() { // Remove any pre-existing stuff. mCache = CacheType(); - mFieldCache = FieldCacheType(); - - // Walk the registered state and copy it to our local cache. - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - boost::any anythingPtr = itr->second; - - // Is this a Field? - try { - auto ptr = boost::any_cast*>(anythingPtr); - mFieldCache.push_back(ptr->clone()); - itr->second = mFieldCache.back().get(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(anythingPtr); - auto clone = std::shared_ptr>(new vector(*ptr)); - mCache.push_back(clone); - itr->second = clone.get(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast(anythingPtr); - auto clone = std::shared_ptr(new Vector(*ptr)); - mCache.push_back(clone); - itr->second = clone.get(); - - } catch (const boost::bad_any_cast&) { - // We'll assume other things don't need to be copied... - // VERIFY2(false, "StateBase::copyState ERROR: unrecognized type for " << itr->first << "\n"); - } - } - } + + // Build a visitor to clone each type of state data + AnyVisitor CLONE; + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { + auto clone = std::any_cast>>(x).get().clone(); + cache.push_back(clone); + storage[key] = std::ref(*clone); + }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + + // Clone all our stored data to cache + for (auto& [key, anyval]: mStorage) { + CLONE.visit(anyval, key, mStorage, mCache); } } diff --git a/src/DataBase/StateBase.hh b/src/DataBase/StateBase.hh index 355b6fbf5..5dc899b13 100644 --- a/src/DataBase/StateBase.hh +++ b/src/DataBase/StateBase.hh @@ -18,8 +18,7 @@ #include "Field/FieldBase.hh" -#include "boost/any.hpp" - +#include #include #include #include @@ -28,17 +27,16 @@ #include #include -#include "Field/FieldBase.hh" - namespace Spheral { // Forward declaration. template class NodeList; -template class FieldListBase; -template class Field; -template class FieldList; +template class Field; +template class FieldList; template class ConnectivityMap; template class Mesh; +template class ReproducingKernel; +enum class RKOrder : int; template class StateBase { @@ -46,93 +44,80 @@ class StateBase { public: //--------------------------- Public Interface ---------------------------// // Useful typedefs - typedef typename Dimension::Scalar Scalar; - typedef typename Dimension::Vector Vector; - typedef typename Dimension::Vector3d Vector3d; - typedef typename Dimension::Tensor Tensor; - typedef typename Dimension::SymTensor SymTensor; - typedef typename Dimension::ThirdRankTensor ThirdRankTensor; - typedef typename Dimension::FourthRankTensor FourthRankTensor; - typedef typename Dimension::FifthRankTensor FifthRankTensor; - typedef typename Spheral::ConnectivityMap ConnectivityMapType; - typedef typename Spheral::Mesh MeshType; - - typedef std::shared_ptr ConnectivityMapPtr; - typedef std::shared_ptr MeshPtr; - - typedef std::string KeyType; - typedef typename FieldBase::FieldName FieldName; + using Scalar = typename Dimension::Scalar; + using Vector = typename Dimension::Vector; + using Vector3d = typename Dimension::Vector3d; + using Tensor = typename Dimension::Tensor; + using SymTensor = typename Dimension::SymTensor; + using ThirdRankTensor = typename Dimension::ThirdRankTensor; + using FourthRankTensor = typename Dimension::FourthRankTensor; + using FifthRankTensor = typename Dimension::FifthRankTensor; + using ConnectivityMapType = typename Spheral::ConnectivityMap; + using MeshType = typename Spheral::Mesh; + + using ConnectivityMapPtr = std::shared_ptr; + using MeshPtr = std::shared_ptr; + + using KeyType = std::string; + using FieldName = typename FieldBase::FieldName; // Constructors, destructor. StateBase(); - StateBase(const StateBase& rhs); - virtual ~StateBase(); - - // Assignment. - StateBase& operator=(const StateBase& rhs); + StateBase(const StateBase& rhs) = default; + StateBase& operator=(const StateBase& rhs) = default; + virtual ~StateBase() {} // Test if two StateBases have equivalent fields. virtual bool operator==(const StateBase& rhs) const; //............................................................................ - // Test if the specified Field or key is currently registered. - bool registered(const KeyType& key) const; - bool registered(const FieldBase& field) const; - bool registered(const FieldListBase& fieldList) const; - bool fieldNameRegistered(const FieldName& fieldName) const; + // Enroll state + virtual void enroll(FieldBase& field); + virtual void enroll(std::shared_ptr>& fieldPtr); + virtual void enroll(FieldListBase& fieldList); + template void enroll(const KeyType& key, T& thing); //............................................................................ - // Enroll a Field. - virtual void enroll(FieldBase& field); - virtual void enroll(std::shared_ptr>& fieldPtr); - - // Return the field for the given key. - template - Field& field(const KeyType& key, - const Value& dummy) const; + // Access Fields + template Field& field(const KeyType& key) const; + template Field& field(const KeyType& key, + const Value& dummy) const; - // Return all the fields of the given Value. - template - std::vector*> allFields(const Value& dummy) const; - - // This version is for when providing a dummy Value type is not possible/practical. - // Using this form however meand using the cumbersome syntax: state.template field(key) - template - Field& field(const KeyType& key) const; + // Get all registered fields of the given data type + template std::vector*> allFields() const; + template std::vector*> allFields(const Value& dummy) const; //............................................................................ - // Enroll a FieldList. - virtual void enroll(FieldListBase& fieldList); - - // Return FieldLists constructed from all registered Fields with the given name. - template - FieldList fields(const std::string& name, - const Value& dummy) const; - - // This version is for when providing a dummy Value type is not possible/practical. - // Using this form however meand using the cumbersome syntax: state.template fields(key) - template - FieldList fields(const std::string& name) const; + // Access FieldLists + template FieldList fields(const std::string& name) const; + template FieldList fields(const std::string& name, + const Value& dummy) const; //............................................................................ - // Enroll an arbitrary type - template - void enrollAny(const KeyType& key, Value& thing); + // Access an arbitrary type + template Value& get(const KeyType& key) const; + template Value& get(const KeyType& key, const Value& dummy) const; - // Return an arbitrary type (held by any) - template - Value& getAny(const KeyType& key) const; - - template - Value& getAny(const KeyType& key, const Value& dummy) const; + //............................................................................ + // Test if the specified Field or key is currently registered. + bool registered(const KeyType& key) const; + bool registered(const FieldBase& field) const; + bool registered(const FieldListBase& fieldList) const; + bool fieldNameRegistered(const FieldName& fieldName) const; //............................................................................ - // Return the complete set of keys registered. + // Return the complete set of keys registered std::vector keys() const; + // The field keys including mangling with NodeList names + std::vector fullFieldKeys() const; + + // The non-field (miscellaneous) keys + std::vector miscKeys() const; + // Return the set of known field names (unencoded from our internal mangling // convention with the NodeList name). - std::vector fieldKeys() const; + std::vector fieldNames() const; //............................................................................ // A state object can carry around a reference to a ConnectivityMap. @@ -172,14 +157,12 @@ public: protected: //--------------------------- Protected Interface ---------------------------// - typedef std::map StorageType; - typedef std::list>> FieldCacheType; - typedef std::list CacheType; + using StorageType = std::map; + using CacheType = std::list; // Protected data. - StorageType mStorage; - CacheType mCache; - FieldCacheType mFieldCache; + StorageType mStorage; + CacheType mCache; std::set*> mNodeListPtrs; ConnectivityMapPtr mConnectivityMapPtr; MeshPtr mMeshPtr; diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index 6a8bb41cf..888e1cd82 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -1,27 +1,45 @@ #include "boost/algorithm/string.hpp" #include "DataBase/UpdatePolicyBase.hh" +#include "RK/RKCorrectionParams.hh" +#include "RK/ReproducingKernel.hh" #include "Mesh/Mesh.hh" +#include "Utilities/range.hh" #include "Utilities/DBC.hh" namespace Spheral { +//------------------------------------------------------------------------------ +// Enroll an arbitrary type +//------------------------------------------------------------------------------ +template +template +inline +void +StateBase:: +enroll(const KeyType& key, T& thing) { + // std::cerr << "StateBase::enroll " << key << std::endl; + mStorage[key] = std::ref(thing); +} + //------------------------------------------------------------------------------ // Return the Field for the given key. //------------------------------------------------------------------------------ template template +inline Field& StateBase:: -field(const typename StateBase::KeyType& key) const { - try { - return dynamic_cast&>(this->getAny>(key)); - } catch (...) { - VERIFY2(false,"StateBase ERROR: unable to extract field for key " << key << "\n"); - } +field(const KeyType& key) const { + FieldBase& fb = this->template get>(key); + auto* fptr = dynamic_cast*>(&fb); + VERIFY2(fptr != nullptr, + "StateBase::field ERROR: field type incorrect for key " << key); + return *fptr; } template template +inline Field& StateBase:: field(const typename StateBase::KeyType& key, @@ -37,22 +55,28 @@ template inline std::vector*> StateBase:: -allFields(const Value&) const { +allFields() const { std::vector*> result; KeyType fieldName, nodeListName; - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - try { - Field* ptr = dynamic_cast*>(boost::any_cast*>(itr->second)); - if (ptr != 0) result.push_back(ptr); - } catch (...) { - // The field must have been the wrong type. + for (auto [key, aref]: mStorage) { + if (aref.type() == typeid(std::reference_wrapper>)) { + auto fb = std::any_cast>>(aref); + auto* fptr = dynamic_cast*>(&fb.get()); + if (fptr != nullptr) result.push_back(fptr); } } return result; } +template +template +inline +std::vector*> +StateBase:: +allFields(const Value&) const { + return this->template allFields(); +} + //------------------------------------------------------------------------------ // Return a FieldList containing all registered fields of the given name. //------------------------------------------------------------------------------ @@ -64,13 +88,15 @@ StateBase:: fields(const std::string& name) const { FieldList result; KeyType fieldName, nodeListName; - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - splitFieldKey(itr->first, fieldName, nodeListName); + for (auto [key, aref]: mStorage) { + splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); - result.appendField(this->template field(itr->first)); + if (aref.type() == typeid(std::reference_wrapper>)) { + auto fb = std::any_cast>>(aref); + auto* fptr = dynamic_cast*>(&fb.get()); + if (fptr != nullptr) result.appendField(*fptr); + } } } return result; @@ -86,40 +112,46 @@ fields(const std::string& name, const Value& dummy) const { } //------------------------------------------------------------------------------ -// Enroll an arbitrary type +// Extract an arbitrary type //------------------------------------------------------------------------------ template template -void +inline +Value& StateBase:: -enrollAny(const typename StateBase::KeyType& key, Value& thing) { - mStorage[key] = &thing; +get(const typename StateBase::KeyType& key) const { + auto itr = mStorage.find(key); + VERIFY2(itr != mStorage.end(), "StateBase ERROR: failed lookup for key " << key); + if (itr->second.type() == typeid(std::reference_wrapper)) { + return std::any_cast>(itr->second); + } + VERIFY2(false, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); } -//------------------------------------------------------------------------------ -// Extract an arbitrary type -//------------------------------------------------------------------------------ +// Same thing passing a dummy argument to help with template type template template +inline Value& StateBase:: -getAny(const typename StateBase::KeyType& key) const { - try { - Value& result = *boost::any_cast(mStorage.find(key)->second); - return result; - } catch (const boost::bad_any_cast&) { - VERIFY2(false, "StateBase::getAny ERROR: unable to extract Value for " << key << "\n"); - } +get(const typename StateBase::KeyType& key, + const Value&) const { + return this->get(key); } -// Same thing passing a dummy argument to help with template type +//------------------------------------------------------------------------------ +// Assign the Fields matching the given name of this State object to be equal to +// the values in another. +//------------------------------------------------------------------------------ template template -Value& +inline +void StateBase:: -getAny(const typename StateBase::KeyType& key, - const Value&) const { - return this->getAny(key); +assignFields(const StateBase& rhs, const std::string name) { + auto lhsfields = this->fields(name, Value()); + auto rhsfields = rhs.fields(name, Value()); + lhsfields.assignFields(rhsfields); } //------------------------------------------------------------------------------ @@ -145,21 +177,6 @@ key(const FieldListBase& fieldList) { return buildFieldKey((*fieldList.begin_base())->name(), UpdatePolicyBase::wildcard()); } -//------------------------------------------------------------------------------ -// Assign the Fields matching the given name of this State object to be equal to -// the values in another. -//------------------------------------------------------------------------------ -template -template -inline -void -StateBase:: -assignFields(const StateBase& rhs, const std::string name) { - auto lhsfields = this->fields(name, Value()); - auto rhsfields = rhs.fields(name, Value()); - lhsfields.assignFields(rhsfields); -} - //------------------------------------------------------------------------------ // Internal methods to encode the convention for combining Field and NodeList // names into a single unique key. diff --git a/src/DataBase/StateDerivatives.cc b/src/DataBase/StateDerivatives.cc index e7e54ee23..ef898013c 100644 --- a/src/DataBase/StateDerivatives.cc +++ b/src/DataBase/StateDerivatives.cc @@ -9,6 +9,7 @@ #include "DataBase.hh" #include "Physics/Physics.hh" #include "Field/Field.hh" +#include "Utilities/AnyVisitor.hh" using std::vector; using std::cout; @@ -41,11 +42,7 @@ StateDerivatives(DataBase& dataBase, StateBase(), mCalculatedNodePairs(), mNumSignificantNeighbors() { - - // Iterate over the physics packages, and have them register their derivatives. - for (PackageIterator itr = physicsPackages.begin(); - itr != physicsPackages.end(); - ++itr) (*itr)->registerDerivatives(dataBase, *this); + for (auto pkg: physicsPackages) pkg->registerDerivatives(dataBase, *this); } //------------------------------------------------------------------------------ @@ -59,11 +56,7 @@ StateDerivatives(DataBase& dataBase, StateBase(), mCalculatedNodePairs(), mNumSignificantNeighbors() { - - // Iterate over the physics packages, and have them register their derivatives. - for (PackageIterator itr = physicsPackageBegin; - itr != physicsPackageEnd; - ++itr) (*itr)->registerDerivatives(dataBase, *this); + for (auto pkg: range(physicsPackageBegin, physicsPackageEnd)) pkg->registerDerivatives(dataBase, *this); } //------------------------------------------------------------------------------ @@ -158,30 +151,24 @@ void StateDerivatives:: Zero() { - // Walk the state fields and zero them. - for (typename StateBase::StorageType::iterator itr = this->mStorage.begin(); - itr != this->mStorage.end(); - ++itr) { - - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->Zero(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->clear(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->clear(); - - } catch (const boost::bad_any_cast&) { - VERIFY2(false, "StateDerivatives::Zero ERROR: unknown type for key " << itr->first << "\n"); - } - } - } + // Build a visitor to zero each data type + AnyVisitor ZERO; + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().Zero(); }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = 0.0; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = Vector::zero; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = Tensor::zero; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = SymTensor::zero; }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { } ); + + // Walk the state values and zero them + for (auto itr: mStorage) { + ZERO.visit(itr.second); } // Reinitialize the node pair interaction information. diff --git a/src/DataBase/StateDerivatives.hh b/src/DataBase/StateDerivatives.hh index 64a4dc729..5bf40ccc1 100644 --- a/src/DataBase/StateDerivatives.hh +++ b/src/DataBase/StateDerivatives.hh @@ -26,16 +26,16 @@ class StateDerivatives: public StateBase { public: //--------------------------- Public Interface ---------------------------// // Useful typedefs - typedef typename Dimension::Scalar Scalar; - typedef typename Dimension::Vector Vector; - typedef typename Dimension::Vector3d Vector3d; - typedef typename Dimension::Tensor Tensor; - typedef typename Dimension::SymTensor SymTensor; + using Scalar = typename Dimension::Scalar; + using Vector = typename Dimension::Vector; + using Vector3d = typename Dimension::Vector3d; + using Tensor = typename Dimension::Tensor; + using SymTensor = typename Dimension::SymTensor; - typedef std::vector*> PackageList; - typedef typename PackageList::iterator PackageIterator; + using PackageList = std::vector*>; + using PackageIterator = typename PackageList::iterator; - typedef typename StateBase::KeyType KeyType; + using KeyType = typename StateBase::KeyType; // Constructors, destructor. StateDerivatives(); @@ -73,16 +73,14 @@ private: //--------------------------- Private Interface ---------------------------// // Map for storing information about pairs of nodes that have already been // calculated. - typedef std::map, - std::vector > > CalculatedPairType; + using CalculatedPairType = std::map, + std::vector>>; CalculatedPairType mCalculatedNodePairs; // Map for maintaining the number of significant neighbors per node. - typedef std::map, int> SignificantNeighborMapType; - + using SignificantNeighborMapType = std::map, int>; SignificantNeighborMapType mNumSignificantNeighbors; - using typename StateBase::StorageType; using StateBase::mStorage; }; diff --git a/src/DataBase/StateInline.hh b/src/DataBase/StateInline.hh index e8b5e58c3..8ea18e4c1 100644 --- a/src/DataBase/StateInline.hh +++ b/src/DataBase/StateInline.hh @@ -1,17 +1,33 @@ namespace Spheral { //------------------------------------------------------------------------------ -// Enroll the given policy. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(const typename State::KeyType& key, - typename State::PolicyPointer polptr) { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - mPolicyMap[fieldKey][key] = polptr; +// Functors in a detail namespace to help with partial specialization +//------------------------------------------------------------------------------ +namespace Detail { + +template +struct EnrollAny { + void operator()(State& state, + const typename State::KeyType& key, + T& thing) { + dynamic_cast*>(&state)->enroll(key, thing); + } +}; + +template +struct EnrollAny> { + void operator()(State& state, + const typename State::KeyType& key, + std::shared_ptr& thing) { + auto UPP = std::dynamic_pointer_cast>(thing); + if (UPP) { + state.enroll(key, UPP); + } else { + dynamic_cast*>(&state)->enroll(key, thing); + } + } +}; + } //------------------------------------------------------------------------------ @@ -22,9 +38,9 @@ inline void State:: enroll(FieldBase& field, - typename State::PolicyPointer polptr) { + typename State::PolicyPointer policy) { this->enroll(field); - this->enroll(this->key(field), polptr); + this->enroll(this->key(field), policy); } //------------------------------------------------------------------------------ @@ -35,51 +51,32 @@ inline void State:: enroll(FieldListBase& fieldList, - typename State::PolicyPointer polptr) { - if (polptr->clonePerField()) { + typename State::PolicyPointer policy) { + if (policy->clonePerField()) { // std::cerr << "Registering FieldList " << this->key(fieldList) << " with cloning policy" << std::endl; - for (auto bitr = fieldList.begin_base(); bitr < fieldList.end_base(); ++bitr) { - this->enroll(**bitr, polptr); + for (auto fptr: range(fieldList.begin_base(), fieldList.end_base())) { + this->enroll(*fptr, policy); } } else { // std::cerr << "Registering FieldList " << this->key(fieldList) << " with SINGLE policy" << std::endl; // this->enroll(this->key(fieldList), fieldList); this->enroll(fieldList); // enrolls each field without a policy - this->enroll(this->key(fieldList), polptr); + this->enroll(this->key(fieldList), policy); } } //------------------------------------------------------------------------------ -// Enroll the given field. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(FieldBase& field) { - StateBase::enroll(field); -} - -//------------------------------------------------------------------------------ -// Enroll the given field shared_pointer. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(std::shared_ptr>& fieldPtr) { - StateBase::enroll(fieldPtr); -} - -//------------------------------------------------------------------------------ -// Enroll the given field list. +// Enroll the given policy. //------------------------------------------------------------------------------ template inline void State:: -enroll(FieldListBase& fieldList) { - StateBase::enroll(fieldList); +enroll(const typename State::KeyType& key, + typename State::PolicyPointer policy) { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + mPolicyMap[fieldKey][key] = policy; } //------------------------------------------------------------------------------ @@ -96,23 +93,15 @@ policy(const Field& field) const { } //------------------------------------------------------------------------------ -// Optionally trip a flag indicating policies should time advance only -- no replacing state! -// This is useful when you're trying to cheat and reuse derivatives from a prior advance. +// Enroll an arbitrary type //------------------------------------------------------------------------------ template -inline -bool -State:: -timeAdvanceOnly() const { - return mTimeAdvanceOnly; -} - -template +template inline void State:: -timeAdvanceOnly(const bool x) { - mTimeAdvanceOnly = x; +enroll(const KeyType& key, T& thing) { + Detail::EnrollAny()(*this, key, thing); } } diff --git a/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc b/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc index fde9b5bd5..6f568dda2 100644 --- a/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc +++ b/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc @@ -147,8 +147,8 @@ secondDerivativesLoop(const typename Dimension::Scalar time, auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); CHECK(M.size() == numNodeLists); CHECK(localM.size() == numNodeLists); diff --git a/src/FSISPH/SolidFSISPHHydroBase.cc b/src/FSISPH/SolidFSISPHHydroBase.cc index d7c920e73..0afac53b4 100644 --- a/src/FSISPH/SolidFSISPHHydroBase.cc +++ b/src/FSISPH/SolidFSISPHHydroBase.cc @@ -445,8 +445,8 @@ registerDerivatives(DataBase& dataBase, CHECK(not derivs.registered(mDvDt)); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); - derivs.enrollAny(HydroFieldNames::pairWork, mPairDepsDt); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairWork, mPairDepsDt); derivs.enroll(plasticStrainRate); derivs.enroll(mXSPHDeltaV); diff --git a/src/GSPH/GSPHEvaluateDerivatives.cc b/src/GSPH/GSPHEvaluateDerivatives.cc index 4644684d4..6573c7223 100644 --- a/src/GSPH/GSPHEvaluateDerivatives.cc +++ b/src/GSPH/GSPHEvaluateDerivatives.cc @@ -72,8 +72,8 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto DvDt = derivatives.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivatives.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DvDx = derivatives.fields(HydroFieldNames::velocityGradient, Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); diff --git a/src/GSPH/GenericRiemannHydro.cc b/src/GSPH/GenericRiemannHydro.cc index bd05e621e..af593b15a 100644 --- a/src/GSPH/GenericRiemannHydro.cc +++ b/src/GSPH/GenericRiemannHydro.cc @@ -302,8 +302,8 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDspecificThermalEnergyDt); derivs.enroll(mDvDx); derivs.enroll(mM); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); - derivs.enrollAny(HydroFieldNames::pairWork, mPairDepsDt); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairWork, mPairDepsDt); } //------------------------------------------------------------------------------ diff --git a/src/GSPH/MFMEvaluateDerivatives.cc b/src/GSPH/MFMEvaluateDerivatives.cc index c68d4adfd..32dc20d3d 100644 --- a/src/GSPH/MFMEvaluateDerivatives.cc +++ b/src/GSPH/MFMEvaluateDerivatives.cc @@ -71,8 +71,8 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto DvDt = derivatives.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivatives.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DvDx = derivatives.fields(HydroFieldNames::velocityGradient, Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); diff --git a/src/GSPH/MFVEvaluateDerivatives.cc b/src/GSPH/MFVEvaluateDerivatives.cc index 53a17a73b..1e48ae27b 100644 --- a/src/GSPH/MFVEvaluateDerivatives.cc +++ b/src/GSPH/MFVEvaluateDerivatives.cc @@ -89,9 +89,9 @@ secondDerivativesLoop(const typename Dimension::Scalar time, //auto HStretchTensor = derivatives.fields("HStretchTensor", SymTensor::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); - auto& pairMassFlux = derivatives.getAny(GSPHFieldNames::pairMassFlux, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); + auto& pairMassFlux = derivatives.get(GSPHFieldNames::pairMassFlux, vector()); CHECK(DrhoDx.size() == numNodeLists); CHECK(M.size() == numNodeLists); diff --git a/src/GSPH/MFVHydroBase.cc b/src/GSPH/MFVHydroBase.cc index be68ef1fa..e60aea910 100644 --- a/src/GSPH/MFVHydroBase.cc +++ b/src/GSPH/MFVHydroBase.cc @@ -228,7 +228,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDmomentumDt); derivs.enroll(mDvolumeDt); //derivs.enroll(mHStretchTensor); - derivs.enrollAny(GSPHFieldNames::pairMassFlux, mPairMassFlux); + derivs.enroll(GSPHFieldNames::pairMassFlux, mPairMassFlux); } //------------------------------------------------------------------------------ diff --git a/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc b/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc index eb7a33a27..4f2a4a958 100644 --- a/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc +++ b/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc @@ -84,9 +84,9 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto DmassDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::mass, 0.0); const auto DmomentumDt = derivs.fields(IncrementState::prefix() + GSPHFieldNames::momentum, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); - const auto& pairDepsDt = derivs.getAny(HydroFieldNames::pairWork, vector()); - const auto& pairMassFlux = derivs.getAny(GSPHFieldNames::pairMassFlux, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); + const auto& pairDepsDt = derivs.get(HydroFieldNames::pairWork, vector()); + const auto& pairMassFlux = derivs.get(GSPHFieldNames::pairMassFlux, vector()); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc b/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc index 154900215..c761a4ac9 100644 --- a/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc @@ -79,8 +79,8 @@ update(const KeyType& key, const auto mass = state.fields(HydroFieldNames::mass, Scalar()); const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); - const auto& pairDepsDt = derivs.getAny(HydroFieldNames::pairWork, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); + const auto& pairDepsDt = derivs.get(HydroFieldNames::pairWork, vector()); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); const auto npairs = pairs.size(); diff --git a/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc b/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc index 8646392df..66827862f 100644 --- a/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc @@ -82,7 +82,7 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); const auto eps0 = state.fields(HydroFieldNames::specificThermalEnergy + "0", Scalar()); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc b/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc index 90ff42f2b..0519846c2 100644 --- a/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc @@ -94,7 +94,7 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); const auto eps0 = state.fields(HydroFieldNames::specificThermalEnergy + "0", Scalar()); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/SpecificThermalEnergyPolicy.cc b/src/Hydro/SpecificThermalEnergyPolicy.cc index c856e579f..51fd1691d 100644 --- a/src/Hydro/SpecificThermalEnergyPolicy.cc +++ b/src/Hydro/SpecificThermalEnergyPolicy.cc @@ -79,7 +79,7 @@ update(const KeyType& key, const auto mass = state.fields(HydroFieldNames::mass, Scalar()); const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto DvDt = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/SphericalPositionPolicy.cc b/src/Hydro/SphericalPositionPolicy.cc index 392d3b6b4..42193a4ae 100644 --- a/src/Hydro/SphericalPositionPolicy.cc +++ b/src/Hydro/SphericalPositionPolicy.cc @@ -56,33 +56,27 @@ update(const KeyType& key, // Get the field name portion of the key. KeyType fieldKey, nodeListKey; StateBase::splitFieldKey(key, fieldKey, nodeListKey); - REQUIRE(nodeListKey == UpdatePolicyBase::wildcard()); // Get the state we're updating. - auto f = state.fields(fieldKey, Vector::zero); - const auto numNodeLists = f.size(); + auto f = state.field(key, Vector::zero); // Find all the available matching derivative Field keys. const auto incrementKey = prefix() + fieldKey; - const auto allkeys = derivs.fieldKeys(); - vector incrementKeys; + const auto allkeys = derivs.keys(); + KeyType dfKey, dfNodeListKey; for (const auto& key: allkeys) { - if (key.compare(0, incrementKey.size(), incrementKey) == 0) { - incrementKeys.push_back(key); - } - } - CHECK(not incrementKeys.empty()); + StateBase::splitFieldKey(key, dfKey, dfNodeListKey); + if (dfNodeListKey == nodeListKey and + dfKey.compare(0, incrementKey.size(), incrementKey) == 0) { - // Update by each of our derivative fields. - for (const auto& key: incrementKeys) { - const auto df = derivs.fields(key, Vector::zero); - CHECK(df.size() == f.size()); - for (auto k = 0u; k != numNodeLists; ++k) { - const auto n = f[k]->numInternalElements(); - for (auto i = 0u; i != n; ++i) { + // This delta field matches the base of increment key, so apply it. + const auto& df = derivs.field(key, Vector::zero); + const auto n = f.numInternalElements(); +#pragma omp parallel for + for (auto i = 0u; i < n; ++i) { // This is where we diverge from the standard IncrementState. Ensure we cannot cross to // negative radius. - f(k,i) = std::max(0.5*f(k,i), f(k,i) + multiplier*(df(k, i))); + f(i) = std::max(0.5*f(i), f(i) + multiplier*(df(i))); } } } diff --git a/src/Hydro/SphericalPositionPolicy.hh b/src/Hydro/SphericalPositionPolicy.hh index 8b121b057..b8b8aeb5c 100644 --- a/src/Hydro/SphericalPositionPolicy.hh +++ b/src/Hydro/SphericalPositionPolicy.hh @@ -35,6 +35,9 @@ public: const double t, const double dt); + // Should this policy be cloned per Field when registering for a FieldList? + virtual bool clonePerField() const { return true; } + // Equivalence. virtual bool operator==(const UpdatePolicyBase& rhs) const; diff --git a/src/PYB11/DataBase/StateBase.py b/src/PYB11/DataBase/StateBase.py index bf774d9a8..03de427d1 100644 --- a/src/PYB11/DataBase/StateBase.py +++ b/src/PYB11/DataBase/StateBase.py @@ -7,17 +7,17 @@ class StateBase: PYB11typedefs = """ - typedef typename %(Dimension)s::Scalar Scalar; - typedef typename %(Dimension)s::Vector Vector; - typedef typename %(Dimension)s::Tensor Tensor; - typedef typename %(Dimension)s::SymTensor SymTensor; - typedef typename %(Dimension)s::ThirdRankTensor ThirdRankTensor; - typedef typename %(Dimension)s::FourthRankTensor FourthRankTensor; - typedef typename %(Dimension)s::FifthRankTensor FifthRankTensor; - typedef typename %(Dimension)s::FacetedVolume FacetedVolume; - typedef typename StateBase<%(Dimension)s>::KeyType KeyType; - typedef typename StateBase<%(Dimension)s>::FieldName FieldName; - typedef typename StateBase<%(Dimension)s>::MeshPtr MeshPtr; + using Scalar = typename %(Dimension)s::Scalar; + using Vector = typename %(Dimension)s::Vector; + using Tensor = typename %(Dimension)s::Tensor; + using SymTensor = typename %(Dimension)s::SymTensor; + using ThirdRankTensor = typename %(Dimension)s::ThirdRankTensor; + using FourthRankTensor = typename %(Dimension)s::FourthRankTensor; + using FifthRankTensor = typename %(Dimension)s::FifthRankTensor; + using FacetedVolume = typename %(Dimension)s::FacetedVolume; + using KeyType = typename StateBase<%(Dimension)s>::KeyType; + using FieldName = typename StateBase<%(Dimension)s>::FieldName; + using MeshPtr = typename StateBase<%(Dimension)s>::MeshPtr; """ #........................................................................... @@ -100,9 +100,19 @@ def keys(self): return "std::vector" @PYB11const - def fieldKeys(self): - "The set of Field names for the state in the StateBase" - return "std::vector" + def fullFieldKeys(self): + "The set of Field names (with NodeList mangling) for the state in the StateBase" + return "std::vector" + + @PYB11const + def fieldNames(self): + "The set of unique Field names for the state in the StateBase (no NodeList mangling)" + return "std::vector" + + @PYB11const + def miscKeys(self): + "The set of names for non-Fields in the StateBase" + return "std::vector" def enrollConnectivityMap(self, connectivityMapPtr = "std::shared_ptr>"): @@ -228,9 +238,9 @@ def allFields(self, allRKCoefficientsFields = PYB11TemplateMethod(allFields, "RKCoefficients<%(Dimension)s>") #........................................................................... - # enrollAny/getAny + # enroll/get @PYB11template("Value") - def enrollAny(self, + def enroll(self, key = "const KeyType&", thing = "%(Value)s&"): "Enroll a type of %(Value)s." @@ -239,13 +249,13 @@ def enrollAny(self, @PYB11template("Value") @PYB11const @PYB11returnpolicy("reference_internal") - def getAny(self, + def get(self, key = "const KeyType&"): "Return a stored type of %(Value)s" return "%(Value)s&" - enrollVectorVector = PYB11TemplateMethod(enrollAny, "std::vector", pyname="enrollAny") - getVectorVector = PYB11TemplateMethod(getAny, "std::vector", pyname="getAny") + enrollVectorVector = PYB11TemplateMethod(enroll, "std::vector", pyname="enroll") + getVectorVector = PYB11TemplateMethod(get, "std::vector", pyname="get") #........................................................................... # assignFields diff --git a/src/RK/RKCorrections.cc b/src/RK/RKCorrections.cc index 5efdaff8d..a29314140 100644 --- a/src/RK/RKCorrections.cc +++ b/src/RK/RKCorrections.cc @@ -155,9 +155,9 @@ RKCorrections:: registerState(DataBase& dataBase, State& state) { // Stuff RKCorrections owns - state.enrollAny(RKFieldNames::rkOrders, mOrders); + state.enroll(RKFieldNames::rkOrders, mOrders); for (auto order: mOrders) { - state.enrollAny(RKFieldNames::reproducingKernel(order), mWR[order]); + state.enroll(RKFieldNames::reproducingKernel(order), mWR[order]); state.enroll(mCorrections[order]); } state.enroll(mVolume); diff --git a/src/RK/ReproducingKernel.cc b/src/RK/ReproducingKernel.cc index 09c1a7a3b..c2c3ad26c 100644 --- a/src/RK/ReproducingKernel.cc +++ b/src/RK/ReproducingKernel.cc @@ -46,11 +46,14 @@ operator=(const ReproducingKernel& rhs) { } //------------------------------------------------------------------------------ -// Destructor +// Equivalence //------------------------------------------------------------------------------ template +bool ReproducingKernel:: -~ReproducingKernel() { -} +operator==(const ReproducingKernel& rhs) const { + return (ReproducingKernelMethods::operator==(rhs) and + *mWptr == *(rhs.mWptr)); +} } diff --git a/src/RK/ReproducingKernel.hh b/src/RK/ReproducingKernel.hh index cfaf073ec..1296b32c6 100644 --- a/src/RK/ReproducingKernel.hh +++ b/src/RK/ReproducingKernel.hh @@ -24,7 +24,8 @@ public: ReproducingKernel(); ReproducingKernel(const ReproducingKernel& rhs); ReproducingKernel& operator=(const ReproducingKernel& rhs); - ~ReproducingKernel(); + virtual ~ReproducingKernel() {} + bool operator==(const ReproducingKernel& rhs) const; // Base kernel calls Scalar evaluateBaseKernel(const Vector& x, diff --git a/src/RK/ReproducingKernelMethods.cc b/src/RK/ReproducingKernelMethods.cc index ef85e2339..99470955b 100644 --- a/src/RK/ReproducingKernelMethods.cc +++ b/src/RK/ReproducingKernelMethods.cc @@ -226,11 +226,15 @@ operator=(const ReproducingKernelMethods& rhs) { } //------------------------------------------------------------------------------ -// Destructor +// Equivalence //------------------------------------------------------------------------------ template +bool ReproducingKernelMethods:: -~ReproducingKernelMethods() { -} +operator==(const ReproducingKernelMethods& rhs) const { + return (mOrder == rhs.mOrder and + mGradCorrectionsSize == rhs.mGradCorrectionsSize and + mHessCorrectionsSize == rhs.mHessCorrectionsSize); +} } diff --git a/src/RK/ReproducingKernelMethods.hh b/src/RK/ReproducingKernelMethods.hh index d7ffd41e0..d4a9908d8 100644 --- a/src/RK/ReproducingKernelMethods.hh +++ b/src/RK/ReproducingKernelMethods.hh @@ -26,7 +26,8 @@ public: ReproducingKernelMethods(); ReproducingKernelMethods(const ReproducingKernelMethods& rhs); ReproducingKernelMethods& operator=(const ReproducingKernelMethods& rhs); - ~ReproducingKernelMethods(); + virtual ~ReproducingKernelMethods() {} + bool operator==(const ReproducingKernelMethods& rhs) const; // Build a transformation operator TransformationMatrix transformationMatrix(const Tensor& T, diff --git a/src/SPH/PSPHHydroBase.cc b/src/SPH/PSPHHydroBase.cc index c6b4901c1..d5794eba8 100644 --- a/src/SPH/PSPHHydroBase.cc +++ b/src/SPH/PSPHHydroBase.cc @@ -296,7 +296,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SPHHydroBase.cc b/src/SPH/SPHHydroBase.cc index 41a06699b..2c38438f0 100644 --- a/src/SPH/SPHHydroBase.cc +++ b/src/SPH/SPHHydroBase.cc @@ -354,7 +354,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mGradRho); derivs.enroll(mM); derivs.enroll(mLocalM); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); TIME_END("SPHregisterDerivs"); } @@ -645,7 +645,7 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto maxViscousPressure = derivs.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivs.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivs.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivs.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivs.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SPHHydroBaseRZ.cc b/src/SPH/SPHHydroBaseRZ.cc index ecf3b36b7..6484480da 100644 --- a/src/SPH/SPHHydroBaseRZ.cc +++ b/src/SPH/SPHHydroBaseRZ.cc @@ -244,7 +244,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SolidSPHHydroBase.cc b/src/SPH/SolidSPHHydroBase.cc index 9d996eace..27afce77f 100644 --- a/src/SPH/SolidSPHHydroBase.cc +++ b/src/SPH/SolidSPHHydroBase.cc @@ -371,7 +371,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SolidSPHHydroBaseRZ.cc b/src/SPH/SolidSPHHydroBaseRZ.cc index 0a4eeb52e..7ea9c1d86 100644 --- a/src/SPH/SolidSPHHydroBaseRZ.cc +++ b/src/SPH/SolidSPHHydroBaseRZ.cc @@ -295,7 +295,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SolidSphericalSPHHydroBase.cc b/src/SPH/SolidSphericalSPHHydroBase.cc index 31e3e0813..b78cbc1b2 100644 --- a/src/SPH/SolidSphericalSPHHydroBase.cc +++ b/src/SPH/SolidSphericalSPHHydroBase.cc @@ -294,7 +294,7 @@ evaluateDerivatives(const Dim<1>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SphericalSPHHydroBase.cc b/src/SPH/SphericalSPHHydroBase.cc index 843ed7ef6..d6591f2e6 100644 --- a/src/SPH/SphericalSPHHydroBase.cc +++ b/src/SPH/SphericalSPHHydroBase.cc @@ -248,7 +248,7 @@ evaluateDerivatives(const Dim<1>::Scalar time, auto localM = derivs.fields("local " + HydroFieldNames::M_SPHCorrection, Tensor::zero); auto maxViscousPressure = derivs.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivs.fields(HydroFieldNames::effectiveViscousPressure, 0.0); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivs.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivs.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/Utilities/AnyVisitor.hh b/src/Utilities/AnyVisitor.hh new file mode 100644 index 000000000..66e63c89b --- /dev/null +++ b/src/Utilities/AnyVisitor.hh @@ -0,0 +1,39 @@ +//---------------------------------Spheral++----------------------------------// +// Collect visitor methods to apply to std::any object holders +// +// This allows us to use the visitor pattern with containers of std::any +// obfuscated objects similarly to the std::variant pattern. +//----------------------------------------------------------------------------// +#ifndef __Spheral_AnyVisitor__ +#define __Spheral_AnyVisitor__ + +#include + +namespace Spheral { + +template +class AnyVisitor { +public: + using VisitorFunc = std::function; + + template + RETURNT visit(T value, EXTRAARGS&&... extraargs) const { + auto it = mVisitors.find(std::type_index(value.type())); + if (it != mVisitors.end()) { + return it->second(value, extraargs...); + } + VERIFY2(false, "AnyVisitor ERROR: unable to process unknown data type " << std::quoted(value.type().name())); + } + + template + void addVisitor(VisitorFunc visitor) { + mVisitors[std::type_index(typeid(T))] = visitor; + } + +private: + std::unordered_map mVisitors; +}; + +} + +#endif diff --git a/src/Utilities/CMakeLists.txt b/src/Utilities/CMakeLists.txt index 75439b880..3c73b46f3 100644 --- a/src/Utilities/CMakeLists.txt +++ b/src/Utilities/CMakeLists.txt @@ -131,6 +131,8 @@ set(Utilities_headers timingUtilities.hh uniform_random.hh uniform_random_Inline.hh + range.hh + AnyVisitor.hh ) spheral_install_python_files(fitspline.py) diff --git a/src/VoronoiCells/SubPointPressureHourglassControl.cc b/src/VoronoiCells/SubPointPressureHourglassControl.cc index 9e5174d47..578202237 100644 --- a/src/VoronoiCells/SubPointPressureHourglassControl.cc +++ b/src/VoronoiCells/SubPointPressureHourglassControl.cc @@ -341,7 +341,7 @@ evaluateDerivatives(const Scalar time, auto DvDt = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DxDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::position, Vector::zero); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); CHECK(DvDt.size() == numNodeLists); CHECK(DepsDt.size() == numNodeLists);