Skip to content

Commit

Permalink
feat[femutils]: apply dirichlet by penalty in femutils using gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
toutane committed Dec 19, 2024
1 parent 6852eca commit 2355ba9
Showing 1 changed file with 146 additions and 9 deletions.
155 changes: 146 additions & 9 deletions femutils/BSRFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

#include <array>
#include <ios>
#include <iomanip>

Expand Down Expand Up @@ -49,19 +50,56 @@
#include <arcane/accelerator/RunCommandLoop.h>
#include <arcane/accelerator/core/RunQueue.h>
#include <arcane/accelerator/NumArrayViews.h>
#include "arcane/accelerator/VariableViews.h"
#include <arcane/accelerator/Atomic.h>
#include <arcane/accelerator/Scan.h>

#include "DoFLinearSystem.h"
#include "CsrFormatMatrix.h"
#include "FemDoFsOnNodes.h"
#include "arcane/accelerator/ViewsCommon.h"
#include "arcane/core/ItemTypes.h"
#include "arcane/core/ItemLocalId.h"
#include "arcane/core/VariableTypedef.h"

#include "arcane/accelerator/core/ViewBuildInfo.h"
#include "arcane/accelerator/AcceleratorGlobal.h"
#include "arcane/accelerator/ViewsCommon.h"

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

