Skip to content

Commit

Permalink
Reorder listMatrix; Some renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
GollokG committed Dec 23, 2024
1 parent b8ba2bb commit 60f7b67
Showing 1 changed file with 100 additions and 104 deletions.
204 changes: 100 additions & 104 deletions src/cdomains/affineEquality/sparseImplementation/listMatrix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,42 @@ module ListMatrix: AbstractMatrix =
let show x =
List.fold_left (^) "" (List.map (fun x -> (V.show x)) x)

let copy m = m
(* Lists are immutable, so this should suffice? A.t is mutuable currently, but is treated like its not in ArrayMatrix*)

let equal m1 m2 = Timing.wrap "equal" (equal m1) m2

let copy m =
Timing.wrap "copy" (copy) m

let empty () = []

let is_empty = List.is_empty

let num_rows = List.length

let is_empty m =
num_rows m = 0
(*This should be different if the implimentation is sound*)
(*m.column_count = 0*)
let num_cols m =
if m = [] then 0 else V.length (hd m)

let init_with_vec v =
[v]

let num_cols m = if m = [] then 0 else V.length (hd m)
let append_row m row =
m @ [row]

let copy m = m
(* Lists are immutable, so this should suffice? A.t is mutuable currently, but is treated like its not in ArrayMatrix*)
let get_row = List.nth

let copy m =
Timing.wrap "copy" (copy) m
let remove_row m n =
List.remove_at n m

let remove_zero_rows m =
List.filter (fun row -> not (V.is_zero_vec row)) m

let swap_rows m j k =
List.mapi (fun i row -> if i = j then List.nth m k else if i = k then List.nth m j else row) m

(* This only works if Array.modifyi has been removed from dim_add *)
let add_empty_columns (m : t) (cols : int array) : t =
let add_empty_columns m cols =
let cols = Array.to_list cols in
let sorted_cols = List.sort Stdlib.compare cols in
let rec count_sorted_occ acc cols last count =
Expand All @@ -55,40 +72,86 @@ module ListMatrix: AbstractMatrix =
let add_empty_columns m cols =
Timing.wrap "add_empty_cols" (add_empty_columns m) cols

let append_row m row =
m @ [row]

let get_row m n =
List.nth m n

let remove_row m n =
List.remove_at n m

let get_col m n =
(*let () = Printf.printf "get_col %i of m:\n%s\n%s\n" n (show m) (V.show (V.of_list @@ List.map (fun row -> V.nth row n) m)) in*)
V.of_list @@ List.map (fun row -> V.nth row n) m (* builds full col including zeros, maybe use sparselist instead? *)

let get_col m n =
Timing.wrap "get_col" (get_col m) n

let set_col m new_col n =
let set_col m new_col n = (* TODO: Optimize! AND CURRENTLY WRONG SEMANTICS IF VECTOR LENGTH <> NUM_ROWS! *)
(* List.mapi (fun row_idx row -> V.set_nth row n (V.nth new_col row_idx)) m *)
List.map2 (fun row value -> V.set_nth row n value) m (V.to_list new_col)

let del_col m j =
if num_cols m = 1 then empty ()
else
List.map (fun row -> V.remove_nth row j) m

let del_cols m cols =
let cols = Array.to_list cols in (* TODO: Is it possible to use list for Apron dimchange? *)
let sorted_cols = List.sort_uniq Stdlib.compare cols in (* Apron Docs: Repetitions are meaningless (and are not correct specification) *)
if (List.length sorted_cols) = num_cols m then empty()
else
List.map (fun row -> V.remove_at_indices row sorted_cols) m

let del_cols m cols = Timing.wrap "del_cols" (del_cols m) cols

let map2 f m v =
let vector_length = V.length v in
List.mapi (fun index row -> if index < vector_length then f row (V.nth v index) else row ) m

let map2i f m v = (* TODO: Optimize! We should probably do it like in map2 *)
let rec map2i_min i acc m v =
match m, v with
| [], _ -> List.rev acc
| row :: rs, [] -> List.rev_append (row :: acc) rs
| row :: rs, value :: vs -> map2i_min (i + 1) (f i row value :: acc) rs vs
in
map2i_min 0 [] m (V.to_list v)

let find_opt = List.find_opt

let append_matrices m1 m2 = (* keeps dimensions of first matrix, what if dimensions differ?*)
m1 @ m2

let equal m1 m2 = Timing.wrap "equal" (equal m1) m2

let div_row (row : V.t) (pivot : A.t) : V.t =
V.map_f_preserves_zero (fun a -> a /: pivot) row

