Skip to content

Commit

Permalink
[Trifocal+P2Pt] custom lu solid
Browse files Browse the repository at this point in the history
  • Loading branch information
rfabbri committed Oct 24, 2023
1 parent a7e7e31 commit e060188
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 87 deletions.
2 changes: 1 addition & 1 deletion minus/minus.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class minus_core { // fully static, not to be instantiated - just used for templ
static constexpr unsigned NVE2 = f::nve*f::nve;
static void evaluate_Hxt(const C<F> * __restrict x /*x, t*/, const C<F> * __restrict params, C<F> * __restrict y /*HxH*/);
static void evaluate_HxH(const C<F> * __restrict x /*x and t*/, const C<F> * __restrict params, C<F> * __restrict y /*HxH*/);
static void lsolve(Eigen::Map<const Eigen::Matrix<C<F>, f::nve, f::nve>,Eigen::Aligned> &matrix, Eigen::Map<const Eigen::Matrix<C<F>, f::nve, 1>, Eigen::Aligned > &b, Eigen::Map<Eigen::Matrix<C<F>, f::nve, 1>,Eigen::Aligned> &x);
static void lsolve(Eigen::Map<Eigen::Matrix<C<F>, f::nve, f::nve>,Eigen::Aligned> &matrix, Eigen::Map<const Eigen::Matrix<C<F>, f::nve, 1>, Eigen::Aligned > &b, Eigen::Map<Eigen::Matrix<C<F>, f::nve, 1>,Eigen::Aligned> &x);
};

