Skip to content

Commit 3717f53

Browse files
committed
Fix bugs in implem of Bitmask and Pack
1 parent a125bc6 commit 3717f53

31 files changed

+234
-117
lines changed

arch/AVX.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
#define LIFT_32(x) _mm256_set1_epi32(x)
6464
#define LIFT_64(x) _mm256_set1_epi64x(x)
6565

66+
67+
#define BITMASK(x,n,c) _mm256_sub_epi##c(ZERO, __mm256_and_si256(_mm256_slli_epi##c(x,n), _mm256_set1_epi##c(1)))
68+
69+
#define PACK_8x2_to_16(a,b) /* TODO: implement with shuffles */
70+
#define PACK_16x2_to_32(a,b) /* TODO: implement with shuffles */
71+
#define PACK_32x2_to_64(a,b) /* TODO: implement with shuffles */
72+
73+
6674
#define DATATYPE __m256i
6775

6876
#define SET_ALL_ONE() ONES

arch/AVX512.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@
7171
#define LIFT_32(x) _mm512_set1_epi32(x)
7272
#define LIFT_64(x) _mm512_set1_epi64(x)
7373

74+
75+
#define BITMASK(x,n,c) _mm512_sub_epi##c(ZERO, __mm512_and_si512(_mm512_slli_epi##c(x,n), _mm512_set1_epi##c(1)))
76+
77+
#define PACK_8x2_to_16(a,b) /* TODO: implement with shuffles */
78+
#define PACK_16x2_to_32(a,b) /* TODO: implement with shuffles */
79+
#define PACK_32x2_to_64(a,b) /* TODO: implement with shuffles */
80+
7481
#define DATATYPE __m512i
7582

7683
#define SET_ALL_ONE() ONES

arch/SSE.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@
6868
#define LIFT_32(x) _mm_set1_epi32(x)
6969
#define LIFT_64(x) _mm_set1_epi64x(x)
7070

71+
#define BITMASK(x,n,c) _mm_sub_epi##c(ZERO, __mm_and_si128(_mm_slli_epi##c(x,n), _mm_set1_epi##c(1)))
72+
73+
#define PACK_8x2_to_16(a,b) /* TODO: implement with shuffles */
74+
#define PACK_16x2_to_32(a,b) /* TODO: implement with shuffles */
75+
#define PACK_32x2_to_64(a,b) /* TODO: implement with shuffles */
7176

7277
#define ORTHOGONALIZE(in,out) orthogonalize(in,out)
7378
#define UNORTHOGONALIZE(in,out) unorthogonalize(in,out)

arch/STD.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
#define LIFT_32(x) (x)
5757
#define LIFT_64(x) (x)
5858

59+
#define BITMASK(x,n,c) -(((x) >> (n)) & 1)
60+
61+
#define PACK_8x2_to_16(a,b) ((((uint16_t)(a)) << 8) | ((uint16_t) (b)))
62+
#define PACK_16x2_to_32(a,b) ((((uint32_t)(a)) << 16) | ((uint32_t) (b)))
63+
#define PACK_32x2_to_64(a,b) ((((uint64_t)(a)) << 32) | ((uint64_t) (b)))
64+
65+
5966
#define refresh(x,y) *(y) = x
6067

6168
#ifndef DATATYPE

