Skip to content

Commit

Permalink
feat[femutils]: start working on applying dirichlet via elimination f…
Browse files Browse the repository at this point in the history
…or bsr format
  • Loading branch information
toutane committed Dec 20, 2024
1 parent 2355ba9 commit 7950de9
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 28 deletions.
6 changes: 6 additions & 0 deletions femutils/AlephDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <arcane/aleph/AlephTypesSolver.h>
#include <arcane/aleph/Aleph.h>
#include <arccore/base/NotImplementedException.h>

#include "FemUtils.h"
#include "IDoFLinearSystemFactory.h"
Expand Down Expand Up @@ -245,6 +246,11 @@ class AlephDoFLinearSystemImpl
info() << "EliminateRowColumn row=" << row.localId() << " v=" << value;
}

void setEliminationArrays(VariableDoFByte& dof_elimination_info, VariableDoFReal& dof_elimination_value) override

Check warning on line 249 in femutils/AlephDoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/AlephDoFLinearSystem.cc#L249

Added line #L249 was not covered by tests
{
ARCANE_THROW(NotImplementedException, "");

Check warning on line 251 in femutils/AlephDoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/AlephDoFLinearSystem.cc#L251

Added line #L251 was not covered by tests
};

void solve() override
{
UniqueArray<Real> aleph_result;
Expand Down
63 changes: 36 additions & 27 deletions femutils/BSRFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "arcane/accelerator/VariableViews.h"
#include <arcane/accelerator/Atomic.h>
#include <arcane/accelerator/Scan.h>
#include <type_traits>

#include "DoFLinearSystem.h"
#include "CsrFormatMatrix.h"
Expand Down Expand Up @@ -824,8 +825,8 @@ class BSRFormat : public TraceAccessor
/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

template <bool IS_WEAK, bool USE_CSR_IN_LINEAR_SYSTEM>
void applyDirichletByPenalty(Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, VariableNodeReal u)
template <bool IS_WEAK, bool USE_CSR_IN_LINEAR_SYSTEM, typename DataType>
void applyDirichletViaPenalty(Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, MeshVariableScalarRefT<Node, DataType> u)

Check warning on line 829 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L829

Added line #L829 was not covered by tests
{
auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());
auto nb_row = m_bsr_matrix.nbRow();
Expand Down Expand Up @@ -876,54 +877,62 @@ class BSRFormat : public TraceAccessor
/*---------------------------------------------------------------------------*/

template <bool ELIMINATE_BY_ROW_N_COLUMN, typename DataType>
void applyDirichletByElimination(std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, MeshVariableScalarRefT<Node, DataType> u, DoFLinearSystem& linear_system)
void applyDirichletViaElimination(std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, MeshVariableScalarRefT<Node, DataType> u, DoFLinearSystem* linear_system)

Check warning on line 880 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L880

Added line #L880 was not covered by tests
{
auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());
auto command = makeCommand(m_queue);
auto in_u = viewIn(command, u);
VariableDoFByte m_dof_elimination_info(VariableBuildInfo(m_dofs_on_nodes.dofFamily(), "BSRFormatDoFEliminationInfo"));
VariableDoFReal m_dof_elimination_value(VariableBuildInfo(m_dofs_on_nodes.dofFamily(), "BSRFormatDoFEliminationValue"));
auto in_out_dof_elimination_info = viewInOut(command, m_dof_elimination_info);
auto in_out_dof_elimination_value = viewInOut(command, m_dof_elimination_value);

Check warning on line 888 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L882-L888

Added lines #L882 - L888 were not covered by tests

NumArray<Accelerator::ItemVariableScalarInViewT<Node, Byte>, MDDim1> in_u_dirichlet_view_arr;
in_u_dirichlet_view_arr.resize(NB_DOF);

Check warning on line 891 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L890-L891

Added lines #L890 - L891 were not covered by tests
for (auto i = 0; i < NB_DOF; ++i)
in_u_dirichlet_view_arr[i] = viewIn(command, u_dirichlet_var_arr[i]);
auto in_u_dirichlet_arr = viewIn(command, in_u_dirichlet_view_arr);

Check warning on line 894 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L893-L894

Added lines #L893 - L894 were not covered by tests

/*auto eliminateInLinearSystem = [&linear_system] ARCCORE_HOST_DEVICE(DoFLocalId dof_lid, Real u_dirichlet) {
if constexpr (ELIMINATE_BY_ROW_N_COLUMN)
linear_system.eliminateRowColumn(dof_lid, u_dirichlet);
else
linear_system.eliminateRow(dof_lid, u_dirichlet);
};*/

command << RUNCOMMAND_ENUMERATE(NodeLocalId, node_lid, m_mesh->ownNodes())
{
for (auto i = 0; i < NB_DOF; ++i) {
if ((in_u_dirichlet_arr[i])[node_lid]) {
DoFLocalId dof_lid = node_dof.dofId(node_lid, i);
Real u_dirichlet = in_u[node_lid][i];
/*if (ELIMINATE_BY_ROW_N_COLUMN)
linear_system.eliminateRowColumn(dof_lid, u_dirichlet);
else
linear_system.eliminateRow(dof_lid, u_dirichlet);*/
if constexpr (std::is_floating_point<DataType>::value) {
command << RUNCOMMAND_ENUMERATE(NodeLocalId, node_lid, m_mesh->ownNodes())

Check warning on line 897 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L897

Added line #L897 was not covered by tests
{
for (auto i = 0; i < NB_DOF; ++i) {
if ((in_u_dirichlet_arr[i])[node_lid]) {
DoFLocalId dof_lid = node_dof.dofId(node_lid, i);
Real u_dirichlet = in_u[node_lid];

Check warning on line 902 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L901-L902

Added lines #L901 - L902 were not covered by tests
if (!ELIMINATE_BY_ROW_N_COLUMN) {
in_out_dof_elimination_info[dof_lid] = 1; // How to define constant on device code ? ELIMINATE_ROW;
in_out_dof_elimination_value[dof_lid] = u_dirichlet;

Check warning on line 905 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L904-L905

Added lines #L904 - L905 were not covered by tests
}
}
}
}
};
};
}
else
ARCANE_THROW(NotImplementedException, "BSRFormat(applyDirichletByElimination): Method not supported!");

linear_system->setEliminationArrays(m_dof_elimination_info, m_dof_elimination_value);

Check warning on line 914 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L914

Added line #L914 was not covered by tests
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

void assembleLinearOperator(Arcane::String method, Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_arr, VariableNodeReal u)
template <typename DataType>
void applyDirichlet(Arcane::String method, Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_arr, MeshVariableScalarRefT<Node, DataType> u, DoFLinearSystem* linear_system)

Check warning on line 921 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L921

Added line #L921 was not covered by tests
{

if (method == "Penalty") {
m_use_csr_in_linear_system ? applyDirichletByPenalty<false, true>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletByPenalty<false, false>(penalty, rhs_vect, u_dirichlet_arr, u);
m_use_csr_in_linear_system ? applyDirichletViaPenalty<false, true, DataType>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletViaPenalty<false, false, DataType>(penalty, rhs_vect, u_dirichlet_arr, u);

Check warning on line 926 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L925-L926

Added lines #L925 - L926 were not covered by tests
}
else if (method == "WeaKPenalty") {
m_use_csr_in_linear_system ? applyDirichletByPenalty<true, true>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletByPenalty<true, false>(penalty, rhs_vect, u_dirichlet_arr, u);
m_use_csr_in_linear_system ? applyDirichletViaPenalty<true, true, DataType>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletViaPenalty<true, false, DataType>(penalty, rhs_vect, u_dirichlet_arr, u);

Check warning on line 930 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L928-L930

Added lines #L928 - L930 were not covered by tests
}
else if (method == "RowElimination")
applyDirichletViaElimination<false, DataType>(u_dirichlet_arr, u, linear_system);
else if (method == "RowColumnElimination")
applyDirichletViaElimination<true, DataType>(u_dirichlet_arr, u, linear_system);

Check warning on line 935 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L932-L935

Added lines #L932 - L935 were not covered by tests
else
ARCANE_THROW(NotImplementedException, "BSRFormat(assembleLinearOperator): Method not supported!");

Check warning on line 937 in femutils/BSRFormat.h

View check run for this annotation

Codecov / codecov/patch

femutils/BSRFormat.h#L937

Added line #L937 was not covered by tests
}
Expand Down
11 changes: 11 additions & 0 deletions femutils/DoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ class SequentialDoFLinearSystemImpl
ARCANE_THROW(NotImplementedException, "");
}

void setEliminationArrays(VariableDoFByte& dof_elimination_info, VariableDoFReal& dof_elimination_value) override

Check warning on line 106 in femutils/DoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/DoFLinearSystem.cc#L106

Added line #L106 was not covered by tests
{
ARCANE_THROW(NotImplementedException, "");

Check warning on line 108 in femutils/DoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/DoFLinearSystem.cc#L108

Added line #L108 was not covered by tests
};

void solve() override
{
_fillRHSVector();
Expand Down Expand Up @@ -379,6 +384,12 @@ eliminateRowColumn(DoFLocalId row, Real value)
m_p->eliminateRowColumn(row, value);
}

void DoFLinearSystem::setEliminationArrays(VariableDoFByte& dof_elimination_info, VariableDoFReal& dof_elimination_value)

Check warning on line 387 in femutils/DoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/DoFLinearSystem.cc#L387

Added line #L387 was not covered by tests
{
_checkInit();
m_p->setEliminationArrays(dof_elimination_info, dof_elimination_value);

Check warning on line 390 in femutils/DoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/DoFLinearSystem.cc#L389-L390

Added lines #L389 - L390 were not covered by tests
};

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

Expand Down
3 changes: 3 additions & 0 deletions femutils/DoFLinearSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DoFLinearSystemImpl
virtual void matrixSetValue(DoFLocalId row, DoFLocalId column, Real value) = 0;
virtual void eliminateRow(DoFLocalId row, Real value) = 0;
virtual void eliminateRowColumn(DoFLocalId row, Real value) = 0;
virtual void setEliminationArrays(VariableDoFByte &dof_elimination_info, VariableDoFReal &dof_elimination_value) = 0;
virtual void solve() = 0;
virtual VariableDoFReal& solutionVariable() = 0;
virtual VariableDoFReal& rhsVariable() = 0;
Expand Down Expand Up @@ -173,6 +174,8 @@ class DoFLinearSystem
*/
void eliminateRowColumn(DoFLocalId row, Real value);

void setEliminationArrays(VariableDoFByte &dof_elimination_info, VariableDoFReal &dof_elimination_value);

/*!
* \brief Solve the current linear system.
*/
Expand Down
5 changes: 5 additions & 0 deletions femutils/HypreDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ class HypreDoFLinearSystemImpl
ARCANE_THROW(NotImplementedException, "");
}

