diff --git a/src/constEval.ts b/src/constEval.ts index 5249fdede..612561ac8 100644 --- a/src/constEval.ts +++ b/src/constEval.ts @@ -11,21 +11,38 @@ import { isSelfId, eqNames, idText, + AstValue, + isValue, } from "./grammar/ast"; -import { idTextErr, throwConstEvalError } from "./errors"; +import { TactConstEvalError, idTextErr, throwConstEvalError } from "./errors"; import { CommentValue, showValue, StructValue, Value } from "./types/types"; import { sha256_sync } from "@ton/crypto"; +import { + extractValue, + makeValueExpression, + makeUnaryExpression, + makeBinaryExpression, + divFloor, + modFloor, +} from "./optimizer/util"; +import { ExpressionTransformer } from "./optimizer/types"; +import { StandardOptimizer } from "./optimizer/standardOptimizer"; import { getStaticConstant, getType, hasStaticConstant, } from "./types/resolveDescriptors"; import { getExpType } from "./types/resolveExpression"; +import { dummySrcInfo } from "./grammar/grammar"; // TVM integers are signed 257-bit integers const minTvmInt: bigint = -(2n ** 256n); const maxTvmInt: bigint = 2n ** 256n - 1n; +// The optimizer that applies the rewriting rules during partial evaluation. +// For the moment we use an optimizer that respects overflows. +const optimizer: ExpressionTransformer = new StandardOptimizer(); + // Throws a non-fatal const-eval error, in the sense that const-eval as a compiler // optimization cannot be applied, e.g. to `let`-statements. // Note that for const initializers this is a show-stopper. @@ -106,58 +123,72 @@ function ensureMethodArity( } } -function evalUnaryOp( +export function evalUnaryOp( op: AstUnaryOperation, - operand: AstExpression, - source: SrcInfo, - ctx: CompilerContext, + valOperand: Value, + operandLoc: SrcInfo = dummySrcInfo, + source: SrcInfo = dummySrcInfo, ): Value { - // Tact grammar does not have negative integer literals, - // so in order to avoid errors for `-115792089237316195423570985008687907853269984665640564039457584007913129639936` - // which is `-(2**256)` we need to have a special case for it - if (operand.kind === "number" && op === "-") { - // emulating negative integer literals - return ensureInt(-operand.value, source); - } - const valOperand = evalConstantExpression(operand, ctx); switch (op) { case "+": - return ensureInt(valOperand, operand.loc); + return ensureInt(valOperand, operandLoc); case "-": - return ensureInt(-ensureInt(valOperand, operand.loc), source); + return ensureInt(-ensureInt(valOperand, operandLoc), source); case "~": - return ~ensureInt(valOperand, operand.loc); + return ~ensureInt(valOperand, operandLoc); case "!": - return !ensureBoolean(valOperand, operand.loc); + return !ensureBoolean(valOperand, operandLoc); case "!!": if (valOperand === null) { throwErrorConstEval( "non-null value expected but got null", - operand.loc, + operandLoc, ); } return valOperand; } } -// precondition: the divisor is not zero -// rounds the division result towards negative infinity -function divFloor(a: bigint, b: bigint): bigint { - const almostSameSign = a > 0n === b > 0n; - if (almostSameSign) { - return a / b; +function fullyEvalUnaryOp( + op: AstUnaryOperation, + operand: AstExpression, + source: SrcInfo, + ctx: CompilerContext, +): Value { + // Tact grammar does not have negative integer literals, + // so in order to avoid errors for `-115792089237316195423570985008687907853269984665640564039457584007913129639936` + // which is `-(2**256)` we need to have a special case for it + + if (operand.kind === "number" && op === "-") { + // emulating negative integer literals + return ensureInt(-operand.value, source); } - return a / b + (a % b === 0n ? 0n : -1n); + + const valOperand = evalConstantExpression(operand, ctx); + + return evalUnaryOp(op, valOperand, operand.loc, source); } -// precondition: the divisor is not zero -// rounds the result towards negative infinity -// Uses the fact that a / b * b + a % b == a, for all b != 0. -function modFloor(a: bigint, b: bigint): bigint { - return a - divFloor(a, b) * b; +function partiallyEvalUnaryOp( + op: AstUnaryOperation, + operand: AstExpression, + source: SrcInfo, + ctx: CompilerContext, +): AstExpression { + const simplOperand = partiallyEvalExpression(operand, ctx); + + if (isValue(simplOperand)) { + const valueOperand = extractValue(simplOperand as AstValue); + const result = evalUnaryOp(op, valueOperand, simplOperand.loc, source); + // Wrap the value into a Tree to continue simplifications + return makeValueExpression(result); + } else { + const newAst = makeUnaryExpression(op, simplOperand); + return optimizer.applyRules(newAst); + } } -function evalBinaryOp( +function fullyEvalBinaryOp( op: AstBinaryOperation, left: AstExpression, right: AstExpression, @@ -166,20 +197,61 @@ function evalBinaryOp( ): Value { const valLeft = evalConstantExpression(left, ctx); const valRight = evalConstantExpression(right, ctx); + + return evalBinaryOp(op, valLeft, valRight, left.loc, right.loc, source); +} + +function partiallyEvalBinaryOp( + op: AstBinaryOperation, + left: AstExpression, + right: AstExpression, + source: SrcInfo, + ctx: CompilerContext, +): AstExpression { + const leftOperand = partiallyEvalExpression(left, ctx); + const rightOperand = partiallyEvalExpression(right, ctx); + + if (isValue(leftOperand) && isValue(rightOperand)) { + const valueLeftOperand = extractValue(leftOperand as AstValue); + const valueRightOperand = extractValue(rightOperand as AstValue); + const result = evalBinaryOp( + op, + valueLeftOperand, + valueRightOperand, + leftOperand.loc, + rightOperand.loc, + source, + ); + // Wrap the value into a Tree to continue simplifications + return makeValueExpression(result); + } else { + const newAst = makeBinaryExpression(op, leftOperand, rightOperand); + return optimizer.applyRules(newAst); + } +} + +export function evalBinaryOp( + op: AstBinaryOperation, + valLeft: Value, + valRight: Value, + locLeft: SrcInfo = dummySrcInfo, + locRight: SrcInfo = dummySrcInfo, + source: SrcInfo = dummySrcInfo, +): Value { switch (op) { case "+": return ensureInt( - ensureInt(valLeft, left.loc) + ensureInt(valRight, right.loc), + ensureInt(valLeft, locLeft) + ensureInt(valRight, locRight), source, ); case "-": return ensureInt( - ensureInt(valLeft, left.loc) - ensureInt(valRight, right.loc), + ensureInt(valLeft, locLeft) - ensureInt(valRight, locRight), source, ); case "*": return ensureInt( - ensureInt(valLeft, left.loc) * ensureInt(valRight, right.loc), + ensureInt(valLeft, locLeft) * ensureInt(valRight, locRight), source, ); case "/": { @@ -187,44 +259,38 @@ function evalBinaryOp( // is a non-conventional one: by default it rounds towards negative infinity, // meaning, for instance, -1 / 5 = -1 and not zero, as in many mainstream languages. // Still, the following holds: a / b * b + a % b == a, for all b != 0. - const r = ensureInt(valRight, right.loc); + const r = ensureInt(valRight, locRight); if (r === 0n) throwErrorConstEval( "divisor expression must be non-zero", - right.loc, + locRight, ); - return ensureInt(divFloor(ensureInt(valLeft, left.loc), r), source); + return ensureInt(divFloor(ensureInt(valLeft, locLeft), r), source); } case "%": { // Same as for division, see the comment above // Example: -1 % 5 = 4 - const r = ensureInt(valRight, right.loc); + const r = ensureInt(valRight, locRight); if (r === 0n) throwErrorConstEval( "divisor expression must be non-zero", - right.loc, + locRight, ); - return ensureInt(modFloor(ensureInt(valLeft, left.loc), r), source); + return ensureInt(modFloor(ensureInt(valLeft, locLeft), r), source); } case "&": - return ( - ensureInt(valLeft, left.loc) & ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) & ensureInt(valRight, locRight); case "|": - return ( - ensureInt(valLeft, left.loc) | ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) | ensureInt(valRight, locRight); case "^": - return ( - ensureInt(valLeft, left.loc) ^ ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) ^ ensureInt(valRight, locRight); case "<<": { - const valNum = ensureInt(valLeft, left.loc); - const valBits = ensureInt(valRight, right.loc); + const valNum = ensureInt(valLeft, locLeft); + const valBits = ensureInt(valRight, locRight); if (0n > valBits || valBits > 256n) { throwErrorConstEval( `the number of bits shifted ('${valBits}') must be within [0..256] range`, - right.loc, + locRight, ); } try { @@ -240,12 +306,12 @@ function evalBinaryOp( } } case ">>": { - const valNum = ensureInt(valLeft, left.loc); - const valBits = ensureInt(valRight, right.loc); + const valNum = ensureInt(valLeft, locLeft); + const valBits = ensureInt(valRight, locRight); if (0n > valBits || valBits > 256n) { throwErrorConstEval( `the number of bits shifted ('${valBits}') must be within [0..256] range`, - right.loc, + locRight, ); } try { @@ -261,21 +327,13 @@ function evalBinaryOp( } } case ">": - return ( - ensureInt(valLeft, left.loc) > ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) > ensureInt(valRight, locRight); case "<": - return ( - ensureInt(valLeft, left.loc) < ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) < ensureInt(valRight, locRight); case ">=": - return ( - ensureInt(valLeft, left.loc) >= ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) >= ensureInt(valRight, locRight); case "<=": - return ( - ensureInt(valLeft, left.loc) <= ensureInt(valRight, right.loc) - ); + return ensureInt(valLeft, locLeft) <= ensureInt(valRight, locRight); case "==": // the null comparisons account for optional types, e.g. // a const x: Int? = 42 can be compared to null @@ -300,18 +358,20 @@ function evalBinaryOp( return valLeft !== valRight; case "&&": return ( - ensureBoolean(valLeft, left.loc) && - ensureBoolean(valRight, right.loc) + ensureBoolean(valLeft, locLeft) && + ensureBoolean(valRight, locRight) ); case "||": return ( - ensureBoolean(valLeft, left.loc) || - ensureBoolean(valRight, right.loc) + ensureBoolean(valLeft, locLeft) || + ensureBoolean(valRight, locRight) ); } } -function evalConditional( +// In the process of writing a partiallyEval version of this +// function for the partial evaluator +function fullyEvalConditional( condition: AstExpression, thenBranch: AstExpression, elseBranch: AstExpression, @@ -329,7 +389,9 @@ function evalConditional( } } -function evalStructInstance( +// In the process of writing a partiallyEval version of this +// function for the partial evaluator +function fullyEvalStructInstance( structTypeId: AstId, structFields: AstStructFieldInitializer[], ctx: CompilerContext, @@ -346,7 +408,9 @@ function evalStructInstance( ); } -function evalFieldAccess( +// In the process of writing a partiallyEval version of this +// function for the partial evaluator +function fullyEvalFieldAccess( structExpr: AstExpression, fieldId: AstId, source: SrcInfo, @@ -399,7 +463,9 @@ function evalFieldAccess( } } -function evalMethod( +// In the process of writing a partiallyEval version of this +// function for the partial evaluator +function fullyEvalMethod( methodName: AstId, object: AstExpression, args: AstExpression[], @@ -423,7 +489,9 @@ function evalMethod( } } -function evalBuiltins( +// In the process of writing a partiallyEval version of this +// function for the partial evaluator +function fullyEvalBuiltins( builtinName: AstId, args: AstExpression[], source: SrcInfo, @@ -651,27 +719,36 @@ function interpretEscapeSequences(stringLiteral: string, source: SrcInfo) { ); } +function lookupName(ast: AstId, ctx: CompilerContext): Value { + if (hasStaticConstant(ctx, ast.text)) { + const constant = getStaticConstant(ctx, ast.text); + if (constant.value !== undefined) { + return constant.value; + } else { + throwErrorConstEval( + `cannot evaluate declared constant ${idTextErr(ast)} as it does not have a body`, + ast.loc, + ); + } + } + throwNonFatalErrorConstEval("cannot evaluate a variable", ast.loc); +} + export function evalConstantExpression( ast: AstExpression, ctx: CompilerContext, ): Value { switch (ast.kind) { case "id": - if (hasStaticConstant(ctx, ast.text)) { - const constant = getStaticConstant(ctx, ast.text); - if (constant.value !== undefined) { - return constant.value; - } else { - throwErrorConstEval( - `cannot evaluate declared constant ${idTextErr(ast)} as it does not have a body`, - ast.loc, - ); - } - } - throwNonFatalErrorConstEval("cannot evaluate a variable", ast.loc); - break; + return lookupName(ast, ctx); case "method_call": - return evalMethod(ast.method, ast.self, ast.args, ast.loc, ctx); + return fullyEvalMethod( + ast.method, + ast.self, + ast.args, + ast.loc, + ctx, + ); case "init_of": throwNonFatalErrorConstEval( "initOf is not supported at this moment", @@ -690,21 +767,100 @@ export function evalConstantExpression( ast.loc, ); case "op_unary": - return evalUnaryOp(ast.op, ast.operand, ast.loc, ctx); + return fullyEvalUnaryOp(ast.op, ast.operand, ast.loc, ctx); case "op_binary": - return evalBinaryOp(ast.op, ast.left, ast.right, ast.loc, ctx); + return fullyEvalBinaryOp(ast.op, ast.left, ast.right, ast.loc, ctx); case "conditional": - return evalConditional( + return fullyEvalConditional( ast.condition, ast.thenBranch, ast.elseBranch, ctx, ); case "struct_instance": - return evalStructInstance(ast.type, ast.args, ctx); + return fullyEvalStructInstance(ast.type, ast.args, ctx); case "field_access": - return evalFieldAccess(ast.aggregate, ast.field, ast.loc, ctx); + return fullyEvalFieldAccess(ast.aggregate, ast.field, ast.loc, ctx); case "static_call": - return evalBuiltins(ast.function, ast.args, ast.loc, ctx); + return fullyEvalBuiltins(ast.function, ast.args, ast.loc, ctx); + } +} + +export function partiallyEvalExpression( + ast: AstExpression, + ctx: CompilerContext, +): AstExpression { + switch (ast.kind) { + case "id": + try { + return makeValueExpression(lookupName(ast, ctx)); + } catch (e) { + if (e instanceof TactConstEvalError) { + if (!e.fatal) { + // If a non-fatal error occurs during lookup, just return the symbol + return ast; + } + } + throw e; + } + case "method_call": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return makeValueExpression( + fullyEvalMethod(ast.method, ast.self, ast.args, ast.loc, ctx), + ); + case "init_of": + throwNonFatalErrorConstEval( + "initOf is not supported at this moment", + ast.loc, + ); + break; + case "null": + return ast; + case "boolean": + return ast; + case "number": + return makeValueExpression(ensureInt(ast.value, ast.loc)); + case "string": + return makeValueExpression( + ensureString( + interpretEscapeSequences(ast.value, ast.loc), + ast.loc, + ), + ); + case "op_unary": + return partiallyEvalUnaryOp(ast.op, ast.operand, ast.loc, ctx); + case "op_binary": + return partiallyEvalBinaryOp( + ast.op, + ast.left, + ast.right, + ast.loc, + ctx, + ); + case "conditional": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return makeValueExpression( + fullyEvalConditional( + ast.condition, + ast.thenBranch, + ast.elseBranch, + ctx, + ), + ); + case "struct_instance": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return makeValueExpression( + fullyEvalStructInstance(ast.type, ast.args, ctx), + ); + case "field_access": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return makeValueExpression( + fullyEvalFieldAccess(ast.aggregate, ast.field, ast.loc, ctx), + ); + case "static_call": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return makeValueExpression( + fullyEvalBuiltins(ast.function, ast.args, ast.loc, ctx), + ); } } diff --git a/src/grammar/ast.ts b/src/grammar/ast.ts index 76311aa70..f107a7170 100644 --- a/src/grammar/ast.ts +++ b/src/grammar/ast.ts @@ -569,6 +569,8 @@ export type AstNull = { loc: SrcInfo; }; +export type AstValue = AstNumber | AstBoolean | AstNull | AstString; + export type AstConstantAttribute = | { type: "virtual"; loc: SrcInfo } | { type: "overrides"; loc: SrcInfo } @@ -679,4 +681,149 @@ export function __DANGER_resetNodeId() { nextId = 1; } +// Test equality of AstExpressions. +export function eqExpressions( + ast1: AstExpression, + ast2: AstExpression, +): boolean { + if (ast1.kind !== ast2.kind) { + return false; + } + + switch (ast1.kind) { + case "null": + return true; + case "boolean": + return ast1.value === (ast2 as AstBoolean).value; + case "number": + return ast1.value === (ast2 as AstNumber).value; + case "string": + return ast1.value === (ast2 as AstString).value; + case "id": + return eqNames(ast1, ast2 as AstId); + case "method_call": + return ( + eqNames(ast1.method, (ast2 as AstMethodCall).method) && + eqExpressions(ast1.self, (ast2 as AstMethodCall).self) && + eqExpressionArrays(ast1.args, (ast2 as AstMethodCall).args) + ); + case "init_of": + return ( + eqNames(ast1.contract, (ast2 as AstInitOf).contract) && + eqExpressionArrays(ast1.args, (ast2 as AstInitOf).args) + ); + case "op_unary": + return ( + ast1.op === (ast2 as AstOpUnary).op && + eqExpressions(ast1.operand, (ast2 as AstOpUnary).operand) + ); + case "op_binary": + return ( + ast1.op === (ast2 as AstOpBinary).op && + eqExpressions(ast1.left, (ast2 as AstOpBinary).left) && + eqExpressions(ast1.right, (ast2 as AstOpBinary).right) + ); + case "conditional": + return ( + eqExpressions( + ast1.condition, + (ast2 as AstConditional).condition, + ) && + eqExpressions( + ast1.thenBranch, + (ast2 as AstConditional).thenBranch, + ) && + eqExpressions( + ast1.elseBranch, + (ast2 as AstConditional).elseBranch, + ) + ); + case "struct_instance": + return ( + eqNames(ast1.type, (ast2 as AstStructInstance).type) && + eqParameterArrays(ast1.args, (ast2 as AstStructInstance).args) + ); + case "field_access": + return ( + eqNames(ast1.field, (ast2 as AstFieldAccess).field) && + eqExpressions( + ast1.aggregate, + (ast2 as AstFieldAccess).aggregate, + ) + ); + case "static_call": + return ( + eqNames(ast1.function, (ast2 as AstStaticCall).function) && + eqExpressionArrays(ast1.args, (ast2 as AstStaticCall).args) + ); + } +} + +function eqParameters( + arg1: AstStructFieldInitializer, + arg2: AstStructFieldInitializer, +): boolean { + return ( + eqNames(arg1.field, arg2.field) && + eqExpressions(arg1.initializer, arg2.initializer) + ); +} + +function eqParameterArrays( + arr1: AstStructFieldInitializer[], + arr2: AstStructFieldInitializer[], +): boolean { + if (arr1.length !== arr2.length) { + return false; + } + + for (let i = 0; i < arr1.length; i++) { + if (!eqParameters(arr1[i]!, arr2[i]!)) { + return false; + } + } + + return true; +} + +function eqExpressionArrays( + arr1: AstExpression[], + arr2: AstExpression[], +): boolean { + if (arr1.length !== arr2.length) { + return false; + } + + for (let i = 0; i < arr1.length; i++) { + if (!eqExpressions(arr1[i]!, arr2[i]!)) { + return false; + } + } + + return true; +} + +export function isValue(ast: AstExpression): boolean { + switch (ast.kind) { + case "null": + case "boolean": + case "number": + case "string": + return true; + + case "struct_instance": + return ast.args.every((arg) => isValue(arg.initializer)); + + case "id": + case "method_call": + case "init_of": + case "op_unary": + case "op_binary": + case "conditional": + case "field_access": + case "static_call": + return false; + } +} + export { SrcInfo }; diff --git a/src/grammar/test/expr-equality.spec.ts b/src/grammar/test/expr-equality.spec.ts new file mode 100644 index 000000000..ff10859c6 --- /dev/null +++ b/src/grammar/test/expr-equality.spec.ts @@ -0,0 +1,418 @@ +import { __DANGER_resetNodeId, eqExpressions } from "../ast"; +import { parseExpression } from "../grammar"; + +type Test = { expr1: string; expr2: string; equality: boolean }; + +const valueExpressions: Test[] = [ + { expr1: "1", expr2: "1", equality: true }, + { expr1: "1", expr2: "true", equality: false }, + { expr1: "1", expr2: '"one"', equality: false }, + { expr1: "1", expr2: "null", equality: false }, + { expr1: "1", expr2: "g", equality: false }, + { expr1: "false", expr2: "true", equality: false }, + { expr1: "false", expr2: '"false"', equality: false }, + { expr1: "false", expr2: "false", equality: true }, + { expr1: "false", expr2: "null", equality: false }, + { expr1: "false", expr2: "g", equality: false }, + { expr1: '"one"', expr2: '"one"', equality: true }, + { expr1: '"one"', expr2: '"onw"', equality: false }, + { expr1: '"one"', expr2: "null", equality: false }, + { expr1: '"one"', expr2: "g", equality: false }, + { expr1: "null", expr2: "null", equality: true }, + { expr1: "null", expr2: "g", equality: false }, +]; + +const functionCallExpressions: Test[] = [ + { expr1: "f(1,4)", expr2: "f(1)", equality: false }, + { expr1: "f(1,4)", expr2: "f(1,4)", equality: true }, + { expr1: "f(1,4)", expr2: "1", equality: false }, + { expr1: "f(1,4)", expr2: "g(1,4)", equality: false }, + { expr1: "f(1,4)", expr2: "true", equality: false }, + { expr1: "f(1,4)", expr2: "null", equality: false }, + { expr1: "f(1,4)", expr2: "f", equality: false }, + { expr1: 'f("a",0)', expr2: 'f("a",0)', equality: true }, + { expr1: 'f("a",0)', expr2: 'f("a",null)', equality: false }, + { expr1: "f(true,0)", expr2: "f(0,true)", equality: false }, + { expr1: "f(true,0)", expr2: "f(true,0)", equality: true }, + { expr1: "f(g(1))", expr2: "g(f(1))", equality: false }, + + { expr1: "s.f(1,4)", expr2: "s.f(1)", equality: false }, + { expr1: "s.f(1,4)", expr2: "s.f(1,4)", equality: true }, + { expr1: "s.f(1,4)", expr2: "1", equality: false }, + { expr1: "s.f(1,4)", expr2: "s.g(1,4)", equality: false }, + { expr1: "s.f(1,4)", expr2: "true", equality: false }, + { expr1: "s.f(1,4)", expr2: "null", equality: false }, + { expr1: 's.f("a",0)', expr2: 's.f("a",0)', equality: true }, + { expr1: 's.f("a",0)', expr2: 's.f("a",null)', equality: false }, + { expr1: "s.f(true,0)", expr2: "s.f(0,true)", equality: false }, + { expr1: "s.f(true,0)", expr2: "s.f(true,0)", equality: true }, + { expr1: "s.f(s.g(1))", expr2: "s.g(s.f(1))", equality: false }, + + { expr1: "s.f(0)", expr2: "f(0)", equality: false }, +]; + +const unaryOpExpressions: Test[] = [ + { expr1: "+4", expr2: "+4", equality: true }, + { expr1: "+4", expr2: "-4", equality: false }, + { expr1: "+4", expr2: "!g", equality: false }, + { expr1: "+4", expr2: "!!g", equality: false }, + { expr1: "+4", expr2: "g!!", equality: false }, + { expr1: "+4", expr2: "~g", equality: false }, + { expr1: "-4", expr2: "-4", equality: true }, + { expr1: "-4", expr2: "!g", equality: false }, + { expr1: "-4", expr2: "!!g", equality: false }, + { expr1: "-4", expr2: "g!!", equality: false }, + { expr1: "-4", expr2: "~g", equality: false }, + { expr1: "!g", expr2: "!g", equality: true }, + { expr1: "!g", expr2: "!!g", equality: false }, + { expr1: "!g", expr2: "g!!", equality: false }, + { expr1: "!g", expr2: "~g", equality: false }, + { expr1: "g!!", expr2: "g!!", equality: true }, + { expr1: "g!!", expr2: "~g", equality: false }, + { expr1: "~g", expr2: "~g", equality: true }, +]; + +const binaryOpExpressions: Test[] = [ + { expr1: "g + r", expr2: "g + r", equality: true }, + { expr1: "g + r", expr2: "r + g", equality: false }, + { expr1: "g + r", expr2: "+r", equality: false }, + { expr1: "g + r", expr2: "g - r", equality: false }, + { expr1: "g + r", expr2: "g * r", equality: false }, + { expr1: "g + r", expr2: "g / r", equality: false }, + { expr1: "g + r", expr2: "g % r", equality: false }, + { expr1: "g + r", expr2: "g >> r", equality: false }, + { expr1: "g + r", expr2: "g << r", equality: false }, + { expr1: "g + r", expr2: "g & r", equality: false }, + { expr1: "g + r", expr2: "g | r", equality: false }, + { expr1: "g + r", expr2: "g ^ r", equality: false }, + { expr1: "g + r", expr2: "g != r", equality: false }, + { expr1: "g + r", expr2: "g > r", equality: false }, + { expr1: "g + r", expr2: "g < r", equality: false }, + { expr1: "g + r", expr2: "g >= r", equality: false }, + { expr1: "g + r", expr2: "g <= r", equality: false }, + { expr1: "g + r", expr2: "g == r", equality: false }, + { expr1: "g + r", expr2: "g && r", equality: false }, + { expr1: "g + r", expr2: "g || r", equality: false }, + { expr1: "g - r", expr2: "g - r", equality: true }, + { expr1: "g - r", expr2: "-r", equality: false }, + { expr1: "g - r", expr2: "r - g", equality: false }, + { expr1: "g - r", expr2: "g * r", equality: false }, + { expr1: "g - r", expr2: "g / r", equality: false }, + { expr1: "g - r", expr2: "g % r", equality: false }, + { expr1: "g - r", expr2: "g >> r", equality: false }, + { expr1: "g - r", expr2: "g << r", equality: false }, + { expr1: "g - r", expr2: "g & r", equality: false }, + { expr1: "g - r", expr2: "g | r", equality: false }, + { expr1: "g - r", expr2: "g ^ r", equality: false }, + { expr1: "g - r", expr2: "g != r", equality: false }, + { expr1: "g - r", expr2: "g > r", equality: false }, + { expr1: "g - r", expr2: "g < r", equality: false }, + { expr1: "g - r", expr2: "g >= r", equality: false }, + { expr1: "g - r", expr2: "g <= r", equality: false }, + { expr1: "g - r", expr2: "g == r", equality: false }, + { expr1: "g - r", expr2: "g && r", equality: false }, + { expr1: "g - r", expr2: "g || r", equality: false }, + { expr1: "g * r", expr2: "g * r", equality: true }, + { expr1: "g * r", expr2: "r * g", equality: false }, + { expr1: "g * r", expr2: "g / r", equality: false }, + { expr1: "g * r", expr2: "g % r", equality: false }, + { expr1: "g * r", expr2: "g >> r", equality: false }, + { expr1: "g * r", expr2: "g << r", equality: false }, + { expr1: "g * r", expr2: "g & r", equality: false }, + { expr1: "g * r", expr2: "g | r", equality: false }, + { expr1: "g * r", expr2: "g ^ r", equality: false }, + { expr1: "g * r", expr2: "g != r", equality: false }, + { expr1: "g * r", expr2: "g > r", equality: false }, + { expr1: "g * r", expr2: "g < r", equality: false }, + { expr1: "g * r", expr2: "g >= r", equality: false }, + { expr1: "g * r", expr2: "g <= r", equality: false }, + { expr1: "g * r", expr2: "g == r", equality: false }, + { expr1: "g * r", expr2: "g && r", equality: false }, + { expr1: "g * r", expr2: "g || r", equality: false }, + { expr1: "g / r", expr2: "g / r", equality: true }, + { expr1: "g / r", expr2: "r / g", equality: false }, + { expr1: "g / r", expr2: "g % r", equality: false }, + { expr1: "g / r", expr2: "g >> r", equality: false }, + { expr1: "g / r", expr2: "g << r", equality: false }, + { expr1: "g / r", expr2: "g & r", equality: false }, + { expr1: "g / r", expr2: "g | r", equality: false }, + { expr1: "g / r", expr2: "g ^ r", equality: false }, + { expr1: "g / r", expr2: "g != r", equality: false }, + { expr1: "g / r", expr2: "g > r", equality: false }, + { expr1: "g / r", expr2: "g < r", equality: false }, + { expr1: "g / r", expr2: "g >= r", equality: false }, + { expr1: "g / r", expr2: "g <= r", equality: false }, + { expr1: "g / r", expr2: "g == r", equality: false }, + { expr1: "g / r", expr2: "g && r", equality: false }, + { expr1: "g / r", expr2: "g || r", equality: false }, + { expr1: "g % r", expr2: "g % r", equality: true }, + { expr1: "g % r", expr2: "r % g", equality: false }, + { expr1: "g % r", expr2: "g >> r", equality: false }, + { expr1: "g % r", expr2: "g << r", equality: false }, + { expr1: "g % r", expr2: "g & r", equality: false }, + { expr1: "g % r", expr2: "g | r", equality: false }, + { expr1: "g % r", expr2: "g ^ r", equality: false }, + { expr1: "g % r", expr2: "g != r", equality: false }, + { expr1: "g % r", expr2: "g > r", equality: false }, + { expr1: "g % r", expr2: "g < r", equality: false }, + { expr1: "g % r", expr2: "g >= r", equality: false }, + { expr1: "g % r", expr2: "g <= r", equality: false }, + { expr1: "g % r", expr2: "g == r", equality: false }, + { expr1: "g % r", expr2: "g && r", equality: false }, + { expr1: "g % r", expr2: "g || r", equality: false }, + { expr1: "g >> r", expr2: "g >> r", equality: true }, + { expr1: "g >> r", expr2: "r >> g", equality: false }, + { expr1: "g >> r", expr2: "g << r", equality: false }, + { expr1: "g >> r", expr2: "g & r", equality: false }, + { expr1: "g >> r", expr2: "g | r", equality: false }, + { expr1: "g >> r", expr2: "g ^ r", equality: false }, + { expr1: "g >> r", expr2: "g != r", equality: false }, + { expr1: "g >> r", expr2: "g > r", equality: false }, + { expr1: "g >> r", expr2: "g < r", equality: false }, + { expr1: "g >> r", expr2: "g >= r", equality: false }, + { expr1: "g >> r", expr2: "g <= r", equality: false }, + { expr1: "g >> r", expr2: "g == r", equality: false }, + { expr1: "g >> r", expr2: "g && r", equality: false }, + { expr1: "g >> r", expr2: "g || r", equality: false }, + { expr1: "g << r", expr2: "g << r", equality: true }, + { expr1: "g << r", expr2: "r << g", equality: false }, + { expr1: "g << r", expr2: "g & r", equality: false }, + { expr1: "g << r", expr2: "g | r", equality: false }, + { expr1: "g << r", expr2: "g ^ r", equality: false }, + { expr1: "g << r", expr2: "g != r", equality: false }, + { expr1: "g << r", expr2: "g > r", equality: false }, + { expr1: "g << r", expr2: "g < r", equality: false }, + { expr1: "g << r", expr2: "g >= r", equality: false }, + { expr1: "g << r", expr2: "g <= r", equality: false }, + { expr1: "g << r", expr2: "g == r", equality: false }, + { expr1: "g << r", expr2: "g && r", equality: false }, + { expr1: "g << r", expr2: "g || r", equality: false }, + { expr1: "g & r", expr2: "g & r", equality: true }, + { expr1: "g & r", expr2: "r & g", equality: false }, + { expr1: "g & r", expr2: "g | r", equality: false }, + { expr1: "g & r", expr2: "g ^ r", equality: false }, + { expr1: "g & r", expr2: "g != r", equality: false }, + { expr1: "g & r", expr2: "g > r", equality: false }, + { expr1: "g & r", expr2: "g < r", equality: false }, + { expr1: "g & r", expr2: "g >= r", equality: false }, + { expr1: "g & r", expr2: "g <= r", equality: false }, + { expr1: "g & r", expr2: "g == r", equality: false }, + { expr1: "g & r", expr2: "g && r", equality: false }, + { expr1: "g & r", expr2: "g || r", equality: false }, + { expr1: "g | r", expr2: "g | r", equality: true }, + { expr1: "g | r", expr2: "r | g", equality: false }, + { expr1: "g | r", expr2: "g ^ r", equality: false }, + { expr1: "g | r", expr2: "g != r", equality: false }, + { expr1: "g | r", expr2: "g > r", equality: false }, + { expr1: "g | r", expr2: "g < r", equality: false }, + { expr1: "g | r", expr2: "g >= r", equality: false }, + { expr1: "g | r", expr2: "g <= r", equality: false }, + { expr1: "g | r", expr2: "g == r", equality: false }, + { expr1: "g | r", expr2: "g && r", equality: false }, + { expr1: "g | r", expr2: "g || r", equality: false }, + { expr1: "g ^ r", expr2: "g ^ r", equality: true }, + { expr1: "g ^ r", expr2: "r ^ g", equality: false }, + { expr1: "g ^ r", expr2: "g != r", equality: false }, + { expr1: "g ^ r", expr2: "g > r", equality: false }, + { expr1: "g ^ r", expr2: "g < r", equality: false }, + { expr1: "g ^ r", expr2: "g >= r", equality: false }, + { expr1: "g ^ r", expr2: "g <= r", equality: false }, + { expr1: "g ^ r", expr2: "g == r", equality: false }, + { expr1: "g ^ r", expr2: "g && r", equality: false }, + { expr1: "g ^ r", expr2: "g || r", equality: false }, + { expr1: "g != r", expr2: "g != r", equality: true }, + { expr1: "g != r", expr2: "r != g", equality: false }, + { expr1: "g != r", expr2: "g > r", equality: false }, + { expr1: "g != r", expr2: "g < r", equality: false }, + { expr1: "g != r", expr2: "g >= r", equality: false }, + { expr1: "g != r", expr2: "g <= r", equality: false }, + { expr1: "g != r", expr2: "g == r", equality: false }, + { expr1: "g != r", expr2: "g && r", equality: false }, + { expr1: "g != r", expr2: "g || r", equality: false }, + { expr1: "g > r", expr2: "g > r", equality: true }, + { expr1: "g > r", expr2: "r > g", equality: false }, + { expr1: "g > r", expr2: "g < r", equality: false }, + { expr1: "g > r", expr2: "g >= r", equality: false }, + { expr1: "g > r", expr2: "g <= r", equality: false }, + { expr1: "g > r", expr2: "g == r", equality: false }, + { expr1: "g > r", expr2: "g && r", equality: false }, + { expr1: "g > r", expr2: "g || r", equality: false }, + { expr1: "g < r", expr2: "g < r", equality: true }, + { expr1: "g < r", expr2: "r < g", equality: false }, + { expr1: "g < r", expr2: "g >= r", equality: false }, + { expr1: "g < r", expr2: "g <= r", equality: false }, + { expr1: "g < r", expr2: "g == r", equality: false }, + { expr1: "g < r", expr2: "g && r", equality: false }, + { expr1: "g < r", expr2: "g || r", equality: false }, + { expr1: "g >= r", expr2: "g >= r", equality: true }, + { expr1: "g >= r", expr2: "r >= g", equality: false }, + { expr1: "g >= r", expr2: "g <= r", equality: false }, + { expr1: "g >= r", expr2: "g == r", equality: false }, + { expr1: "g >= r", expr2: "g && r", equality: false }, + { expr1: "g >= r", expr2: "g || r", equality: false }, + { expr1: "g <= r", expr2: "g <= r", equality: true }, + { expr1: "g <= r", expr2: "r <= g", equality: false }, + { expr1: "g <= r", expr2: "g == r", equality: false }, + { expr1: "g <= r", expr2: "g && r", equality: false }, + { expr1: "g <= r", expr2: "g || r", equality: false }, + { expr1: "g == r", expr2: "g == r", equality: true }, + { expr1: "g == r", expr2: "r == g", equality: false }, + { expr1: "g == r", expr2: "g && r", equality: false }, + { expr1: "g == r", expr2: "g || r", equality: false }, + { expr1: "g && r", expr2: "g && r", equality: true }, + { expr1: "g && r", expr2: "r && g", equality: false }, + { expr1: "g && r", expr2: "g || r", equality: false }, + { expr1: "g || r", expr2: "g || r", equality: true }, + { expr1: "g || r", expr2: "r || g", equality: false }, +]; + +const conditionalExpressions: Test[] = [ + { expr1: "g ? a : b", expr2: "g ? a : b", equality: true }, + { expr1: "g ? a : b", expr2: "g ? b : a", equality: false }, + { expr1: "g ? a : b", expr2: "b ? g : a", equality: false }, + { expr1: "g ? a : b", expr2: "b ? a : g", equality: false }, + { expr1: "g ? a : b", expr2: "a ? b : g", equality: false }, + { expr1: "g ? a : b", expr2: "a ? g : b", equality: false }, + { expr1: "g ? a : b", expr2: "g", equality: false }, + { expr1: "g ? a : b", expr2: "b", equality: false }, + { expr1: "g ? a : b", expr2: "a", equality: false }, +]; + +const structExpressions: Test[] = [ + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f1: a, f2: b}", + equality: true, + }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test2 {f1: a, f2: b}", + equality: false, + }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f3: a, f2: b}", + equality: false, + }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f1: a, f3: b}", + equality: false, + }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f1: c, f2: b}", + equality: false, + }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f1: a, f2: c}", + equality: false, + }, + { expr1: "Test {f1: a, f2: b}", expr2: "Test {f1: a}", equality: false }, + { + expr1: "Test {f1: a, f2: b}", + expr2: "Test {f1: a, f2: b, f3: c}", + equality: false, + }, + { expr1: "Test {f1: a, f2: b}", expr2: "Test", equality: false }, + { expr1: "Test {f1: a, f2: b}", expr2: "f1", equality: false }, + { expr1: "Test {f1: a, f2: b}", expr2: "f2", equality: false }, + { expr1: "Test {f1: a, f2: b}", expr2: "a", equality: false }, + { expr1: "Test {f1: a, f2: b}", expr2: "b", equality: false }, +]; + +const fieldAccessExpressions: Test[] = [ + { expr1: "s.a", expr2: "s.a", equality: true }, + { expr1: "s.a", expr2: "s.a(0)", equality: false }, + { expr1: "s.a", expr2: "a(0)", equality: false }, + { expr1: "s.a", expr2: "a", equality: false }, + { expr1: "s.a", expr2: "s.a.a", equality: false }, + { expr1: "s.a", expr2: "Test {a: e1, b: e2}.a", equality: false }, + { expr1: "s.a.a", expr2: "s.a.a", equality: true }, + { expr1: "s.a.a", expr2: "s.a.a(0)", equality: false }, + { expr1: "s.a.a", expr2: "s.a(0)", equality: false }, + { expr1: "s.a.a", expr2: "a(0)", equality: false }, + { expr1: "s.a.a", expr2: "a", equality: false }, + { expr1: "s.a.a", expr2: "Test {a: e1, b: e2}.a", equality: false }, + { + expr1: "Test {a: e1, b: e2}.a", + expr2: "Test {a: e1, b: e2}.a", + equality: true, + }, + { expr1: "Test {a: e1, b: e2}.a", expr2: "a", equality: false }, + { expr1: "Test {a: e1, b: e2}.a", expr2: "s.a", equality: false }, + { expr1: "Test {a: e1, b: e2}.a", expr2: "s.a(0)", equality: false }, + { expr1: "Test {a: e1, b: e2}.a", expr2: "a(0)", equality: false }, + { expr1: "Test {a: e1, b: e2}.a", expr2: "s.a.a", equality: false }, + { + expr1: "Test {a: e1, b: e2}.a", + expr2: "Test {a: e1, b: e2}.b", + equality: false, + }, +]; + +const initOfExpressions: Test[] = [ + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b,c,d)", equality: true }, + { expr1: "initOf a(b,c,d)", expr2: "initOf g(b,c,d)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(f,c,d)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b,f,d)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b,c,f)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b,c)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "initOf a(b,c,d,e)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "a(b,c,d)", equality: false }, + { expr1: "initOf a(b,c,d)", expr2: "s.a(b,c,d)", equality: false }, +]; + +function testEquality(expr1: string, expr2: string, equal: boolean) { + expect(eqExpressions(parseExpression(expr1), parseExpression(expr2))).toBe( + equal, + ); +} + +describe("expression-equality", () => { + beforeEach(() => { + __DANGER_resetNodeId(); + }); + it("should correctly determine if two expressions involving values are equal or not.", () => { + valueExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving function calls are equal or not.", () => { + functionCallExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving unary operators are equal or not.", () => { + unaryOpExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving binary operators are equal or not.", () => { + binaryOpExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving conditionals are equal or not.", () => { + conditionalExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving structs are equal or not.", () => { + structExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving field accesses are equal or not.", () => { + fieldAccessExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); + it("should correctly determine if two expressions involving initOf are equal or not.", () => { + initOfExpressions.forEach((test) => { + testEquality(test.expr1, test.expr2, test.equality); + }); + }); +}); diff --git a/src/grammar/test/expr-is-value.spec.ts b/src/grammar/test/expr-is-value.spec.ts new file mode 100644 index 000000000..f26a23535 --- /dev/null +++ b/src/grammar/test/expr-is-value.spec.ts @@ -0,0 +1,72 @@ +//type Test = { expr: string; isValue: boolean }; + +import { __DANGER_resetNodeId, isValue } from "../ast"; +import { parseExpression } from "../grammar"; + +const valueExpressions: string[] = [ + "1", + "true", + "false", + '"one"', + "null", + "Test {f1: 0, f2: true}", + "Test {f1: 0, f2: true, f3: null}", + "Test {f1: Test2 {c:0}, f2: true}", +]; + +const notValueExpressions: string[] = [ + "g", + "Test {f1: 0, f2: b}", + "Test {f1: a, f2: true}", + "f(1)", + "f(1,4)", + "s.f(1,4)", + "+4", + "-4", + "!true", + "g!!", + "~6", + "0 + 1", + "0 - 1", + "0 * 2", + "1 / 3", + "2 % 4", + "10 >> 2", + "10 << 2", + "10 & 4", + "10 | 4", + "10 ^ 4", + "10 != 4", + "10 > 3", + "10 < 3", + "10 >= 5", + "10 <= 2", + "10 == 7", + "true && false", + "true || false", + "true ? 0 : 1", + "s.a", + "s.a.a", + "Test {a: 0, b: 1}.a", + "initOf a(0,1,null)", +]; + +function testIsValue(expr: string, testResult: boolean) { + expect(isValue(parseExpression(expr))).toBe(testResult); +} + +describe("expression-is-value", () => { + beforeEach(() => { + __DANGER_resetNodeId(); + }); + valueExpressions.forEach((test) => { + it(`should correctly determine that '${test}' is a value expression.`, () => { + testIsValue(test, true); + }); + }); + notValueExpressions.forEach((test) => { + it(`should correctly determine that '${test}' is NOT a value expression.`, () => { + testIsValue(test, false); + }); + }); +}); diff --git a/src/grammar/test/partial-eval.spec.ts b/src/grammar/test/partial-eval.spec.ts new file mode 100644 index 000000000..891ae1566 --- /dev/null +++ b/src/grammar/test/partial-eval.spec.ts @@ -0,0 +1,206 @@ +import { + AstExpression, + AstValue, + __DANGER_resetNodeId, + cloneAstNode, + eqExpressions, + isValue, +} from "../ast"; +import { parseExpression } from "../grammar"; +import { extractValue, makeValueExpression } from "../../optimizer/util"; +import { evalUnaryOp, partiallyEvalExpression } from "../../constEval"; +import { CompilerContext } from "../../context"; + +const additiveExpressions = [ + { original: "X + 3 + 1", simplified: "X + 4" }, + { original: "3 + X + 1", simplified: "X + 4" }, + { original: "1 + (X + 3)", simplified: "X + 4" }, + { original: "1 + (3 + X)", simplified: "4 + X" }, + + // Should NOT simplify to X + 2, because X could be MAX - 2, + // so that X + 3 causes an overflow, but X + 2 does not overflow + { original: "X + 3 - 1", simplified: "X + 3 - 1" }, + { original: "3 + X - 1", simplified: "3 + X - 1" }, + + // Should NOT simplify to X - 2, because X could be MIN + 2 + { original: "1 + (X - 3)", simplified: "1 + (X - 3)" }, + + { original: "1 + (3 - X)", simplified: "4 - X" }, + + { original: "X + 3 - (-1)", simplified: "X + 4" }, + { original: "3 + X - (-1)", simplified: "X + 4" }, + + // Should NOT simplify, because the current rules require that - commutes, + // which does not. This could be fixed in future rules. + { original: "-1 + (X - 3)", simplified: "-1 + (X - 3)" }, + + // Should NOT simplify to 2 - X, because X could be MIN + 3, + // so that 3 - X = -MIN = MAX + 1 causes an overflow, + // but 2 - X = -MIN - 1 = MAX does not + { original: "-1 + (3 - X)", simplified: "-1 + (3 - X)" }, + + // All the following cases should NOT simplify because - + // does not associate on the left with - or +. + // The following "associative rule" for - will be added in the future: + // (x - c1) op c2 -----> x + (-c1 op c2), where op \in {-,+} + { original: "1 - (X + 3)", simplified: "1 - (X + 3)" }, + { original: "1 - (3 + X)", simplified: "1 - (3 + X)" }, + { original: "1 - X + 3", simplified: "1 - X + 3" }, + { original: "X - 1 + 3", simplified: "X - 1 + 3" }, + { original: "1 - (X - 3)", simplified: "1 - (X - 3)" }, + { original: "1 - (3 - X)", simplified: "1 - (3 - X)" }, + { original: "1 - X - 3", simplified: "1 - X - 3" }, + { original: "X - 1 - 3", simplified: "X - 1 - 3" }, +]; + +const multiplicativeExpressions = [ + { original: "X * 3 * 2", simplified: "X * 6" }, + { original: "3 * X * 2", simplified: "X * 6" }, + { original: "2 * (X * 3)", simplified: "X * 6" }, + { original: "2 * (3 * X)", simplified: "6 * X" }, + + { original: "X * -3 * -2", simplified: "X * 6" }, + { original: "-3 * X * -2", simplified: "X * 6" }, + { original: "-2 * (X * -3)", simplified: "X * 6" }, + { original: "-2 * (-3 * X)", simplified: "6 * X" }, + + // The following 4 cases should NOT simplify to X * 0. + // the reason is that X could be MAX, so that X*3 causes + // an overflow, but X*0 does not. + { original: "X * 3 * 0", simplified: "X * 3 * 0" }, + { original: "3 * X * 0", simplified: "3 * X * 0" }, + { original: "0 * (X * 3)", simplified: "0 * (X * 3)" }, + { original: "0 * (3 * X)", simplified: "0 * (3 * X)" }, + + { original: "X * 0 * 3", simplified: "X * 0" }, + { original: "0 * X * 3", simplified: "X * 0" }, + { original: "3 * (X * 0)", simplified: "X * 0" }, + { original: "3 * (0 * X)", simplified: "0 * X" }, + + // This expression cannot be further simplified to X, + // because X could be MIN, so that X * -1 causes an overflow + { original: "X * -1 * 1 * -1", simplified: "X * -1 * -1" }, + + // This expression could be further simplified to X * -1 + // but, currently, there are no rules that reduce three multiplied -1 + // to a single -1. This should be fixed in the future. + { original: "X * -1 * 1 * -1 * -1", simplified: "X * -1 * -1 * -1" }, + + // Even though, X * -1 * 1 * -1 cannot be simplified to X, + // when we multiply with a number with absolute value bigger than 1, + // we ensure that the overflows are preserved, so that we can simplify + // the expression. + { original: "X * -1 * 1 * -1 * 2", simplified: "X * 2" }, + + // Should NOT simplify to X * 2, because X could be MIN/2 = -2^255, + // so that X * -2 = 2^256 = MAX + 1 causes an overflow, + // but X * 2 = -2^256 does not. + { original: "X * -2 * -1", simplified: "X * -2 * -1" }, + + // Note however that multiplying first by -1 allow us + // to simplify the expression, because if X * -1 overflows, + // X * 2 will also. + { original: "X * -1 * -2", simplified: "X * 2" }, +]; + +function testExpression(original: string, simplified: string) { + expect( + eqExpressions( + partiallyEvalExpression( + parseExpression(original), + new CompilerContext(), + ), + unaryNegNodesToNumbers(parseExpression(simplified)), + ), + ).toBe(true); +} + +// Evaluates UnaryOp nodes with operator - into a single a node having a value. +// The reason for doing this is that the partial evaluator will transform negative +// numbers in an expression, e.g., "-1" into a tree with a single node with value -1, so that +// when comparing the tree with those produced by the parser, the two trees +// do not match, because the parser will produce a UnaryOp node with a child node with value 1. +// This is so because Tact does not have a way to write negative literals, but indirectly trough +// the use of the unary - operator. +function unaryNegNodesToNumbers(ast: AstExpression): AstExpression { + let newNode: AstExpression; + switch (ast.kind) { + case "null": + return ast; + case "boolean": + return ast; + case "number": + return ast; + case "string": + return ast; + case "id": + return ast; + case "method_call": + newNode = cloneAstNode(ast); + newNode.args = ast.args.map(unaryNegNodesToNumbers); + newNode.self = unaryNegNodesToNumbers(ast.self); + return newNode; + case "init_of": + newNode = cloneAstNode(ast); + newNode.args = ast.args.map(unaryNegNodesToNumbers); + return newNode; + case "op_unary": + if (ast.op === "-") { + if (isValue(ast.operand)) { + return makeValueExpression( + evalUnaryOp( + ast.op, + extractValue(ast.operand as AstValue), + ), + ); + } + } + newNode = cloneAstNode(ast); + newNode.operand = unaryNegNodesToNumbers(ast.operand); + return newNode; + case "op_binary": + newNode = cloneAstNode(ast); + newNode.left = unaryNegNodesToNumbers(ast.left); + newNode.right = unaryNegNodesToNumbers(ast.right); + return newNode; + case "conditional": + newNode = cloneAstNode(ast); + newNode.thenBranch = unaryNegNodesToNumbers(ast.thenBranch); + newNode.elseBranch = unaryNegNodesToNumbers(ast.elseBranch); + return newNode; + case "struct_instance": + newNode = cloneAstNode(ast); + newNode.args = ast.args.map((param) => { + const newParam = cloneAstNode(param); + newParam.initializer = unaryNegNodesToNumbers( + param.initializer, + ); + return newParam; + }); + return newNode; + case "field_access": + newNode = cloneAstNode(ast); + newNode.aggregate = unaryNegNodesToNumbers(ast.aggregate); + return newNode; + case "static_call": + newNode = cloneAstNode(ast); + newNode.args = ast.args.map(unaryNegNodesToNumbers); + return newNode; + } +} + +describe("partial-evaluator", () => { + beforeEach(() => { + __DANGER_resetNodeId(); + }); + it("should correctly simplify partial expressions involving + and -", () => { + additiveExpressions.forEach((pair) => { + testExpression(pair.original, pair.simplified); + }); + }); + it("should correctly simplify partial expressions involving *", () => { + multiplicativeExpressions.forEach((pair) => { + testExpression(pair.original, pair.simplified); + }); + }); +}); diff --git a/src/optimizer/associative.ts b/src/optimizer/associative.ts new file mode 100644 index 000000000..f3fa9c578 --- /dev/null +++ b/src/optimizer/associative.ts @@ -0,0 +1,762 @@ +// This module includes rules involving associative rewrites of expressions + +import { evalBinaryOp } from "../constEval"; +import { + AstBinaryOperation, + AstExpression, + AstOpBinary, + AstValue, + isValue, +} from "../grammar/ast"; +import { Value } from "../types/types"; +import { ExpressionTransformer, Rule } from "./types"; +import { + abs, + checkIsBinaryOpNode, + checkIsBinaryOp_NonValue_Value, + checkIsBinaryOp_Value_NonValue, + extractValue, + makeBinaryExpression, + makeValueExpression, + sign, +} from "./util"; + +abstract class AssociativeRewriteRule extends Rule { + // An entry (op, S) in the map means "operator op associates with all operators in set S", + // mathematically: all op2 \in S. (a op b) op2 c = a op (b op2 c) + private associativeOps: Map>; + + // This set contains all operators that commute. + // Mathematically: all op \in commutativeOps. a op b = b op a + private commutativeOps: Set; + + constructor() { + super(); + + // + associates with these on the right: + // i.e., all op \in plusAssoc. (a + b) op c = a + (b op c) + const additiveAssoc: Set = new Set(["+", "-"]); + + // - does not associate with any operator on the right + + // * associates with these on the right: + const multiplicativeAssoc: Set = new Set([ + "*", + "<<", + ]); + + // Division / does not associate with any on the right + + // Modulus % does not associate with any on the right + + // TODO: shifts, bitwise integer operators, boolean operators + + this.associativeOps = new Map([ + ["+", additiveAssoc], + ["*", multiplicativeAssoc], + ]); + + this.commutativeOps = new Set( + ["+", "*", "!=", "==", "&&", "||"], // TODO: bitwise integer operators + ); + } + + public areAssociative( + op1: AstBinaryOperation, + op2: AstBinaryOperation, + ): boolean { + if (this.associativeOps.has(op1)) { + const rightOperators = this.associativeOps.get(op1)!; + return rightOperators.has(op2); + } else { + return false; + } + } + + public isCommutative(op: AstBinaryOperation): boolean { + return this.commutativeOps.has(op); + } +} + +abstract class AllowableOpRule extends AssociativeRewriteRule { + private allowedOps: Set; + + constructor() { + super(); + + this.allowedOps = new Set( + // Recall that integer operators +,-,*,/,% are not safe with this rule, because + // there is a risk that they will not preserve overflows in the unknown operands. + ["&&", "||"], // TODO: check bitwise integer operators + ); + } + + public isAllowedOp(op: AstBinaryOperation): boolean { + return this.allowedOps.has(op); + } + + public areAllowedOps(op: AstBinaryOperation[]): boolean { + return op.reduce( + (prev, curr) => prev && this.allowedOps.has(curr), + true, + ); + } +} + +export class AssociativeRule1 extends AllowableOpRule { + public applyRule( + ast: AstExpression, + optimizer: ExpressionTransformer, + ): AstExpression { + if (checkIsBinaryOpNode(ast)) { + const topLevelNode = ast as AstOpBinary; + if ( + checkIsBinaryOp_NonValue_Value(topLevelNode.left) && + checkIsBinaryOp_NonValue_Value(topLevelNode.right) + ) { + // The tree has this form: + // (x1 op1 c1) op (x2 op2 c2) + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = leftTree.left; + const c1 = leftTree.right as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree.left; + const c2 = rightTree.right as AstValue; + const op2 = rightTree.op; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op1 and op associate + // op and op2 associate + // op commutes + if ( + this.areAllowedOps([op1, op, op2]) && + this.areAssociative(op1, op) && + this.areAssociative(op, op2) && + this.isCommutative(op) + ) { + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp( + op2, + extractValue(c1), + extractValue(c2), + ); + + // The final expression is + // (x1 op1 x2) op val + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newLeft = optimizer.applyRules( + makeBinaryExpression(op1, x1, x2), + ); + const newRight = makeValueExpression(val); + return makeBinaryExpression(op, newLeft, newRight); + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } + } else if ( + checkIsBinaryOp_NonValue_Value(topLevelNode.left) && + checkIsBinaryOp_Value_NonValue(topLevelNode.right) + ) { + // The tree has this form: + // (x1 op1 c1) op (c2 op2 x2) + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = leftTree.left; + const c1 = leftTree.right as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree.right; + const c2 = rightTree.left as AstValue; + const op2 = rightTree.op; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op1 and op associate + // op and op2 associate + if ( + this.areAllowedOps([op1, op, op2]) && + this.areAssociative(op1, op) && + this.areAssociative(op, op2) + ) { + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp( + op, + extractValue(c1), + extractValue(c2), + ); + + // The current expression could be either + // x1 op1 (val op2 x2) or + // (x1 op1 val) op2 x2 <--- we choose this form. + // Other rules will attempt to extract the constant outside the expression. + + // Because we are joining x1 and val, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newValNode = makeValueExpression(val); + const newLeft = optimizer.applyRules( + makeBinaryExpression(op1, x1, newValNode), + ); + return makeBinaryExpression(op2, newLeft, x2); + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } + } else if ( + checkIsBinaryOp_Value_NonValue(topLevelNode.left) && + checkIsBinaryOp_NonValue_Value(topLevelNode.right) + ) { + // The tree has this form: + // (c1 op1 x1) op (x2 op2 c2) + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = leftTree.right; + const c1 = leftTree.left as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree.left; + const c2 = rightTree.right as AstValue; + const op2 = rightTree.op; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op and op1 associate + // op2 and op associate + // op commutes + if ( + this.areAllowedOps([op1, op, op2]) && + this.areAssociative(op, op1) && + this.areAssociative(op2, op) && + this.isCommutative(op) + ) { + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp( + op, + extractValue(c2), + extractValue(c1), + ); + + // The current expression could be either + // x2 op2 (val op1 x1) or + // (x2 op2 val) op1 x1 <--- we choose this form. + // Other rules will attempt to extract the constant outside the expression. + + // Because we are joining x2 and val, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newValNode = makeValueExpression(val); + const newLeft = optimizer.applyRules( + makeBinaryExpression(op2, x2, newValNode), + ); + return makeBinaryExpression(op1, newLeft, x1); + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } + } else if ( + checkIsBinaryOp_Value_NonValue(topLevelNode.left) && + checkIsBinaryOp_Value_NonValue(topLevelNode.right) + ) { + // The tree has this form: + // (c1 op1 x1) op (c2 op2 x2) + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = leftTree.right; + const c1 = leftTree.left as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree.right; + const c2 = rightTree.left as AstValue; + const op2 = rightTree.op; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op1 and op associate + // op and op2 associate + // op commutes + if ( + this.areAllowedOps([op1, op, op2]) && + this.areAssociative(op1, op) && + this.areAssociative(op, op2) && + this.isCommutative(op) + ) { + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp( + op1, + extractValue(c1), + extractValue(c2), + ); + + // The final expression is + // val op (x1 op2 x2) + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newRight = optimizer.applyRules( + makeBinaryExpression(op2, x1, x2), + ); + const newLeft = makeValueExpression(val); + return makeBinaryExpression(op, newLeft, newRight); + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } + } + } + + // If execution reaches here, it means that the rule could not be applied fully + // so, we return the original tree + return ast; + } +} + +export class AssociativeRule2 extends AllowableOpRule { + public applyRule( + ast: AstExpression, + optimizer: ExpressionTransformer, + ): AstExpression { + if (checkIsBinaryOpNode(ast)) { + const topLevelNode = ast as AstOpBinary; + if ( + checkIsBinaryOp_NonValue_Value(topLevelNode.left) && + !isValue(topLevelNode.right) + ) { + // The tree has this form: + // (x1 op1 c1) op x2 + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right; + + const x1 = leftTree.left; + const c1 = leftTree.right as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op1 and op associate + // op commutes + if ( + this.areAllowedOps([op1, op]) && + this.areAssociative(op1, op) && + this.isCommutative(op) + ) { + // The final expression is + // (x1 op1 x2) op c1 + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newLeft = optimizer.applyRules( + makeBinaryExpression(op1, x1, x2), + ); + return makeBinaryExpression(op, newLeft, c1); + } + } else if ( + checkIsBinaryOp_Value_NonValue(topLevelNode.left) && + !isValue(topLevelNode.right) + ) { + // The tree has this form: + // (c1 op1 x1) op x2 + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right; + + const x1 = leftTree.right; + const c1 = leftTree.left as AstValue; + const op1 = leftTree.op; + + const x2 = rightTree; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op1 and op associate + if ( + this.areAllowedOps([op1, op]) && + this.areAssociative(op1, op) + ) { + // The final expression is + // c1 op1 (x1 op x2) + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newRight = optimizer.applyRules( + makeBinaryExpression(op, x1, x2), + ); + return makeBinaryExpression(op1, c1, newRight); + } + } else if ( + !isValue(topLevelNode.left) && + checkIsBinaryOp_NonValue_Value(topLevelNode.right) + ) { + // The tree has this form: + // x2 op (x1 op1 c1) + const leftTree = topLevelNode.left; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = rightTree.left; + const c1 = rightTree.right as AstValue; + const op1 = rightTree.op; + + const x2 = leftTree; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op and op1 associate + if ( + this.areAllowedOps([op, op1]) && + this.areAssociative(op, op1) + ) { + // The final expression is + // (x2 op x1) op1 c1 + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newLeft = optimizer.applyRules( + makeBinaryExpression(op, x2, x1), + ); + return makeBinaryExpression(op1, newLeft, c1); + } + } else if ( + !isValue(topLevelNode.left) && + checkIsBinaryOp_Value_NonValue(topLevelNode.right) + ) { + // The tree has this form: + // x2 op (c1 op1 x1) + const leftTree = topLevelNode.left; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = rightTree.right; + const c1 = rightTree.left as AstValue; + const op1 = rightTree.op; + + const x2 = leftTree; + + const op = topLevelNode.op; + + // Check that: + // the operators are allowed + // op and op1 associate + // op is commutative + if ( + this.areAllowedOps([op, op1]) && + this.areAssociative(op, op1) && + this.isCommutative(op) + ) { + // The final expression is + // c1 op (x2 op1 x1) + + // Because we are joining x1 and x2, + // there is further opportunity of simplification, + // So, we ask the evaluator to apply all the rules in the subtree. + const newRight = optimizer.applyRules( + makeBinaryExpression(op1, x2, x1), + ); + return makeBinaryExpression(op, c1, newRight); + } + } + } + + // If execution reaches here, it means that the rule could not be applied fully + // so, we return the original tree + return ast; + } +} + +function ensureInt(val: Value): bigint { + if (typeof val !== "bigint") { + throw new Error(`integer expected`); + } + return val; +} + +export class AssociativeRule3 extends AssociativeRewriteRule { + private extraOpCondition: Map< + AstBinaryOperation, + (c1: Value, c2: Value, val: Value) => boolean + >; + + public constructor() { + super(); + + this.extraOpCondition = new Map([ + [ + "+", + (c1, c2, val) => { + const n1 = ensureInt(c1); + const res = ensureInt(val); + return sign(n1) === sign(res) && abs(n1) <= abs(res); + }, + ], + + [ + "-", + (c1, c2, val) => { + const n1 = ensureInt(c1); + const res = ensureInt(val); + return sign(n1) === sign(res) && abs(n1) <= abs(res); + }, + ], + + [ + "*", + (c1, c2, val) => { + const n1 = ensureInt(c1); + const res = ensureInt(val); + if (n1 < 0n) { + if (sign(n1) === sign(res)) { + return abs(n1) <= abs(res); + } else { + return abs(n1) < abs(res); + } + } else if (n1 === 0n) { + return true; + } else { + return abs(n1) <= abs(res); + } + }, + ], + ]); + } + + protected opSatisfiesConditions( + op: AstBinaryOperation, + c1: Value, + c2: Value, + res: Value, + ): boolean { + if (this.extraOpCondition.has(op)) { + return this.extraOpCondition.get(op)!(c1, c2, res); + } else { + return false; + } + } + + public applyRule( + ast: AstExpression, + optimizer: ExpressionTransformer, + ): AstExpression { + if (checkIsBinaryOpNode(ast)) { + const topLevelNode = ast as AstOpBinary; + if ( + checkIsBinaryOp_NonValue_Value(topLevelNode.left) && + isValue(topLevelNode.right) + ) { + // The tree has this form: + // (x1 op1 c1) op c2 + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstValue; + + const x1 = leftTree.left; + const c1 = extractValue(leftTree.right as AstValue); + const op1 = leftTree.op; + + const c2 = extractValue(rightTree); + + const op = topLevelNode.op; + + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp(op, c1, c2); + + // Check that: + // op1 and op associate + // the extra conditions on op1 + + if ( + this.areAssociative(op1, op) && + this.opSatisfiesConditions(op1, c1, c2, val) + ) { + // The final expression is + // x1 op1 val + + const newConstant = makeValueExpression(val); + // Since the tree is simpler now, there is further + // opportunity for simplification that was missed + // previously + return optimizer.applyRules( + makeBinaryExpression(op1, x1, newConstant), + ); + } + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } else if ( + checkIsBinaryOp_Value_NonValue(topLevelNode.left) && + isValue(topLevelNode.right) + ) { + // The tree has this form: + // (c1 op1 x1) op c2 + const leftTree = topLevelNode.left as AstOpBinary; + const rightTree = topLevelNode.right as AstValue; + + const x1 = leftTree.right; + const c1 = extractValue(leftTree.left as AstValue); + const op1 = leftTree.op; + + const c2 = extractValue(rightTree); + + const op = topLevelNode.op; + + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp(op, c1, c2); + + // Check that: + // op1 and op associate + // op1 commutes + // the extra conditions on op1 + + if ( + this.areAssociative(op1, op) && + this.isCommutative(op1) && + this.opSatisfiesConditions(op1, c1, c2, val) + ) { + // The final expression is + // x1 op1 val + + const newConstant = makeValueExpression(val); + // Since the tree is simpler now, there is further + // opportunity for simplification that was missed + // previously + return optimizer.applyRules( + makeBinaryExpression(op1, x1, newConstant), + ); + } + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } else if ( + isValue(topLevelNode.left) && + checkIsBinaryOp_NonValue_Value(topLevelNode.right) + ) { + // The tree has this form: + // c2 op (x1 op1 c1) + const leftTree = topLevelNode.left as AstValue; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = rightTree.left; + const c1 = extractValue(rightTree.right as AstValue); + const op1 = rightTree.op; + + const c2 = extractValue(leftTree); + + const op = topLevelNode.op; + + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp(op, c2, c1); + + // Check that: + // op and op1 associate + // op1 commutes + // the extra conditions on op1 + + if ( + this.areAssociative(op, op1) && + this.isCommutative(op1) && + this.opSatisfiesConditions(op1, c1, c2, val) + ) { + // The final expression is + // x1 op1 val + + const newConstant = makeValueExpression(val); + // Since the tree is simpler now, there is further + // opportunity for simplification that was missed + // previously + return optimizer.applyRules( + makeBinaryExpression(op1, x1, newConstant), + ); + } + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } else if ( + isValue(topLevelNode.left) && + checkIsBinaryOp_Value_NonValue(topLevelNode.right) + ) { + // The tree has this form: + // c2 op (c1 op1 x1) + const leftTree = topLevelNode.left as AstValue; + const rightTree = topLevelNode.right as AstOpBinary; + + const x1 = rightTree.right; + const c1 = extractValue(rightTree.left as AstValue); + const op1 = rightTree.op; + + const c2 = extractValue(leftTree); + + const op = topLevelNode.op; + + // Agglutinate the constants and compute their final value + try { + // If an error occurs, we abandon the simplification + const val = evalBinaryOp(op, c2, c1); + + // Check that: + // op and op1 associate + // the extra conditions on op1 + + if ( + this.areAssociative(op, op1) && + this.opSatisfiesConditions(op1, c1, c2, val) + ) { + // The final expression is + // val op1 x1 + + const newConstant = makeValueExpression(val); + // Since the tree is simpler now, there is further + // opportunity for simplification that was missed + // previously + return optimizer.applyRules( + makeBinaryExpression(op1, newConstant, x1), + ); + } + } catch (e) { + // Do nothing: will exit rule without modifying tree + } + } + } + + // If execution reaches here, it means that the rule could not be applied fully + // so, we return the original tree + return ast; + } +} diff --git a/src/optimizer/standardOptimizer.ts b/src/optimizer/standardOptimizer.ts new file mode 100644 index 000000000..a623679a0 --- /dev/null +++ b/src/optimizer/standardOptimizer.ts @@ -0,0 +1,37 @@ +import { AstExpression } from "../grammar/ast"; +import { + AssociativeRule1, + AssociativeRule2, + AssociativeRule3, +} from "./associative"; +import { Rule, ExpressionTransformer } from "./types"; + +type PrioritizedRule = { priority: number; rule: Rule }; + +// This optimizer uses rules that preserve overflows in integer expressions. +export class StandardOptimizer extends ExpressionTransformer { + private rules: PrioritizedRule[]; + + constructor() { + super(); + + this.rules = [ + { priority: 0, rule: new AssociativeRule1() }, + { priority: 1, rule: new AssociativeRule2() }, + { priority: 2, rule: new AssociativeRule3() }, + // TODO: add simpler algebraic rules that will be added to algebraic.ts + ]; + + // Sort according to the priorities: smaller number means greater priority. + // So, the rules will be sorted increasingly according to their priority number. + this.rules.sort((r1, r2) => r1.priority - r2.priority); + } + + public applyRules(ast: AstExpression): AstExpression { + return this.rules.reduce( + (prev, prioritizedRule) => + prioritizedRule.rule.applyRule(prev, this), + ast, + ); + } +} diff --git a/src/optimizer/types.ts b/src/optimizer/types.ts new file mode 100644 index 000000000..bbd871efb --- /dev/null +++ b/src/optimizer/types.ts @@ -0,0 +1,12 @@ +import { AstExpression } from "../grammar/ast"; + +export abstract class ExpressionTransformer { + public abstract applyRules(ast: AstExpression): AstExpression; +} + +export abstract class Rule { + public abstract applyRule( + ast: AstExpression, + optimizer: ExpressionTransformer, + ): AstExpression; +} diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts new file mode 100644 index 000000000..6b1e49dde --- /dev/null +++ b/src/optimizer/util.ts @@ -0,0 +1,141 @@ +import { + AstExpression, + AstUnaryOperation, + AstBinaryOperation, + createAstNode, + AstValue, + isValue, +} from "../grammar/ast"; +import { dummySrcInfo } from "../grammar/grammar"; +import { Value } from "../types/types"; + +export function extractValue(ast: AstValue): Value { + switch ( + ast.kind // Missing structs + ) { + case "null": + return null; + case "boolean": + return ast.value; + case "number": + return ast.value; + case "string": + return ast.value; + } +} + +export function makeValueExpression(value: Value): AstValue { + if (value === null) { + const result = createAstNode({ + kind: "null", + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "string") { + const result = createAstNode({ + kind: "string", + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "bigint") { + const result = createAstNode({ + kind: "number", + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "boolean") { + const result = createAstNode({ + kind: "boolean", + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + throw new Error( + `structs, addresses, cells, and comment values are not supported at the moment.`, + ); +} + +export function makeUnaryExpression( + op: AstUnaryOperation, + operand: AstExpression, +): AstExpression { + const result = createAstNode({ + kind: "op_unary", + op: op, + operand: operand, + loc: dummySrcInfo, + }); + return result as AstExpression; +} + +export function makeBinaryExpression( + op: AstBinaryOperation, + left: AstExpression, + right: AstExpression, +): AstExpression { + const result = createAstNode({ + kind: "op_binary", + op: op, + left: left, + right: right, + loc: dummySrcInfo, + }); + return result as AstExpression; +} + +// Checks if the top level node is a binary op node +export function checkIsBinaryOpNode(ast: AstExpression): boolean { + return ast.kind === "op_binary"; +} + +// Checks if top level node is a binary op node +// with a non-value node on the left and +// value node on the right +export function checkIsBinaryOp_NonValue_Value(ast: AstExpression): boolean { + return ast.kind === "op_binary" + ? !isValue(ast.left) && isValue(ast.right) + : false; +} + +// Checks if top level node is a binary op node +// with a value node on the left and +// non-value node on the right +export function checkIsBinaryOp_Value_NonValue(ast: AstExpression): boolean { + return ast.kind === "op_binary" + ? isValue(ast.left) && !isValue(ast.right) + : false; +} + +// bigint arithmetic + +// precondition: the divisor is not zero +// rounds the division result towards negative infinity +export function divFloor(a: bigint, b: bigint): bigint { + const almostSameSign = a > 0n === b > 0n; + if (almostSameSign) { + return a / b; + } + return a / b + (a % b === 0n ? 0n : -1n); +} + +export function abs(a: bigint): bigint { + return a < 0n ? -a : a; +} + +export function sign(a: bigint): bigint { + if (a === 0n) return 0n; + else return a < 0n ? -1n : 1n; +} + +// precondition: the divisor is not zero +// rounds the result towards negative infinity +// Uses the fact that a / b * b + a % b == a, for all b != 0. +export function modFloor(a: bigint, b: bigint): bigint { + return a - divFloor(a, b) * b; +}