namespace Arcane::FemUtils
{

template <int BLOCK_SIZE, bool ORDER_VALUES_PER_BLOCK>
ARCCORE_HOST_DEVICE Int32 findValueIndexHostDevice(DoFLocalId row, DoFLocalId col, Accelerator::NumArrayView<Accelerator::DataViewGetter<Int32>, MDDim1, DefaultLayout> in_columns, Accelerator::NumArrayView<Accelerator::DataViewGetter<Int32>, MDDim1, DefaultLayout> in_row_index, Accelerator::NumArrayView<Accelerator::DataViewGetter<Int32>, MDDim1, DefaultLayout> in_nb_nz_per_row, Int32 nb_col, Int32 nb_row)
{
auto block_row = row / BLOCK_SIZE;
auto block_col = col / BLOCK_SIZE;

auto block_start = in_row_index[block_row];
auto block_end = (block_row == nb_row - 1) ? nb_col : in_row_index[block_row + 1];

auto row_offset = row % BLOCK_SIZE;
auto col_offset = col % BLOCK_SIZE;

constexpr int BLOCK_SIZE_SQ = BLOCK_SIZE * BLOCK_SIZE;
auto block_start_in_value = block_start * BLOCK_SIZE_SQ;
auto col_index = 0;
while (block_start < block_end) {
if (in_columns[block_start] == block_col) {
if constexpr (!ORDER_VALUES_PER_BLOCK)
return block_start_in_value + (BLOCK_SIZE * col_index) + (row_offset * BLOCK_SIZE * in_nb_nz_per_row[block_row]);
else
return (block_start * BLOCK_SIZE_SQ) + ((row_offset * BLOCK_SIZE) + col_offset);
}
++block_start;
++col_index;
}
return -1;
}

/*---------------------------------------------------------------------------*/
/**
* @brief A class representing a Block Sparse Row (BSR) matrix.
Expand Down Expand Up @@ -786,18 +824,116 @@ class BSRFormat : public TraceAccessor
/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

BSRMatrix<NB_DOF>& matrix()
template <bool IS_WEAK, bool USE_CSR_IN_LINEAR_SYSTEM>
void applyDirichletByPenalty(Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, VariableNodeReal u)
{
return m_bsr_matrix;
};
void resetMatrixValues()
auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());
auto nb_row = m_bsr_matrix.nbRow();
auto nb_col = m_bsr_matrix.nbCol();

auto command = makeCommand(m_queue);
auto in_out_values = viewInOut(command, m_bsr_matrix.values());
auto in_columns = viewIn(command, m_bsr_matrix.columns());
auto in_row_index = viewIn(command, m_bsr_matrix.rowIndex());
auto in_nb_nz_per_row = viewIn(command, m_bsr_matrix.nbNzPerRow());
auto in_out_rhs_vect = viewInOut(command, rhs_vect);
auto in_u = viewIn(command, u);

NumArray<Accelerator::ItemVariableScalarInViewT<Node, Byte>, MDDim1> in_u_dirichlet_view_arr;
in_u_dirichlet_view_arr.resize(NB_DOF);
for (auto i = 0; i < NB_DOF; ++i)
in_u_dirichlet_view_arr[i] = viewIn(command, u_dirichlet_var_arr[i]);
auto in_u_dirichlet_arr = viewIn(command, in_u_dirichlet_view_arr);

auto updateValue = [in_out_values, penalty] ARCCORE_HOST_DEVICE(Int32 value_idx) {
if constexpr (IS_WEAK)
in_out_values[value_idx] += penalty;
else
in_out_values[value_idx] = penalty;
};

auto findValueIndex = [in_columns, in_row_index, in_nb_nz_per_row, nb_col, nb_row] ARCCORE_HOST_DEVICE(DoFLocalId dof_lid) {
if constexpr (USE_CSR_IN_LINEAR_SYSTEM)
return findValueIndexHostDevice<NB_DOF, false>(dof_lid, dof_lid, in_columns, in_row_index, in_nb_nz_per_row, nb_col, nb_row);
else
return findValueIndexHostDevice<NB_DOF, true>(dof_lid, dof_lid, in_columns, in_row_index, in_nb_nz_per_row, nb_col, nb_row);
};

command << RUNCOMMAND_ENUMERATE(NodeLocalId, node_lid, m_mesh->ownNodes())
{
for (auto i = 0; i < NB_DOF; ++i) {
if ((in_u_dirichlet_arr[i])[node_lid]) {
DoFLocalId dof_lid = node_dof.dofId(node_lid, i);
auto value_idx = findValueIndex(dof_lid);
updateValue(value_idx);
in_out_rhs_vect[dof_lid] = in_u[node_lid] * penalty;
}
}
};
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

template <bool ELIMINATE_BY_ROW_N_COLUMN, typename DataType>
void applyDirichletByElimination(std::array<VariableNodeByte, NB_DOF>& u_dirichlet_var_arr, MeshVariableScalarRefT<Node, DataType> u, DoFLinearSystem& linear_system)
{
m_bsr_matrix.values().fill(0, m_queue);
};
void dumpMatrix(std::string filename) const
auto node_dof(m_dofs_on_nodes.nodeDoFConnectivityView());
auto command = makeCommand(m_queue);
auto in_u = viewIn(command, u);

NumArray<Accelerator::ItemVariableScalarInViewT<Node, Byte>, MDDim1> in_u_dirichlet_view_arr;
in_u_dirichlet_view_arr.resize(NB_DOF);
for (auto i = 0; i < NB_DOF; ++i)
in_u_dirichlet_view_arr[i] = viewIn(command, u_dirichlet_var_arr[i]);
auto in_u_dirichlet_arr = viewIn(command, in_u_dirichlet_view_arr);

/*auto eliminateInLinearSystem = [&linear_system] ARCCORE_HOST_DEVICE(DoFLocalId dof_lid, Real u_dirichlet) {
if constexpr (ELIMINATE_BY_ROW_N_COLUMN)
linear_system.eliminateRowColumn(dof_lid, u_dirichlet);
else
linear_system.eliminateRow(dof_lid, u_dirichlet);
};*/

command << RUNCOMMAND_ENUMERATE(NodeLocalId, node_lid, m_mesh->ownNodes())
{
for (auto i = 0; i < NB_DOF; ++i) {
if ((in_u_dirichlet_arr[i])[node_lid]) {
DoFLocalId dof_lid = node_dof.dofId(node_lid, i);
Real u_dirichlet = in_u[node_lid][i];
/*if (ELIMINATE_BY_ROW_N_COLUMN)
linear_system.eliminateRowColumn(dof_lid, u_dirichlet);
else
linear_system.eliminateRow(dof_lid, u_dirichlet);*/
}
}
};
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

void assembleLinearOperator(Arcane::String method, Real penalty, NumArray<Real, MDDim1>& rhs_vect, std::array<VariableNodeByte, NB_DOF>& u_dirichlet_arr, VariableNodeReal u)
{
m_bsr_matrix.dump(filename);
};

if (method == "Penalty") {
m_use_csr_in_linear_system ? applyDirichletByPenalty<false, true>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletByPenalty<false, false>(penalty, rhs_vect, u_dirichlet_arr, u);
}
else if (method == "WeaKPenalty") {
m_use_csr_in_linear_system ? applyDirichletByPenalty<true, true>(penalty, rhs_vect, u_dirichlet_arr, u)
: applyDirichletByPenalty<true, false>(penalty, rhs_vect, u_dirichlet_arr, u);
}
else
ARCANE_THROW(NotImplementedException, "BSRFormat(assembleLinearOperator): Method not supported!");
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

BSRMatrix<NB_DOF>& matrix() { return m_bsr_matrix; };
void resetMatrixValues() { m_bsr_matrix.values().fill(0, m_queue); };
void dumpMatrix(std::string filename) const { m_bsr_matrix.dump(filename); };

private:

Expand All @@ -810,6 +946,7 @@ class BSRFormat : public TraceAccessor
RunQueue& m_queue;
const FemDoFsOnNodes& m_dofs_on_nodes;
};

}; // namespace Arcane::FemUtils

#endif // ! BSRFORMAT_H

0 comments on commit 2355ba9

Please sign in to comment.