diff --git a/src/main.cpp b/src/main.cpp index 9016a1b..db484a3 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,10 +1,7 @@ -#include - #include "suboptimal/solvers/linear/SimplexPivotRule.h" #include "suboptimal/solvers/linear/SimplexSolverConfig.h" #include "suboptimal/solvers/linear/simplex.h" -// using namespace std; using namespace Eigen; void solveBasicProblem() { diff --git a/src/solvers/linear/simplex.cpp b/src/solvers/linear/simplex.cpp index ffd3377..c4b0141 100644 --- a/src/solvers/linear/simplex.cpp +++ b/src/solvers/linear/simplex.cpp @@ -2,10 +2,8 @@ #include #include -#include #include #include -#include #include #include @@ -19,7 +17,7 @@ using namespace suboptimal; using namespace Eigen; using namespace std; -int findPivotPosition(const MatrixXd& tableau, const span basic_vars, const SimplexPivotRule pivot_rule, +int findPivotPosition(const MatrixXd& tableau, const VectorX& basic_vars, const SimplexPivotRule pivot_rule, Index& pivot_row, Index& pivot_col) { pivot_col = -1; pivot_row = -1; @@ -63,7 +61,7 @@ int findPivotPosition(const MatrixXd& tableau, const span basic_vars, con } } else if (ratio == min_ratio && pivot_rule == SimplexPivotRule::kBland) { // Bland's rule: select the basic variable with the smallest index - if (basic_vars[i] < basic_vars[pivot_row]) pivot_row = i; + if (basic_vars(i) < basic_vars(pivot_row)) pivot_row = i; } } } @@ -75,7 +73,7 @@ int findPivotPosition(const MatrixXd& tableau, const span basic_vars, con return 0; } -void pivot(MatrixXd& tableau, const span basic_vars, const Index pivot_row, const Index pivot_col) { +void pivot(MatrixXd& tableau, VectorX& basic_vars, const Index pivot_row, const Index pivot_col) { const double pivot_element = tableau(pivot_row, pivot_col); tableau.row(pivot_row) /= pivot_element; tableau(pivot_row, pivot_col) = 1; // Account for floating point errors @@ -86,20 +84,20 @@ void pivot(MatrixXd& tableau, const span basic_vars, const Index pivot_ro } // Update basic variables - basic_vars[pivot_row] = pivot_col; + basic_vars(pivot_row) = pivot_col; } -vector findBasicVars(const MatrixXd& tableau) { - vector basic_vars(tableau.rows() - 1); +VectorX findBasicVars(const MatrixXd& tableau) { + VectorX basic_vars = VectorX::Zero(tableau.rows() - 1); for (Index i = 0; i < tableau.cols(); i++) { const auto col = tableau.col(i); Index max_index; - if (col.lpNorm<1>() == 1 && col.maxCoeff(&max_index) == 1) basic_vars[max_index] = i; + if (col.lpNorm<1>() == 1 && col.maxCoeff(&max_index) == 1) basic_vars(max_index) = i; } return basic_vars; } -SolverExitStatus solveTableau(MatrixXd& tableau, const span basic_vars, SolverProfiler& profiler, +SolverExitStatus solveTableau(MatrixXd& tableau, VectorX& basic_vars, SolverProfiler& profiler, const SimplexSolverConfig& config) { int num_iterations = 0; auto exit_status = SolverExitStatus::kSuccess; @@ -181,7 +179,7 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution, tableau.row(tableau.rows() - 1) = auxiliary_objective; // Find basic variables - vector basic_vars = findBasicVars(tableau); + VectorX basic_vars = findBasicVars(tableau); // Perform simplex iterations to find initial BFS SolverProfiler aux_profiler{}; @@ -206,13 +204,13 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution, auto objective_row = tableau.row(tableau.rows() - 1); objective_row.head(problem.numDecisionVars()) = -problem.getObjectiveCoeffs().transpose(); - for (size_t i = 0; i < basic_vars.size(); i++) { - objective_row -= tableau.row(gsl::narrow(i)) * objective_row(basic_vars[i]); + for (Index i = 0; i < basic_vars.size(); i++) { + objective_row -= tableau.row(i) * objective_row(basic_vars(i)); } } // Initialize basic variables - vector basic_vars = findBasicVars(tableau); + VectorX basic_vars = findBasicVars(tableau); // Perform simplex iterations SolverProfiler profiler{}; @@ -230,9 +228,9 @@ SolverExitStatus solveSimplex(const LinearProblem& problem, VectorXd& solution, // Extract solution and objective value solution = VectorXd::Zero(problem.numDecisionVars()); const VectorXd rhs = tableau.col(tableau.cols() - 1); - for (size_t i = 0; i < basic_vars.size(); i++) { - if (const Index var_index = basic_vars[i]; var_index < problem.numDecisionVars()) { - solution(var_index) = rhs(gsl::narrow(i)); + for (Index i = 0; i < basic_vars.size(); i++) { + if (const Index var_index = basic_vars(i); var_index < problem.numDecisionVars()) { + solution(var_index) = rhs(i); } } objective_value = tableau(tableau.rows() - 1, tableau.cols() - 1);