let swap_rows m j k =
List.mapi (fun i row -> if i = j then List.nth m k else if i = k then List.nth m j else row) m
V.map_f_preserves_zero (fun a -> a /: pivot) row (* TODO: This is a case for apply_with_c *)

let sub_scaled_row row1 row2 s =
V.map2_f_preserves_zero (fun x y -> x -: (s *: y)) row1 row2

(* This function return a tuple of row index and pivot position (column) in m *)
(* TODO: maybe we could use a Hashmap instead of a list? *)
let get_pivot_positions m : (int * int) list =
List.rev @@ List.fold_lefti (
fun acc i row -> match V.find_first_non_zero row with
| None -> acc
| Some (pivot_col, _) -> (i, pivot_col) :: acc
) [] m

let assert_rref m =
let pivot_l = get_pivot_positions m in
let rec validate m i =
match m with
| [] -> ()
| v::vs when (V.is_zero_vec v) ->
if List.exists (fun v -> not @@ V.is_zero_vec v) vs
then raise (Invalid_argument "Matrix not in rref: zero row!")
else ()
| v::vs ->
let rec validate_vec pl =
match pivot_l with
| [] -> true
| (pr, pc)::ps ->
let target = if pr <> i then A.zero else A.one in
if V.nth v pc <>: target then false else validate_vec ps
in if validate_vec pivot_l then validate vs (i+1) else raise (Invalid_argument "Matrix not in rref: pivot column not empty!")
in validate m 0

(* TODO: Remove this! Just to suppress warning *)
let () = assert_rref (empty ())

Check warning

Code scanning / Semgrep OSS

Semgrep Finding: semgrep.list-length-compare-n Warning

computing list length is inefficient for length comparison, use compare_length_with instead
(* Reduces the jth column with the last row that has a non-zero element in this column. *)
let reduce_col m j =
if is_empty m then m
Expand All @@ -113,41 +176,12 @@ module ListMatrix: AbstractMatrix =
sub_scaled_row row pivot_row s)
) m

let del_col m j =
if num_cols m = 1 then empty ()
else
List.map (fun row -> V.remove_nth row j) m

let del_cols m cols =
let cols = Array.to_list cols in (* TODO: Is it possible to use list for Apron dimchange? *)
let sorted_cols = List.sort_uniq Stdlib.compare cols in (* Apron Docs: Repetitions are meaningless (and are not correct specification) *)
if (List.length sorted_cols) = num_cols m then empty()
else
List.map (fun row -> V.remove_at_indices row sorted_cols) m

let del_cols m cols = Timing.wrap "del_cols" (del_cols m) cols

let map2i f m v =
let rec map2i_min i acc m v =
match m, v with
| [], _ -> List.rev acc
| row :: rs, [] -> List.rev_append (row :: acc) rs
| row :: rs, value :: vs -> map2i_min (i + 1) (f i row value :: acc) rs vs
in
map2i_min 0 [] m (V.to_list v)

let remove_zero_rows m =
List.filter (fun row -> not (V.is_zero_vec row)) m

let init_with_vec v =
[v]

