Skip to content

Commit

Permalink
last level of stack proof
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Dec 21, 2024
1 parent bf85cc7 commit 2f3508c
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 32 deletions.
100 changes: 68 additions & 32 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2_stack.ec
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ module Mix = {
t <@ Jkem_avx2.M(Syscall)._poly_invntt(t);
mp <@ Jkem_avx2.M(Syscall)._poly_sub(mp, v, t);
mp <@ Jkem_avx2.M(Syscall).__poly_reduce(mp);
(msgp, mp) <@ M._i_poly_tomsg(msgp, mp);
(msgp, mp) <@ Jkem_avx2.M(Syscall)._poly_tomsg_1(msgp, mp);

return msgp;
}

Expand Down Expand Up @@ -429,6 +429,13 @@ proc.
by unroll for {1} ^while; unroll for {2} ^while; sim.
qed.

equiv aux_invntt2 :
Jkem_avx2.M(Syscall)._poly_invntt ~ Jkem_avx2.M(Syscall)._poly_invntt : ={arg} ==> ={res}.
proc.
by unroll for {1} ^while; unroll for {2} ^while; sim.
qed.


equiv mlkem_correct_enc_avx2_stack_mix :
M.__indcpa_enc ~ Mix.__indcpa_enc : ={arg} ==> ={res}.
proc => /=.
Expand All @@ -443,21 +450,7 @@ proc => /=.
sim (M._poly_invntt ~ Jkem_avx2.M(Syscall)._poly_invntt : (true)); by apply aux_invntt.
qed.

require import MLKEMFCLib.

op load_array1184 (m : global_mem_t) (p : address) : W8.t Array1184.t =
(Array1184.init (fun (i : int) => m.[p + i])).