void setEliminationArrays(VariableDoFByte& dof_elimination_info, VariableDoFReal& dof_elimination_value) override

Check warning on line 144 in femutils/HypreDoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/HypreDoFLinearSystem.cc#L144

Added line #L144 was not covered by tests
{
ARCANE_THROW(NotImplementedException, "");

Check warning on line 146 in femutils/HypreDoFLinearSystem.cc

View check run for this annotation

Codecov / codecov/patch

femutils/HypreDoFLinearSystem.cc#L146

Added line #L146 was not covered by tests
};

void solve() override;

VariableDoFReal& solutionVariable() override
Expand Down
14 changes: 13 additions & 1 deletion poisson/FemModule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,20 @@ _assembleLinearOperator(BSRMatrix<1>* bsr_matrix)
// Temporary variable to keep values for the RHS part of the linear system
VariableDoFReal& rhs_values(m_linear_system.rhsVariable());
rhs_values.fill(0.0);

auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());

if (m_use_csr) {
m_rhs_vect.resize(nbNode());
m_rhs_vect.fill(0.0);

Check warning on line 647 in poisson/FemModule.cc

View check run for this annotation

Codecov / codecov/patch

poisson/FemModule.cc#L646-L647

Added lines #L646 - L647 were not covered by tests

std::array<VariableNodeByte, 1> u_dirichlet_arr = { m_u_dirichlet };

Check warning on line 649 in poisson/FemModule.cc

View check run for this annotation

Codecov / codecov/patch

poisson/FemModule.cc#L649

Added line #L649 was not covered by tests

auto method = options()->enforceDirichletMethod();
m_bsr_format.applyDirichlet(options()->enforceDirichletMethod(), options()->penalty(), m_rhs_vect, u_dirichlet_arr, m_u, &m_linear_system);

Check warning on line 652 in poisson/FemModule.cc

View check run for this annotation

Codecov / codecov/patch

poisson/FemModule.cc#L651-L652

Added lines #L651 - L652 were not covered by tests

_translateRhs();

Check warning on line 654 in poisson/FemModule.cc

View check run for this annotation

Codecov / codecov/patch

poisson/FemModule.cc#L654

Added line #L654 was not covered by tests
}
else {
if (options()->enforceDirichletMethod() == "Penalty") {

Timer::Action timer_action(m_time_stats, "Penalty");
Expand Down Expand Up @@ -761,6 +772,7 @@ _assembleLinearOperator(BSRMatrix<1>* bsr_matrix)
<< " - RowElimination\n"
<< " - RowColumnElimination\n";
}
}

{
Timer::Action timer_action(m_time_stats, "ConstantSourceTermAssembly");
Expand Down
1 change: 1 addition & 0 deletions poisson/FemModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
#include <arcane/core/MeshUtils.h>
#include "BSRFormat.h"
#include "ArcaneFemFunctionsGpu.h"
#include "arcane/core/VariableTypedef.h"

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/
Expand Down

0 comments on commit 7950de9

Please sign in to comment.