// TODO: make these static
Expand Down
159 changes: 73 additions & 86 deletions minus/minus.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -16,97 +16,84 @@ namespace MiNuS {

using namespace Eigen; // only used for linear solve

// TODO: parameters restrict
// Specific to Chicago
template <problem P, typename F>
__attribute__((always_inline)) inline void
minus_core<P, F>::
lsolve(
Map<const Matrix<C<F>, f::nve, f::nve>,Aligned> & __restrict matrix,
Map<Matrix<C<F>, f::nve, f::nve>,Aligned> & __restrict m,
Map<const Matrix<C<F>, f::nve, 1>, Aligned > & __restrict b,
Map<Matrix<C<F>, f::nve, 1>,Aligned> & __restrict x)
{
typedef Matrix<C<F>, f::nve, f::nve> MatrixType;
typedef PermutationMatrix<f::nve, f::nve> PermutationType;
typedef Transpositions<f::nve, f::nve> TranspositionType;

MatrixType m(matrix); // matrix holding LU together TODO: in-place
PermutationType m_p;

TranspositionType m_rowsTranspositions;
// XXX modified by Fabbri to suit Chicago problem
// static __attribute__((always_inline)) void unblocked_lu(
// MatrixType &m,
typename TranspositionType::StorageIndex* row_transpositions = &m_rowsTranspositions.coeffRef(0);
static constexpr Index rows = f::nve;
//Index first_zero_pivot = -1;
for(Index k = 0; k < 14; ++k) {
Index rrows = rows-k-1;

Index row_of_biggest_in_col(k);
F biggest_in_corner = std::norm(m(k,k));// std::norm(m.coeff(k,k));
for (unsigned j=rows-1; j != k; --j) {
F tmp;
if ((tmp = std::norm(m(j,k))) > biggest_in_corner*1000) {
biggest_in_corner = tmp;
row_of_biggest_in_col = j;
break;
}
}

row_transpositions[k] = typename TranspositionType::StorageIndex(row_of_biggest_in_col);

//if (biggest_in_corner != Score(0)) {
if (k != row_of_biggest_in_col) {
m.row(k).swap(m.row(row_of_biggest_in_col));
}

m.col(k).tail(rrows) /= m(k,k);
// } else if (first_zero_pivot==-1)
// the pivot is exactly zero, we record the index of the first pivot which is exactly 0,
// and continue the factorization such we still have A = PLU
// first_zero_pivot = k;

if (k < rows-1)
m.bottomRightCorner(rrows,rrows).noalias() -= m.col(k).tail(rrows) * m.row(k).tail(rrows);
}

m_p = m_rowsTranspositions;

// Step 1
x = m_p * b;

// TODO: use block indexing and std::vector-std::vector multiplication
x(1) -= m(1,0)*x(0);
x(2) -= m(2,0)*x(0)+m(2,1)*x(1);
x(3) -= m(3,0)*x(0)+m(3,1)*x(1)+m(3,2)*x(2);
x(4) -= m(4,0)*x(0)+m(4,1)*x(1)+m(4,2)*x(2)+m(4,3)*x(3);
x(5) -= m(5,0)*x(0)+m(5,1)*x(1)+m(5,2)*x(2)+m(5,3)*x(3)+m(5,4)*x(4);
x(6) -= m(6,0)*x(0)+m(6,1)*x(1)+m(6,2)*x(2)+m(6,3)*x(3)+m(6,4)*x(4)+m(6,5)*x(5);
x(7) -= m(7,0)*x(0)+m(7,1)*x(1)+m(7,2)*x(2)+m(7,3)*x(3)+m(7,4)*x(4)+m(7,5)*x(5)+m(7,6)*x(6);
x(8) -= m(8,0)*x(0)+m(8,1)*x(1)+m(8,2)*x(2)+m(8,3)*x(3)+m(8,4)*x(4)+m(8,5)*x(5)+m(8,6)*x(6)+m(8,7)*x(7);
x(9) -= m(9,0)*x(0)+m(9,1)*x(1)+m(9,2)*x(2)+m(9,3)*x(3)+m(9,4)*x(4)+m(9,5)*x(5)+m(9,6)*x(6)+m(9,7)*x(7)+m(9,8)*x(8);
x(10) -= m(10,0)*x(0)+m(10,1)*x(1)+m(10,2)*x(2)+m(10,3)*x(3)+m(10,4)*x(4)+m(10,5)*x(5)+m(10,6)*x(6)+m(10,7)*x(7)+m(10,8)*x(8)+m(10,9)*x(9);
x(11) -= m(11,0)*x(0)+m(11,1)*x(1)+m(11,2)*x(2)+m(11,3)*x(3)+m(11,4)*x(4)+m(11,5)*x(5)+m(11,6)*x(6)+m(11,7)*x(7)+m(11,8)*x(8)+m(11,9)*x(9)+m(11,10)*x(10);
x(12) -= m(12,0)*x(0)+m(12,1)*x(1)+m(12,2)*x(2)+m(12,3)*x(3)+m(12,4)*x(4)+m(12,5)*x(5)+m(12,6)*x(6)+m(12,7)*x(7)+m(12,8)*x(8)+m(12,9)*x(9)+m(12,10)*x(10)+m(12,11)*x(11);
x(13) -= (m(13,0)*x(0)+m(13,1)*x(1)+m(13,2)*x(2)+m(13,3)*x(3)+m(13,4)*x(4)+m(13,5)*x(5)+m(13,6)*x(6)+m(13,7)*x(7)+m(13,8)*x(8)+m(13,9)*x(9)+m(13,10)*x(10)+m(13,11)*x(11)+m(13,12)*x(12));

// Step 2
//m.template triangularView<UnitLower>().solveInPlace(x);

x(13) /= m(13,13);
x(12) -= m(12,13)*x(13); x(12) /= m(12,12);
x(11) -= (m(11,12)*x(12)+m(11,13)*x(13)); x(11) /= m(11,11);
x(10) -= (m(10,11)*x(11)+m(10,12)*x(12)+m(10,13)*x(13)); x(10) /= m(10,10);
x(9) -= (m(9,10)*x(10)+m(9,11)*x(11)+m(9,12)*x(12)+m(9,13)*x(13)); x(9) /= m(9,9);
x(8) -= (m(8,9)*x(9)+m(8,10)*x(10)+m(8,11)*x(11)+m(8,12)*x(12)+m(8,13)*x(13)); x(8) /= m(8,8);
x(7) -= (m(7,8)*x(8)+m(7,9)*x(9)+m(7,10)*x(10)+m(7,11)*x(11)+m(7,12)*x(12)+m(7,13)*x(13)); x(7) /= m(7,7);
x(6) -= (m(6,7)*x(7)+m(6,8)*x(8)+m(6,9)*x(9)+m(6,10)*x(10)+m(6,11)*x(11)+m(6,12)*x(12)+m(6,13)*x(13)); x(6) /= m(6,6);
x(5) -= (m(5,6)*x(6)+m(5,7)*x(7)+m(5,8)*x(8)+m(5,9)*x(9)+m(5,10)*x(10)+m(5,11)*x(11)+m(5,12)*x(12)+m(5,13)*x(13)); x(5) /= m(5,5);
x(4) -= (m(4,5)*x(5)+m(4,6)*x(6)+m(4,7)*x(7)+m(4,8)*x(8)+m(4,9)*x(9)+m(4,10)*x(10)+m(4,11)*x(11)+m(4,12)*x(12)+m(4,13)*x(13)); x(4) /= m(4,4);
x(3) -= (m(3,4)*x(4)+m(3,5)*x(5)+m(3,6)*x(6)+m(3,7)*x(7)+m(3,8)*x(8)+m(3,9)*x(9)+m(3,10)*x(10)+m(3,11)*x(11)+m(3,12)*x(12)+m(3,13)*x(13)); x(3) /= m(3,3);
x(2) -= (m(2,3)*x(3)+m(2,4)*x(4)+m(2,5)*x(5)+m(2,6)*x(6)+m(2,7)*x(7)+m(2,8)*x(8)+m(2,9)*x(9)+m(2,10)*x(10)+m(2,11)*x(11)+m(2,12)*x(12)+m(2,13)*x(13)); x(2) /= m(2,2);
x(1) -= (m(1,2)*x(2)+m(1,3)*x(3)+m(1,4)*x(4)+m(1,5)*x(5)+m(1,6)*x(6)+m(1,7)*x(7)+m(1,8)*x(8)+m(1,9)*x(9)+m(1,10)*x(10)+m(1,11)*x(11)+m(1,12)*x(12)+m(1,13)*x(13)); x(1) /= m(1,1);
x(0) -= (m(0,1)*x(1)+m(0,2)*x(2)+m(0,3)*x(3)+m(0,4)*x(4)+m(0,5)*x(5)+m(0,6)*x(6)+m(0,7)*x(7)+m(0,8)*x(8)+m(0,9)*x(9)+m(0,10)*x(10)+m(0,11)*x(11)+m(0,12)*x(12)+m(0,13)*x(13)); x(0) /= m(0,0);
typedef Matrix<C<F>, f::nve, f::nve> MatrixType;
typedef PermutationMatrix<f::nve, f::nve> PermutationType;
typedef Transpositions<f::nve, f::nve> TranspositionType;
PermutationType m_p;
TranspositionType m_rowsTranspositions;
typename TranspositionType::StorageIndex* row_transpositions = &m_rowsTranspositions.coeffRef(0);
static constexpr Index rows = f::nve;
for(Index k = 0; k < f::nve; ++k) {
Index rrows = rows-k-1;

Index row_of_biggest_in_col(k);
F biggest_in_corner = std::norm(m(k,k));
for (unsigned j=rows-1; j != k; --j) {
F tmp;
if ((tmp = std::norm(m(j,k))) > biggest_in_corner*1000) {
biggest_in_corner = tmp;
row_of_biggest_in_col = j;
break;
}
}

row_transpositions[k] = typename TranspositionType::StorageIndex(row_of_biggest_in_col);

if (k != row_of_biggest_in_col)
m.row(k).swap(m.row(row_of_biggest_in_col));

m.col(k).tail(rrows) /= m(k,k);

if (k < rows-1)
m.bottomRightCorner(rrows,rrows).noalias() -= m.col(k).tail(rrows) * m.row(k).tail(rrows);
}

m_p = m_rowsTranspositions;

// Step 1
x = m_p * b;

// TODO: use block indexing and std::vector-std::vector multiplication
x(1) -= m(1,0)*x(0);
x(2) -= m(2,0)*x(0)+m(2,1)*x(1);
x(3) -= m(3,0)*x(0)+m(3,1)*x(1)+m(3,2)*x(2);
x(4) -= m(4,0)*x(0)+m(4,1)*x(1)+m(4,2)*x(2)+m(4,3)*x(3);
x(5) -= m(5,0)*x(0)+m(5,1)*x(1)+m(5,2)*x(2)+m(5,3)*x(3)+m(5,4)*x(4);
x(6) -= m(6,0)*x(0)+m(6,1)*x(1)+m(6,2)*x(2)+m(6,3)*x(3)+m(6,4)*x(4)+m(6,5)*x(5);
x(7) -= m(7,0)*x(0)+m(7,1)*x(1)+m(7,2)*x(2)+m(7,3)*x(3)+m(7,4)*x(4)+m(7,5)*x(5)+m(7,6)*x(6);
x(8) -= m(8,0)*x(0)+m(8,1)*x(1)+m(8,2)*x(2)+m(8,3)*x(3)+m(8,4)*x(4)+m(8,5)*x(5)+m(8,6)*x(6)+m(8,7)*x(7);
x(9) -= m(9,0)*x(0)+m(9,1)*x(1)+m(9,2)*x(2)+m(9,3)*x(3)+m(9,4)*x(4)+m(9,5)*x(5)+m(9,6)*x(6)+m(9,7)*x(7)+m(9,8)*x(8);
x(10) -= m(10,0)*x(0)+m(10,1)*x(1)+m(10,2)*x(2)+m(10,3)*x(3)+m(10,4)*x(4)+m(10,5)*x(5)+m(10,6)*x(6)+m(10,7)*x(7)+m(10,8)*x(8)+m(10,9)*x(9);
x(11) -= m(11,0)*x(0)+m(11,1)*x(1)+m(11,2)*x(2)+m(11,3)*x(3)+m(11,4)*x(4)+m(11,5)*x(5)+m(11,6)*x(6)+m(11,7)*x(7)+m(11,8)*x(8)+m(11,9)*x(9)+m(11,10)*x(10);
x(12) -= m(12,0)*x(0)+m(12,1)*x(1)+m(12,2)*x(2)+m(12,3)*x(3)+m(12,4)*x(4)+m(12,5)*x(5)+m(12,6)*x(6)+m(12,7)*x(7)+m(12,8)*x(8)+m(12,9)*x(9)+m(12,10)*x(10)+m(12,11)*x(11);
x(13) -= (m(13,0)*x(0)+m(13,1)*x(1)+m(13,2)*x(2)+m(13,3)*x(3)+m(13,4)*x(4)+m(13,5)*x(5)+m(13,6)*x(6)+m(13,7)*x(7)+m(13,8)*x(8)+m(13,9)*x(9)+m(13,10)*x(10)+m(13,11)*x(11)+m(13,12)*x(12));

// Step 2
//m.template triangularView<UnitLower>().solveInPlace(x);

x(13) /= m(13,13);
x(12) -= m(12,13)*x(13); x(12) /= m(12,12);
x(11) -= (m(11,12)*x(12)+m(11,13)*x(13)); x(11) /= m(11,11);
x(10) -= (m(10,11)*x(11)+m(10,12)*x(12)+m(10,13)*x(13)); x(10) /= m(10,10);
x(9) -= (m(9,10)*x(10)+m(9,11)*x(11)+m(9,12)*x(12)+m(9,13)*x(13)); x(9) /= m(9,9);
x(8) -= (m(8,9)*x(9)+m(8,10)*x(10)+m(8,11)*x(11)+m(8,12)*x(12)+m(8,13)*x(13)); x(8) /= m(8,8);
x(7) -= (m(7,8)*x(8)+m(7,9)*x(9)+m(7,10)*x(10)+m(7,11)*x(11)+m(7,12)*x(12)+m(7,13)*x(13)); x(7) /= m(7,7);
x(6) -= (m(6,7)*x(7)+m(6,8)*x(8)+m(6,9)*x(9)+m(6,10)*x(10)+m(6,11)*x(11)+m(6,12)*x(12)+m(6,13)*x(13)); x(6) /= m(6,6);
x(5) -= (m(5,6)*x(6)+m(5,7)*x(7)+m(5,8)*x(8)+m(5,9)*x(9)+m(5,10)*x(10)+m(5,11)*x(11)+m(5,12)*x(12)+m(5,13)*x(13)); x(5) /= m(5,5);
x(4) -= (m(4,5)*x(5)+m(4,6)*x(6)+m(4,7)*x(7)+m(4,8)*x(8)+m(4,9)*x(9)+m(4,10)*x(10)+m(4,11)*x(11)+m(4,12)*x(12)+m(4,13)*x(13)); x(4) /= m(4,4);
x(3) -= (m(3,4)*x(4)+m(3,5)*x(5)+m(3,6)*x(6)+m(3,7)*x(7)+m(3,8)*x(8)+m(3,9)*x(9)+m(3,10)*x(10)+m(3,11)*x(11)+m(3,12)*x(12)+m(3,13)*x(13)); x(3) /= m(3,3);
x(2) -= (m(2,3)*x(3)+m(2,4)*x(4)+m(2,5)*x(5)+m(2,6)*x(6)+m(2,7)*x(7)+m(2,8)*x(8)+m(2,9)*x(9)+m(2,10)*x(10)+m(2,11)*x(11)+m(2,12)*x(12)+m(2,13)*x(13)); x(2) /= m(2,2);
x(1) -= (m(1,2)*x(2)+m(1,3)*x(3)+m(1,4)*x(4)+m(1,5)*x(5)+m(1,6)*x(6)+m(1,7)*x(7)+m(1,8)*x(8)+m(1,9)*x(9)+m(1,10)*x(10)+m(1,11)*x(11)+m(1,12)*x(12)+m(1,13)*x(13)); x(1) /= m(1,1);
x(0) -= (m(0,1)*x(1)+m(0,2)*x(2)+m(0,3)*x(3)+m(0,4)*x(4)+m(0,5)*x(5)+m(0,6)*x(6)+m(0,7)*x(7)+m(0,8)*x(8)+m(0,9)*x(9)+m(0,10)*x(10)+m(0,11)*x(11)+m(0,12)*x(12)+m(0,13)*x(13)); x(0) /= m(0,0);
}


Expand Down Expand Up @@ -139,7 +126,7 @@ track(const track_settings &s, const C<F> s_sols[f::nve*f::nsols], const C<F> pa
Map<Matrix<C<F>, f::nve, 1>,Aligned> dxi_eigen(dxi);
Map<Matrix<C<F>, f::nve, 1>,Aligned> dx4_eigen(dx4);
Map<Matrix<C<F>, f::nve, 1>,Aligned> &dx_eigen = dx4_eigen;
Map<const Matrix<C<F>, f::nve, f::nve>,Aligned> AA((C<F> *)Hxt,f::nve,f::nve); // accessors for the data
Map<Matrix<C<F>, f::nve, f::nve>,Aligned> AA((C<F> *)Hxt,f::nve,f::nve); // accessors for the data
Map<const Matrix<C<F>, f::nve, 1>, Aligned > bb(RHS);
static constexpr F the_smallest_number = 1e-13; // XXX BENCHMARK THIS
typedef minus_array<f::nve,F> v; typedef minus_array<NVEPLUS1,F> vp;
Expand Down Expand Up @@ -225,7 +212,7 @@ track(const track_settings &s, const C<F> s_sols[f::nve*f::nsols], const C<F> pa
do {
++n_corr_steps;
evaluate_HxH(x1t1, params, HxH);
lsolve(AA, bb, dx_eigen); // TODO: always same AA, do not redo LU
lsolve(AA, bb, dx_eigen);
v::add_to_self(x1t1, dx);
is_successful = v::norm2(dx) < s.epsilon2_ * v::norm2(x1t1); // |dx|^2/|x1|^2 < eps2
} while (!is_successful && n_corr_steps < s.max_corr_steps_);
Expand Down

0 comments on commit e060188

Please sign in to comment.