lemma polyvec_to_bytes_stack_equiv _mem _pos:
0 <= _pos <= 1184 =>
equiv [ Jkem_avx2_stack.M.__i_polyvec_tobytes
~ Jkem_avx2.M(Syscall).__polyvec_tobytes :
Glob.mem{2} = _mem /\ arg{2}.`1 = W64.of_int _pos /\ arg{1}.`2 = arg{2}.`2 ==>
Glob.mem{2} =
stores _mem _pos
(take 1152
(to_list res{1}))].
admitted.
require import MLKEMFCLib MLKEM_PolyPolyVec_stack_bridges.


import InnerPKE.
Expand Down Expand Up @@ -585,19 +578,7 @@ case (1152 <= add < 1152 + 8*i{2}).
by rewrite get8_set64_directE 1,2:/# ifT 1:/# /get8 /#.
qed.

equiv aux_invntt2 :
Jkem_avx2.M(Syscall)._poly_invntt ~ Jkem_avx2.M(Syscall)._poly_invntt : ={arg} ==> ={res}.
proc.
by unroll for {1} ^while; unroll for {2} ^while; sim.
qed.


lemma polyvec_from_bytes_stack_equiv:
equiv [ Jkem_avx2_stack.M.__i_polyvec_frombytes
~ Jkem_avx2.M(Syscall).__polyvec_frombytes :
to_uint arg{2} = 0 /\
load_array1152 Glob.mem{2} 0 = arg{1} ==> ={res}].
admitted.


equiv mlkem_correct_enc_avx2_stack :
Expand Down Expand Up @@ -650,8 +631,23 @@ seq 1 1 : #pre.
rewrite /load_array1152 /load_array1184 tP => ??.
by rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=.

seq 3 3 : #pre.
+ admit.
seq 3 3 : #{/~pkp{2}}pre.
+ conseq />.
while (#{/~pkp{2}}pre /\ to_uint pkp{2} = 1152+8*w{1} /\ aux{1} = 4 /\ to_uint i{2} = w{1} /\ 0 <= w{1} <= 4); last by auto => />.
auto => /> &2 ????;rewrite !ultE /= => ?; do split.
+ rewrite tP => k kb; rewrite initiE 1:/# /= initiE 1:/# /=.
rewrite !get8_set64_directE 1..4:/#.
case (8 * to_uint i{2} <= k && k < 8 * to_uint i{2} + 8) => *; last by smt().
rewrite /get64_direct /loadW64; congr;congr.
apply W8u8.Pack.ext_eq => j jb.
rewrite /unpack8 !initiE 1,2:/# /= initiE 1:/#.
by rewrite /load_array1184 initiE /#.
+ by rewrite to_uintD_small /= /#.
+ by rewrite to_uintD_small /= /#.
+ by smt().
+ by smt().
+ by rewrite to_uintD_small /= /#.
+ by rewrite to_uintD_small /= /#.

by sim (Jkem_avx2.M(Syscall)._poly_invntt ~ Jkem_avx2.M(Syscall)._poly_invntt : (true))
(Jkem_avx2.M(Syscall).aBUFLEN____dumpstate_array_avx2 ~ Jkem_avx2.M(Syscall).aBUFLEN____dumpstate_array_avx2 : true) (M.a64____dumpstate_array_avx2 ~ Jkem_avx2.M(Syscall).a64____dumpstate_array_avx2 :true) => /=;[ apply aux_buflen_dumpstate1 | apply aux_invntt2].
Expand All @@ -672,4 +668,44 @@ Array960.init(fun i => ct{1}.[i]) = cph{2}.`1 /\
Array128.init(fun i => ct{1}.[i+960]) = cph{2}.`2 ==>
={r});1,2:smt().
+ by call mlkem_correct_dec_avx2_stack_mix;auto.
admitted.

pose _skp := W64.of_int 0.
pose _ctp1 := W64.of_int 1152.
pose _ctp2 := W64.of_int (1152+960).

transitivity {1} {r <@ Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_dec_1(msgp,_ctp1, _skp); }
(load_array1152 Glob.mem{2} 0 = sk{1} /\ ={msgp} /\
load_array1088 Glob.mem{2} 1152 = ct{1} ==> ={r})
(load_array1152 Glob.mem{1} 0 = sk{2} /\
load_array960 Glob.mem{1} 1152 = cph{2}.`1 /\
load_array128 Glob.mem{1} (1152+960) = cph{2}.`2
==>
={r}); 2: smt();last first.
+ ecall (mlkem_correct_dec (Glob.mem{1}) 1152 0).
auto => /> &1 &2; rewrite /load_array1152 /load_array960 /load_array128 !tP => ??x1 x2 ?; do split.
+ rewrite tP => k kb; rewrite initiE 1: /# /=;smt(Array960.initiE Array128.initiE).
+ rewrite tP => k kb; rewrite initiE 1: /# /=;smt(Array960.initiE Array128.initiE).
+ move => /> &1 &2; rewrite !tP => ??.
exists (stores (stores (stores witness 0 (to_list sk{2})) 1152 (to_list cph{2}.`1)) (1152+960) (to_list cph{2}.`2)) msgp{1}.
do split.
+ rewrite /load_array1152 tP => k kb.
rewrite initiE 1:/# /= !get_storesE size_to_list size_to_list; smt(@Array1152).
+ rewrite /load_array1088 tP => k kb.
rewrite initiE 1:/# /= !get_storesE size_to_list size_to_list; smt(@Array960 @Array128).
+ rewrite /load_array1152 tP => k kb.
rewrite initiE 1:/# /= !get_storesE size_to_list size_to_list; smt(@Array1152 @Array960 @Array128).
+ rewrite /load_array960 tP => k kb.
rewrite initiE 1:/# /= !get_storesE size_to_list size_to_list; smt(@Array1152 @Array960 @Array128).
+ rewrite /load_array128 tP => k kb.
rewrite initiE 1:/# /= !get_storesE size_to_list size_to_list; smt(@Array1152 @Array960 @Array128).

inline {1} 1; inline {2} 1.
sp 3 3.
seq 5 5 : (#pre /\ ={bp,mp,skpv,t,v}); 1: by conseq />;sim.
sim 3 4 (Jkem_avx2.M(Syscall)._poly_invntt ~ Jkem_avx2.M(Syscall)._poly_invntt : (true)); 2: by apply aux_invntt2.
wp;call polyvec_from_bytes_stack_equiv => /=.
wp;call poly_decompress_stack_equiv.
wp;call polyvec_decompress_stack_equiv.
by auto => /> &1 &2;rewrite /load_array128 /load_array960 /load_array1088;
rewrite tP => k kb; rewrite !initiE 1,2:/# /= initiE /#.
qed.
78 changes: 78 additions & 0 deletions proof/correctness/avx2/MLKEM_PolyPolyVec_stack_bridges.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
require import AllCore IntDiv List.
require import Jkem_avx2 Jkem_avx2_stack MLKEM_InnerPKE_avx2 MLKEM.
from Jasmin require import JModel_x86.

require import Array32 Array33 Array64 Array148 Array256 Array384 Array768 Array960 Array128 Array1088 Array1152 Array1184 Array2304 Array2400.
require import Array8 WArray32 WArray33 Array300 WArray64 WArray1184 WArray2400.
import MLKEM InnerPKE.

require import MLKEMFCLib.

op load_array1184 (m : global_mem_t) (p : address) : W8.t Array1184.t =
(Array1184.init (fun (i : int) => m.[p + i])).

lemma poly_to_bytes_stack_equiv _mem _pos:
0 <= _pos <= 1184+2*384 =>
equiv [ Jkem_avx2_stack.M._i_poly_tobytes
~ Jkem_avx2.M(Syscall)._poly_tobytes :
Glob.mem{2} = _mem /\ arg{2}.`1 = W64.of_int _pos /\ arg{1}.`2 = arg{2}.`2 ==>
Glob.mem{2} =
stores _mem _pos
(to_list res{1}.`1)].
move => Hpos;proc => /=.
admitted.

lemma polyvec_to_bytes_stack_equiv _mem _pos:
0 <= _pos <= 1184 =>
equiv [ Jkem_avx2_stack.M.__i_polyvec_tobytes
~ Jkem_avx2.M(Syscall).__polyvec_tobytes :
Glob.mem{2} = _mem /\ arg{2}.`1 = W64.of_int _pos /\ arg{1}.`2 = arg{2}.`2 ==>
Glob.mem{2} =
stores _mem _pos
(take 1152
(to_list res{1}))].
move => Hpos;proc => /=.
seq 3 3 : (
Glob.mem{2} = (stores _mem _pos (take 384 (to_list r{1}))) /\ pp{2} = (of_int _pos)%W64 /\ ={a}
).
+ wp;call (poly_to_bytes_stack_equiv _mem _pos) => /=;1: by smt().
auto => />. admit.
seq 3 3 : (
Glob.mem{2} = (stores _mem _pos (take (2*384) (to_list r{1}))) /\ pp{2} = (of_int (_pos + 384))%W64 /\ ={a}
).
+ exlim Glob.mem{2} => _mem2.
wp;call (poly_to_bytes_stack_equiv _mem2 (_pos+384)) => /=; 1: by smt().
auto => />. admit.
seq 3 3 : (
Glob.mem{2} = (stores _mem _pos (take (3*384) (to_list r{1}))) /\ pp{2} = (of_int (_pos + 2*384))%W64 /\ ={a}
).
+ exlim Glob.mem{2} => _mem3.
wp;call (poly_to_bytes_stack_equiv _mem3 (_pos+2*384)) => /=; 1: by smt().
auto => />. admit.

by auto => />.
qed.

lemma polyvec_from_bytes_stack_equiv:
equiv [ Jkem_avx2_stack.M.__i_polyvec_frombytes
~ Jkem_avx2.M(Syscall).__polyvec_frombytes :
to_uint arg{2} = 0 /\
load_array1152 Glob.mem{2} 0 = arg{1} ==> ={res}].
admitted.

op load_array1088 (m : global_mem_t) (p : address) : W8.t Array1088.t = Array1088.init (fun (i : int) => m.[p + i]).

lemma polyvec_decompress_stack_equiv:
equiv [ Jkem_avx2_stack.M.__i_polyvec_decompress
~ Jkem_avx2.M(Syscall).__polyvec_decompress :
to_uint arg{2} = 1152 /\
load_array1088 Glob.mem{2} 1152 = arg{1} ==> ={res}].
admitted.

lemma poly_decompress_stack_equiv:
equiv [ Jkem_avx2_stack.M._i_poly_decompress
~ Jkem_avx2.M(Syscall)._poly_decompress :
to_uint arg{2}.`2 = 1152+960 /\
load_array128 Glob.mem{2} (1152+960) = arg{1}.`2 ==> ={res}].
admitted.

0 comments on commit 2f3508c

Please sign in to comment.