Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
This repository contains code for a language transpiler of smart contracts. It is therefore extremely important to write code that generalizes well and to *not* hardcode specific edge cases.

### Testing
- First write tests that reflect diversely how the language is supposed to work and how not.
- Check out other tests for language features (esp. test_misc.py) to see how tests are written. They define a program, compile it and then test its behavior against expectations, usually using property based tests. It is fine to start with a fixed example which gives faster feedback (PBT tend to time out when the behavior is violated, feel free to abort test generation)
- The entire test suite takes > 20 minutes to execute. Only run it at the end, prefer individual relevant tests before.

### Adding features
For language features
- Make sure to familiarize yourself with the way the visitor pattern is used for basically everything in the repo.
- There are three main steps for the compiler: normalization and optimization (in many individual steps in rewrite/ and optimize/), type inference (in type_inference.py) and transpilation (in compiler.py)
- If you rewrite a feature into existing features, make sure that the existing features are actually supported already. For example, lambda expressions are not supported yet.

For standard library tooling
- All code in std/ and ledger/ is meant to be interpreted by the opshin compiler. Add code there if it is supposed to extend the standard library of the language

Documentation is autogenerated from function comments.
-
7 changes: 5 additions & 2 deletions opshin/fun_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
)
elif isinstance(target, RecordType):
# default: all fields in union are records, so we can safely access CONSTR_ID
attr_source: ClassType = (
target if isinstance(instance.typ, AnyType) else instance.typ
)
node = plt.EqualsInteger(
plt.Apply(instance.typ.attribute("CONSTR_ID"), OVar("x")),
plt.Apply(attr_source.attribute("CONSTR_ID"), OVar("x")),
plt.Integer(target.record.constructor),
)

Expand Down Expand Up @@ -202,7 +205,7 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
)
else:
raise NotImplementedError(
f"Only isinstance for byte, int, Plutus Dataclass types are supported"
f"Only isinstance for byte, int, Plutus Dataclass, List and Dict types are supported"
)


Expand Down
25 changes: 12 additions & 13 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,12 @@ def does_literally_reference_self(arg):
stmt.test
)
# for the time after this assert, the variable has the specialized type
prevtyps.update(self.implement_typechecks(typchecks))
self.implement_typechecks(prevtyps)
wrapped = self.implement_typechecks(typchecks)
prevtyps.update(wrapped)
self.wrapped.extend(wrapped.keys())
if prevtyps:
self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()]
self.implement_typechecks(prevtyps)
return stmts

def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
Expand Down Expand Up @@ -1273,22 +1277,17 @@ def visit_Call(self, node: Call) -> TypedCall:

# might be isinstance
# Subscripts are not allowed in isinstance calls
if (
isinstance(tc.func, Name)
and tc.func.orig_id == "isinstance"
and isinstance(tc.args[1], Subscript)
):
is_isinstance_call = (
isinstance(tc.func, Name) and tc.func.orig_id == "isinstance"
)
if is_isinstance_call and isinstance(tc.args[1], Subscript):
raise TypeError(
"Subscripted generics cannot be used with class and instance checks"
)