let normalize m =
let col_count = num_cols m in
let dec_mat_2D (m : t) (row_idx : int) (col_idx : int) : t =
let cut_front_matrix m row_idx col_idx =
List.filteri_map (fun i row -> if i < row_idx then None else Some (V.starting_from_nth col_idx row)) m
in
let dec_mat_2D m row_idx col_idx = Timing.wrap "dec_mat_2D" (dec_mat_2D m row_idx) col_idx in
let cut_front_matrix m row_idx col_idx = Timing.wrap "cut_front_matrix" (cut_front_matrix m row_idx) col_idx in
(* Function for finding first pivot in an extracted part of the matrix (row_idx and col_idx indicate which part of the original matrix) *)
(* The last column represents the constant in the affeq *)
let find_first_pivot m' row_idx col_idx =
Expand All @@ -158,26 +192,25 @@ module ListMatrix: AbstractMatrix =
let row_first_non_zero = V.find_first_non_zero row in
match row_first_non_zero with
| None -> (cur_row, cur_col, cur_val)
| Some (idx, value) -> (* let () = Printf.printf "We found first non-zero at index %i in row %i\n" idx i in *)
if idx < cur_col then (i, idx, value) else (cur_row, cur_col, cur_val)
) (num_rows m', max_piv_col_idx + 1, A.zero) m' (* Initializing with max, so num_cols m indicates that pivot is not found *)
| Some (idx, value) -> if idx < cur_col then (i, idx, value) else (cur_row, cur_col, cur_val)
) (num_rows m', max_piv_col_idx + 1, A.zero) m'
in
if piv_col = (max_piv_col_idx + 1) then None else Some (row_idx + piv_row, col_idx + piv_col, piv_val)
in
let find_first_pivot m' row_idx col_idx = Timing.wrap "find_first_pivot" (find_first_pivot m' row_idx) col_idx in
let affeq_rows_are_valid m = (* Check if the semantics of an rref-affeq matrix are correct *)
let col_count = num_cols m in
let row_is_valid row = (* TODO: Vector findi_opt *)
let row_is_valid row =
match V.find_first_non_zero row with
| Some (idx, _) -> if idx < col_count - 1 then true else false (* If all cofactors of the affeq are zero, but the constant is non-zero, the row is invalid *)
| None -> true (* Full zero row is valid *)
| None -> true
in
List.for_all row_is_valid m in
let rec main_loop m m' row_idx col_idx =
let rec find_piv_and_reduce m m' row_idx col_idx =
if col_idx >= (col_count - 1) then m (* In this case the whole bottom of the matrix starting from row_index is Zero, so it is normalized *)
else
match find_first_pivot m' row_idx col_idx with
| None -> m (* No pivot found means already normalized*)
| None -> m (* No pivot found means already normalized *)
| Some (piv_row_idx, piv_col_idx, piv_val) -> (
(* let () = Printf.printf "The current matrix is: \n%s and the pivot is (%i, %i, %s)\n" (show m) piv_row_idx piv_col_idx (A.to_string piv_val) in *)
let m = if piv_row_idx <> row_idx then swap_rows m row_idx piv_row_idx else m in
Expand All @@ -186,44 +219,13 @@ module ListMatrix: AbstractMatrix =
let subtracted_m = List.mapi (fun idx row -> if idx <> row_idx then
let scale = V.nth row piv_col_idx in
sub_scaled_row row piv_row scale else row) normalized_m in
let m' = dec_mat_2D subtracted_m (row_idx + 1) (piv_col_idx + 1) in
main_loop subtracted_m m' (row_idx + 1) (piv_col_idx + 1)) (* We start at piv_col_idx + 1 because every other col before that is zero at the bottom*)
let m' = cut_front_matrix subtracted_m (row_idx + 1) (piv_col_idx + 1) in
find_piv_and_reduce subtracted_m m' (row_idx + 1) (piv_col_idx + 1)) (* We start at piv_col_idx + 1 because every other col before that is zero at the bottom*)
in
let m' = main_loop m m 0 0 in
let m' = find_piv_and_reduce m m 0 0 in
if affeq_rows_are_valid m' then Some m' else None (* TODO: We can check this for each row, using the helper function row_is_invalid *)


(* This function return a tuple of row index and pivot position (column) in m *)
(* TODO: maybe we could use a Hashmap instead of a list? *)
let get_pivot_positions (m : t) : (int * int) list =
List.rev @@ List.fold_lefti (
fun acc i row -> match V.find_first_non_zero row with
| None -> acc
| Some (pivot_col, _) -> (i, pivot_col) :: acc
) [] m

let assert_rref m =
let pivot_l = get_pivot_positions m in
let rec validate m i =
match m with
| [] -> ()
| v::vs when (V.is_zero_vec v) ->
if List.exists (fun v -> not @@ V.is_zero_vec v) vs
then raise (Invalid_argument "Matrix not in rref: zero row!")
else ()
| v::vs ->
let rec validate_vec pl =
match pivot_l with
| [] -> true
| (pr, pc)::ps ->
let target = if pr <> i then A.zero else A.one in
if V.nth v pc <>: target then false else validate_vec ps
in if validate_vec pivot_l then validate vs (i+1) else raise (Invalid_argument "Matrix not in rref: pivot column not empty!")
in validate m 0

(* TODO: Remove this! Just to suppress warning *)
let () = assert_rref (empty ())

(* Sets the jth column to zero by subtracting multiples of v *)
let reduce_col_with_vec m j v =
let pivot_element = V.nth v j in
Expand Down Expand Up @@ -319,10 +321,4 @@ module ListMatrix: AbstractMatrix =

let is_covered_by m1 m2 = Timing.wrap "is_covered_by" (is_covered_by m1) m2

let find_opt f m =
List.find_opt f m

let map2 f m v =
let vector_length = V.length v in
List.mapi (fun index row -> if index < vector_length then f row (V.nth v index) else row ) m
end

0 comments on commit 60f7b67

Please sign in to comment.