From d1b48b4a04ca6b9bd1e05cf411dc1cdc2fbb1939 Mon Sep 17 00:00:00 2001 From: Lynette Mwangi <69351261+Lynette7@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:34:00 +0300 Subject: [PATCH] feat: support left shift or right shift operator --- src/circuit_writer/ir.rs | 10 ++++++- src/circuit_writer/writer.rs | 2 ++ src/constraints/field.rs | 53 +++++++++++++++++++++++++++++++++++- src/lexer/mod.rs | 20 ++++++++++++-- src/mast/mod.rs | 4 ++- src/parser/expr.rs | 8 +++++- src/type_checker/checker.rs | 4 ++- 7 files changed, 94 insertions(+), 7 deletions(-) diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index 0561701e5..0a34886b7 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -1,6 +1,6 @@ use ark_ff::Zero; use circ::{ - ir::term::{leaf_term, term, BoolNaryOp, Op, PfNaryOp, PfUnOp, Sort, Term, Value}, + ir::term::{leaf_term, term, BoolNaryOp, BvBinOp, Op, PfNaryOp, PfUnOp, Sort, Term, Value}, term, }; use num_bigint::BigUint; @@ -895,6 +895,14 @@ impl IRWriter { let t: Term = term![Op::BoolNaryOp(BoolNaryOp::Or); lhs.cvars[0].clone(), rhs.cvars[0].clone()]; Var::new_cvar(t, expr.span) } + Op2::RightShift => { + let t: Term = term![Op::BvBinOp(BvBinOp::Lshr); lhs.cvars[0].clone(), rhs.cvars[0].clone()]; + Var::new_cvar(t, expr.span) + } + Op2::LeftShift => { + let t: Term = term![Op::BvBinOp(BvBinOp::Shl); lhs.cvars[0].clone(), rhs.cvars[0].clone()]; + Var::new_cvar(t, expr.span) + } Op2::Division => { let t: Term = term![ Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(), diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index a2a56ca9b..859960af2 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -654,6 +654,8 @@ impl CircuitWriter { Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), + Op2::LeftShift => field::shift_left(self, &lhs[0], &rhs[0], expr.span), + Op2::RightShift => field::shift_right(self, &lhs[0], &rhs[0], expr.span), Op2::Division => todo!(), }; diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 38e120117..fc759c3b7 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -7,7 +7,7 @@ use crate::{ use super::boolean; -use ark_ff::{One, Zero}; +use ark_ff::{One, Field, PrimeField, Zero}; use std::ops::Neg; @@ -345,3 +345,54 @@ pub fn if_else_inner( let temp = mul(compiler, &one_minus_cond[0], else_, span); add(compiler, &cond_then[0], &temp[0], span) } + +/// Performs a left shift (multiplication by 2^n) on a field element +pub fn shift_left( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + shift: &ConstOrCell, + span: Span, +) -> Var { + match (lhs, shift) { + // Constant value and constant shift + (ConstOrCell::Const(val), ConstOrCell::Const(shift_amount)) => { + let two = B::Field::from(2u64); + let shift_value = two.pow([shift_amount.into_repr().as_ref()[0]]); + Var::new_constant(*val * shift_value, span) + } + // Constant shift and variable value + (ConstOrCell::Cell(var), ConstOrCell::Const(shift_amount)) => { + let two = B::Field::from(2u64); + let shift_value = two.pow([shift_amount.into_repr().as_ref()[0]]); + let res = compiler.backend.mul_const(var, &shift_value, span); + Var::new_var(res, span) + } + // variable shift + _ => unimplemented!("Variable shift amounts are not yet supported."), + } +} + +/// Performs a right shift (division by 2^n) on a field element +pub fn shift_right( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + shift: &ConstOrCell, + span: Span, +) -> Var { + match (lhs, shift) { + // Constant value and constant shift + (ConstOrCell::Const(val), ConstOrCell::Const(shift_amount)) => { + let shift_value = B::Field::from(2u64).pow(shift_amount.into_repr().as_ref()); + Var::new_constant(*val / shift_value, span) + } + // Constant shift and variable value + (ConstOrCell::Cell(var), ConstOrCell::Const(shift_amount)) => { + let shift_value = B::Field::from(2u64).pow(shift_amount.into_repr().as_ref()); + let shift_inverse = shift_value.inverse().expect("Division by zero"); + let res = compiler.backend.mul_const(var, &shift_inverse, span); + Var::new_var(res, span) + } + // Variable shift + _ => unimplemented!("Variable shift amounts are not yet supported."), + } +} diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 1a88221be..8f94fd499 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -157,6 +157,8 @@ pub enum TokenKind { Pipe, // | DoublePipe, // || Exclamation, // ! + DoubleGreater, // >> + DoubleLess, // << Question, // ? PlusEqual, // += MinusEqual, // -= @@ -201,6 +203,8 @@ impl Display for TokenKind { Pipe => "`|`", DoublePipe => "`||`", Exclamation => "`!`", + DoubleGreater => "`>>`", + DoubleLess => "`<<`", Question => "`?`", PlusEqual => "`+=`", MinusEqual => "`-=`", @@ -378,10 +382,22 @@ impl Token { } } '>' => { - tokens.push(TokenKind::Greater.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'>')) { + tokens.push(TokenKind::DoubleGreater.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Greater.new_token(ctx, 1)); + } } '<' => { - tokens.push(TokenKind::Less.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'<')) { + tokens.push(TokenKind::DoubleLess.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Less.new_token(ctx, 1)); + } } '=' => { let next_c = chars.peek(); diff --git a/src/mast/mod.rs b/src/mast/mod.rs index d955bc485..96891df3e 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -801,7 +801,9 @@ fn monomorphize_expr( | Op2::Multiplication | Op2::Division | Op2::BoolAnd - | Op2::BoolOr => lhs_mono.typ, + | Op2::BoolOr + | Op2::RightShift + | Op2::LeftShift => lhs_mono.typ, }; let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono; diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 27fd375ef..2fa1ed0a6 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -134,6 +134,8 @@ pub enum Op2 { Inequality, BoolAnd, BoolOr, + RightShift, + LeftShift, } impl Expr { @@ -512,7 +514,9 @@ impl Expr { | TokenKind::NotEqual | TokenKind::DoubleAmpersand | TokenKind::DoublePipe - | TokenKind::Exclamation, + | TokenKind::Exclamation + | TokenKind::DoubleGreater + | TokenKind::DoubleLess, .. }) => { // lhs + rhs @@ -526,6 +530,8 @@ impl Expr { TokenKind::NotEqual => Op2::Inequality, TokenKind::DoubleAmpersand => Op2::BoolAnd, TokenKind::DoublePipe => Op2::BoolOr, + TokenKind::DoubleGreater => Op2::RightShift, + TokenKind::DoubleLess => Op2::LeftShift, _ => unreachable!(), }; diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index cf8dda009..242cb5b4b 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -388,7 +388,9 @@ impl TypeChecker { | Op2::Multiplication | Op2::Division | Op2::BoolAnd - | Op2::BoolOr => lhs_node.typ, + | Op2::BoolOr + | Op2::RightShift + | Op2::LeftShift => lhs_node.typ, }; Some(ExprTyInfo::new_anon(typ))