# Need to handle the presence of PlutusData classes
if (
isinstance(tc.func, Name)
and tc.func.orig_id == "isinstance"
and not isinstance(
tc.args[1].typ, (ByteStringType, IntegerType, ListType, DictType)
)
if is_isinstance_call and not isinstance(
tc.args[1].typ, (ByteStringType, IntegerType, ListType, DictType)
):
if (
isinstance(tc.args[0].typ, InstanceType)
Expand Down
143 changes: 143 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,74 @@ def fib(n):

some_output = st.sampled_from([SomeOutputDatum(b"0"), SomeOutputDatumHash(b"1")])

ASSERT_ANYTHING_CONFIG = DEFAULT_TEST_CONFIG.update(allow_isinstance_anything=True)


def _assert_isinstance_anything_int_program():
return """
def validator(x: Anything) -> int:
y = x
assert isinstance(y, int), "Wrong type"
return y + 1
"""


def _assert_isinstance_anything_int_program_cast_if():
return """
def validator(x: Anything) -> int:
y = x
if isinstance(y, int):
return y + 1
return 0
"""


def _assert_isinstance_anything_bytes_program():
return """
def validator(x: Anything) -> int:
y = x
assert isinstance(y, bytes), "Wrong type"
return len(y)
"""


def _assert_isinstance_anything_list_program():
return """
from typing import List

def validator(x: Anything) -> int:
y = x
assert isinstance(y, List), "Wrong type"
return len(y)
"""


def _assert_isinstance_anything_dict_program():
return """
from typing import Dict

def validator(x: Anything) -> int:
y = x
assert isinstance(y, Dict), "Wrong type"
return len(y)
"""


def _assert_isinstance_anything_user_defined_program():
return """
from opshin.prelude import *

@dataclass()
class Foo(PlutusData):
CONSTR_ID = 0
value: int

def validator(x: Anything) -> int:
y = x
assert isinstance(y, Foo), "Wrong type"
return y.value + 2
"""


class MiscTest(unittest.TestCase):
def test_assert_sum_contract_succeed(self):
Expand Down Expand Up @@ -816,6 +884,81 @@ def validator(x: Anything) -> int:
ret = eval_uplc_value(source_code, 0)
self.assertEqual(ret, 0)

def test_assert_isinstance_anything_int(self):
source_code = _assert_isinstance_anything_int_program()
ret = eval_uplc_value(source_code, 1, config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(ret, 2)

def test_assert_isinstance_anything_int_cast_if(self):
source_code = _assert_isinstance_anything_int_program_cast_if()
ret = eval_uplc_value(source_code, 1, config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(ret, 2)

def test_assert_isinstance_anything_int_cast_if_wrong_config(self):
source_code = _assert_isinstance_anything_int_program_cast_if()
with self.assertRaises(CompilerError):
ret = eval_uplc_value(source_code, 1, config=DEFAULT_TEST_CONFIG)

def test_assert_isinstance_anything_int_wrong_config(self):
source_code = _assert_isinstance_anything_int_program()
with self.assertRaises(CompilerError):
ret = eval_uplc_value(source_code, 1, config=DEFAULT_TEST_CONFIG)

def test_assert_isinstance_anything_int_illegal(self):
source_code = _assert_isinstance_anything_int_program()
with self.assertRaises(RuntimeError):
eval_uplc_value(source_code, b"\x01", config=ASSERT_ANYTHING_CONFIG)

def test_assert_isinstance_anything_bytes(self):
source_code = _assert_isinstance_anything_bytes_program()
ret = eval_uplc_value(source_code, b"abc", config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(ret, 3)

def test_assert_isinstance_anything_bytes_illegal(self):
source_code = _assert_isinstance_anything_bytes_program()
with self.assertRaises(RuntimeError):
eval_uplc_value(source_code, 1, config=ASSERT_ANYTHING_CONFIG)

def test_assert_isinstance_anything_list(self):
source_code = _assert_isinstance_anything_list_program()
ret = eval_uplc_value(source_code, [1, 2, 3], config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(
ret,
3,
)

def test_assert_isinstance_anything_list_illegal(self):
source_code = _assert_isinstance_anything_list_program()
with self.assertRaises(RuntimeError):
eval_uplc_value(source_code, {1: 2}, config=ASSERT_ANYTHING_CONFIG)

def test_assert_isinstance_anything_dict(self):
source_code = _assert_isinstance_anything_dict_program()
ret = eval_uplc_value(source_code, {1: 2, 3: 4}, config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(ret, 2)

def test_assert_isinstance_anything_dict_illegal(self):
source_code = _assert_isinstance_anything_dict_program()
with self.assertRaises(RuntimeError):
eval_uplc_value(source_code, [1, 2], config=ASSERT_ANYTHING_CONFIG)

def test_assert_isinstance_anything_user_defined_type(self):
source_code = _assert_isinstance_anything_user_defined_program()
datum = uplc.PlutusConstr(0, [uplc.PlutusInteger(5)])
ret = eval_uplc_value(source_code, datum, config=ASSERT_ANYTHING_CONFIG)
self.assertEqual(ret, 7)

def test_assert_isinstance_anything_user_defined_type_wrong_config(self):
source_code = _assert_isinstance_anything_user_defined_program()
datum = uplc.PlutusConstr(0, [uplc.PlutusInteger(5)])
with self.assertRaises(CompilerError):
eval_uplc_value(source_code, datum, config=DEFAULT_TEST_CONFIG)

def test_assert_isinstance_anything_user_defined_type_illegal(self):
source_code = _assert_isinstance_anything_user_defined_program()
with self.assertRaises(RuntimeError):
eval_uplc_value(source_code, 1, config=ASSERT_ANYTHING_CONFIG)

def test_typecast_int_anything(self):
# this should compile, it happens implicitly anyways when calling a function with Any parameters
source_code = """
Expand Down