Skip to content

Commit

Permalink
refactor: use Eigen vector type to store basic vars
Browse files Browse the repository at this point in the history
  • Loading branch information
mimizh2418 committed Oct 20, 2024
1 parent f34c985 commit 1307037
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
3 changes: 0 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include <iostream>

#include "suboptimal/solvers/linear/SimplexPivotRule.h"
#include "suboptimal/solvers/linear/SimplexSolverConfig.h"
#include "suboptimal/solvers/linear/simplex.h"

// using namespace std;
using namespace Eigen;

void solveBasicProblem() {
Expand Down
32 changes: 15 additions & 17 deletions src/solvers/linear/simplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

#include <Eigen/Core>
#include <algorithm>
#include <gsl/narrow>
#include <iostream>
#include <limits>
#include <span>
#include <stdexcept>
#include <vector>

Expand All @@ -19,7 +17,7 @@ using namespace suboptimal;
using namespace Eigen;
using namespace std;

int findPivotPosition(const MatrixXd& tableau, const span<Index> basic_vars, const SimplexPivotRule pivot_rule,
int findPivotPosition(const MatrixXd& tableau, const VectorX<Index>& basic_vars, const SimplexPivotRule pivot_rule,
Index& pivot_row, Index& pivot_col) {
pivot_col = -1;
pivot_row = -1;
Expand Down Expand Up @@ -63,7 +61,7 @@ int findPivotPosition(const MatrixXd& tableau, const span<Index> basic_vars, con
}
} else if (ratio == min_ratio && pivot_rule == SimplexPivotRule::kBland) {
// Bland's rule: select the basic variable with the smallest index
if (basic_vars[i] < basic_vars[pivot_row]) pivot_row = i;
if (basic_vars(i) < basic_vars(pivot_row)) pivot_row = i;
}
}
}
Expand All @@ -75,7 +73,7 @@ int findPivotPosition(const MatrixXd& tableau, const span<Index> basic_vars, con
return 0;
}

void pivot(MatrixXd& tableau, const span<Index> basic_vars, const Index pivot_row, const Index pivot_col) {
void pivot(MatrixXd& tableau, VectorX<Index>& basic_vars, const Index pivot_row, const Index pivot_col) {
const double pivot_element = tableau(pivot_row, pivot_col);
tableau.row(pivot_row) /= pivot_element;
tableau(pivot_row, pivot_col) = 1; // Account for floating point errors
Expand All @@ -86,20 +84,20 @@ void pivot(MatrixXd& tableau, const span<Index> basic_vars, const Index pivot_ro
}

// Update basic variables
basic_vars[pivot_row] = pivot_col;
basic_vars(pivot_row) = pivot_col;
}

vector<Index> findBasicVars(const MatrixXd& tableau) {
vector<Index> basic_vars(tableau.rows() - 1);
VectorX<Index> findBasicVars(const MatrixXd& tableau) {
VectorX<Index> basic_vars = VectorX<Index>::Zero(tableau.rows() - 1);
for (Index i = 0; i < tableau.cols(); i++) {
const auto col = tableau.col(i);
Index max_index;
if (col.lpNorm<1>() == 1 && col.maxCoeff(&max_index) == 1) basic_vars[max_index] = i;
if (col.lpNorm<1>() == 1 && col.maxCoeff(&max_index) == 1) basic_vars(max_index) = i;
}
return basic_vars;
}

SolverExitStatus solveTableau(MatrixXd& tableau, const span<Index> basic_vars, SolverProfiler& profiler,
SolverExitStatus solveTableau(MatrixXd& tableau, VectorX<Index>& basic_vars, SolverProfiler& profiler,
const SimplexSolverConfig& config) {
int num_iterations = 0;
auto exit_status = SolverExitStatus::kSuccess;
Expand Down Expand Up @@ -181,7 +179,7 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution,
tableau.row(tableau.rows() - 1) = auxiliary_objective;

// Find basic variables
vector<Index> basic_vars = findBasicVars(tableau);
VectorX<Index> basic_vars = findBasicVars(tableau);

// Perform simplex iterations to find initial BFS
SolverProfiler aux_profiler{};
Expand All @@ -206,13 +204,13 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution,

auto objective_row = tableau.row(tableau.rows() - 1);
objective_row.head(problem.numDecisionVars()) = -problem.getObjectiveCoeffs().transpose();
for (size_t i = 0; i < basic_vars.size(); i++) {
objective_row -= tableau.row(gsl::narrow<Index>(i)) * objective_row(basic_vars[i]);
for (Index i = 0; i < basic_vars.size(); i++) {
objective_row -= tableau.row(i) * objective_row(basic_vars(i));
}
}

// Initialize basic variables
vector<Index> basic_vars = findBasicVars(tableau);
VectorX<Index> basic_vars = findBasicVars(tableau);

// Perform simplex iterations
SolverProfiler profiler{};
Expand All @@ -230,9 +228,9 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution,
// Extract solution and objective value
solution = VectorXd::Zero(problem.numDecisionVars());
const VectorXd rhs = tableau.col(tableau.cols() - 1);
for (size_t i = 0; i < basic_vars.size(); i++) {
if (const Index var_index = basic_vars[i]; var_index < problem.numDecisionVars()) {
solution(var_index) = rhs(gsl::narrow<Index>(i));
for (Index i = 0; i < basic_vars.size(); i++) {
if (const Index var_index = basic_vars(i); var_index < problem.numDecisionVars()) {
solution(var_index) = rhs(i);
}
}
objective_value = tableau(tableau.rows() - 1, tableau.cols() - 1);
Expand Down

0 comments on commit 1307037

Please sign in to comment.