src/c_gen/nodes_to_c.ml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ let rec expr_to_c (m:mtyp)
118118
(List.length l)
119119
(var_to_c lift_env env v)
120120
(join "," (List.map string_of_int l))
121+
| Bitmask(e',ae) -> sprintf "BITMASK(%s, %s, %d)"
122+
(expr_to_c m lift_env env env_var e')
123+
(aexpr_to_c ae)
124+
(match m with
125+
| Mint m_val -> m_val
126+
| _ -> assert false)
127+
| Pack(e1,e2,Some typ) ->
128+
let args_m = get_type_m (get_normed_expr_type env_var e1) in
129+
sprintf "PACK_%dx2_to_%d(%s,%s)"
130+
(match args_m with
131+
| Mint m_val -> m_val
132+
| _ -> assert false)
133+
(match m with
134+
| Mint m_val -> m_val
135+
| _ -> assert false)
136+
(expr_to_c args_m lift_env env env_var e1)
137+
(expr_to_c args_m lift_env env env_var e2)
121138
| Shift(op,e,ae) ->
122139
sprintf "%s(%s,%s,%d)"
123140
(shift_op_to_c op)

src/normalization/clean.ml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ let rec collect_var env (var:var) : unit =
1212

1313
let rec collect_expr env (e:expr) : unit =
1414
match e with
15-
| Const _ -> ()
16-
| ExpVar(v) -> collect_var env v
17-
| Tuple l -> List.iter (collect_expr env) l
18-
| Not e -> collect_expr env e
19-
| Log(_,x,y) -> collect_expr env x; collect_expr env y
20-
| Arith(_,x,y) -> collect_expr env x; collect_expr env y
21-
| Shift(_,x,_) -> collect_expr env x
22-
| Shuffle(v,_) -> collect_var env v
23-
| Mask(e,_) -> collect_expr env e
24-
| Pack(l,_) -> List.iter (collect_expr env) l
25-
| Fun(_,l) -> List.iter (collect_expr env) l
26-
| Fun_v(_,_,l) -> List.iter (collect_expr env) l
15+
| Const _ -> ()
16+
| ExpVar(v) -> collect_var env v
17+
| Tuple l -> List.iter (collect_expr env) l
18+
| Not e -> collect_expr env e
19+
| Log(_,x,y) -> collect_expr env x; collect_expr env y
20+
| Arith(_,x,y) -> collect_expr env x; collect_expr env y
21+
| Shift(_,x,_) -> collect_expr env x
22+
| Shuffle(v,_) -> collect_var env v
23+
| Bitmask(e,_) -> collect_expr env e
24+
| Pack(e1,e2,_) -> collect_expr env e1; collect_expr env e2
25+
| Fun(_,l) -> List.iter (collect_expr env) l
26+
| Fun_v(_,_,l) -> List.iter (collect_expr env) l
2727

2828
let clean_in_deqs (vars:p) (deqs:deq list) : p =
2929
let env = Hashtbl.create 100 in

src/normalization/expand_array.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ let rec expand_expr env_var env_keep env env_it force (e:expr) : expr =
144144
Shift(op,e1',simpl_arith env ae)
145145
| Shuffle(v,pat) -> Tuple(List.map (fun x -> Shuffle(x,pat))
146146
(expand_var env_var env_keep env force v))
147-
| Mask(e,i) -> Mask(rec_call e,i)
148-
| Pack(l,t) -> Pack(List.map rec_call l,t)
147+
| Bitmask(e,ae) -> Bitmask(rec_call e,ae)
148+
| Pack(e1,e2,t) -> Pack(rec_call e1, rec_call e2, t)
149149
| Fun(f,el) ->
150150
if f.name = "refresh" then
151151
Fun(f,List.map (expand_expr env_var env_keep env env_it

src/normalization/expand_parameters.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ module Unroll = struct
117117
| Arith(op,x,y) -> Arith(op,unroll_expr env_it x,unroll_expr env_it y)
118118
| Shift(op,e,ae) -> Shift(op,unroll_expr env_it e,simpl_arith env_it ae)
119119
| Shuffle(v,l) -> Shuffle(unroll_var env_it v,l)
120-
| Mask(e,i) -> Mask(unroll_expr env_it e,i)
121-
| Pack(l,t) -> Pack(List.map (unroll_expr env_it) l,t)
120+
| Bitmask(e,ae) -> Bitmask(unroll_expr env_it e,simpl_arith env_it ae)
121+
| Pack(e1,e2,t) -> Pack(unroll_expr env_it e1, unroll_expr env_it e2, t)
122122
| Fun(f,l) -> Fun(f,List.map (unroll_expr env_it) l)
123123
| Fun_v _ -> assert false
124124

@@ -205,10 +205,10 @@ let rec propagate_expr (expand_env:(var,var list) Hashtbl.t) (e:expr) : expr =
205205
Shift(op,propagate_expr expand_env e,ae)
206206
| Shuffle(v,pat) ->
207207
Shuffle(List.hd (propagate_var expand_env v),pat)
208-
| Mask(e,i) ->
209-
Mask(propagate_expr expand_env e, i)
210-
| Pack(l,t) ->
211-
Pack (List.map (propagate_expr expand_env) l, t)
208+
| Bitmask(e,ae) ->
209+
Bitmask(propagate_expr expand_env e, ae)
210+
| Pack(e1,e2,t) ->
211+
Pack (propagate_expr expand_env e1, propagate_expr expand_env e2, t)
212212
| Fun(x,es) ->
213213
let l = List.map (propagate_expr expand_env) es in
214214
(match l with

src/normalization/expand_permut.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ let rec apply_perm_e env_fun env_var (e:expr) : expr =
2727
| Log(op,x,y) -> Log(op,apply_perm_e env_fun env_var x,apply_perm_e env_fun env_var y)
2828
| Arith(op,x,y) -> Arith(op,apply_perm_e env_fun env_var x,apply_perm_e env_fun env_var y)
2929
| Shift(op,e,n) -> Shift(op,apply_perm_e env_fun env_var e,n)
30-
| Mask(e,i) -> Mask(apply_perm_e env_fun env_var e,i)
31-
| Pack(l,t) -> Pack(List.map (apply_perm_e env_fun env_var) l,t)
30+
| Bitmask(e,ae) -> Bitmask(apply_perm_e env_fun env_var e,ae)
31+
| Pack(e1,e2,t) -> Pack(apply_perm_e env_fun env_var e1,apply_perm_e env_fun env_var e2,t)
3232
| Fun(f,l) -> let l' = List.map (apply_perm_e env_fun env_var) l in
3333
(match env_fetch env_fun f with
3434
| Some perm -> Tuple (list_from_perm env_var perm l')

src/normalization/fix_dim.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ let rec dim_expr (v_tgt:var) (dim:int) (size:int) (e:expr) : expr =
117117
| Arith(op,x,y) -> Arith(op,dim_expr v_tgt dim size x,dim_expr v_tgt dim size y)
118118
| Shift(op,e,ae) -> Shift(op,dim_expr v_tgt dim size e,ae)
119119
| Shuffle(v,l) -> Shuffle(dim_var v_tgt dim size v,l)
120-
| Mask(e,i) -> Mask(dim_expr v_tgt dim size e,i)
121-
| Pack(l,t) -> Pack(List.map (dim_expr v_tgt dim size) l,t)
120+
| Bitmask(e,ae) -> Bitmask(dim_expr v_tgt dim size e,ae)
121+
| Pack(e1,e2,t) -> Pack(dim_expr v_tgt dim size e1,dim_expr v_tgt dim size e2, t)
122122
| Fun(f,l) -> Fun(f,List.map (dim_expr v_tgt dim size) l)
123123
| Fun_v _ -> assert false
124124

0 commit comments

Comments
 (0)