Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Can't infer ImmutableOrigin from MutableOrigin in function signatures #3845

Open
martinvuyk opened this issue Dec 7, 2024 · 1 comment
Labels
bug Something isn't working mojo-repo Tag all issues with this label

Comments

@martinvuyk
Copy link
Contributor

martinvuyk commented Dec 7, 2024

Bug description

Shortest repro I could manage. Here it does not infer the type correctly even if the whole thing is specified

_merge[
    span_origin = ImmutableOrigin.cast_from[span_life].result,
    cmp_fn=cmp_fn,
](span1.get_immut(), span2.get_immut(), temp_buff)

output:

invalid call to '_merge': argument #0 cannot be converted from 
'Span2[type, Origin2(span_life._mlir_origin)]' to 'Span2[type, Origin2((muttoimm span_life._mlir_origin))]'

In the nightly branch _merge[cmp_fn](span1, span2, temp_buff) also doesn't work but if I cast the first _merge[cmp_fn](span1.get_immut(), span2, temp_buff) it does (see PR #3823).

In both cases I think it should infer the whole thing, otherwise it defeats the purpose of having the implicit Mutable -> Immutable casting.

CC: @lattner, @ConnorGray

Steps to reproduce

from memory import UnsafePointer


alias ImmutableOrigin = Origin2[False]
alias MutableOrigin = Origin2[True]


@value
@register_passable("trivial")
struct Span2[
    is_mutable: Bool, //,
    T: CollectionElement,
    origin: Origin2[is_mutable],
](CollectionElementNew):

    alias mut = Span2[T, ImmutableOrigin.cast_from[origin].result]
    """The mutable version of the Span."""
    alias immut = Span2[T, MutableOrigin.cast_from[origin].result]
    """The immutable version of the Span."""

    # Fields
    var _data: UnsafePointer[T]
    var _len: Int

    @doc_private
    @implicit
    @always_inline("nodebug")
    fn __init__(out self: Self.immut, ref other: Self.mut):
        """Implicitly cast the mutable origin of self to an immutable one.

        Args:
            other: The Span to cast.
        """
        self = rebind[Self.immut](other)

    @always_inline
    fn get_immut(self) -> Self.immut:
        """Return an immutable version of this Span.

        Returns:
            An immutable version of the same Span.
        """
        return rebind[Self.immut](self)

    @always_inline
    fn __init__(out self, *, ptr: UnsafePointer[T], length: Int):
        self._data = ptr
        self._len = length

    @always_inline
    fn __init__(out self, *, other: Self):
        self._data = other._data
        self._len = other._len

    fn unsafe_ptr(self) -> UnsafePointer[T]:
        return self._data

    fn __len__(self) -> Int:
        return self._len

    @always_inline
    fn __getitem__(self, idx: Int) -> ref [origin] T:
        # TODO: Simplify this with a UInt type.
        debug_assert(
            -self._len <= int(idx) < self._len, "index must be within bounds"
        )

        var offset = idx
        if offset < 0:
            offset += len(self)
        return self._data[offset]

    @always_inline
    fn __getitem__(self, slc: Slice) -> Self:
        var start: Int
        var end: Int
        var step: Int
        start, end, step = slc.indices(len(self))

        debug_assert(
            step == 1, "Slice must be within bounds and step must be 1"
        )

        var res = Self(
            ptr=(self._data + start), length=len(range(start, end, step))
        )

        return res


@value
@register_passable("trivial")
struct Origin2[is_mutable: Bool]:
    """This represents a origin reference for a memory value.

    Parameters:
        is_mutable: Whether the origin is mutable.
    """

    alias _mlir_type = __mlir_type[
        `!lit.origin<`,
        is_mutable.value,
        `>`,
    ]
    alias cast_from = _lit_mut_cast[result_mutable=is_mutable]

    var _mlir_origin: Self._mlir_type

    @doc_private
    @implicit
    @always_inline("nodebug")
    fn __init__(out self, mlir_origin: Self._mlir_type):
        self._mlir_origin = mlir_origin

    @doc_private
    @implicit
    @always_inline("nodebug")
    fn __init__(out self: ImmutableOrigin, origin: MutableOrigin):
        """Initialize an ImmutableOrigin from a MutableOrigin.

        Args:
            origin: The origin value."""
        self = rebind[ImmutableOrigin](origin)


struct _lit_mut_cast[
    is_mutable: Bool, //,
    result_mutable: Bool,
    operand: Origin2[is_mutable],
]:
    alias result = __mlir_attr[
        `#lit.origin.mutcast<`,
        operand._mlir_origin,
        `> : !lit.origin<`,
        result_mutable.value,
        `>`,
    ]


@value
struct _SortWrapper[type: CollectionElement](CollectionElement):
    var data: type

    @implicit
    fn __init__(out self, data: type):
        self.data = data

    fn __init__(out self, *, other: Self):
        self.data = other.data


@always_inline
fn _insertion_sort[
    type: CollectionElement,
    origin: MutableOrigin, //,
    cmp_fn: fn (_SortWrapper[type], _SortWrapper[type]) capturing [_] -> Bool,
](span: Span2[type, origin]):
    var array = span.unsafe_ptr()
    var size = len(span)

    for i in range(1, size):
        var value = array[i]
        var j = i

        # Find the placement of the value in the array, shifting as we try to
        # find the position. Throughout, we assume array[start:i] has already
        # been sorted.
        while j > 0 and cmp_fn(value, array[j - 1]):
            array[j] = array[j - 1]
            j -= 1

        array[j] = value


fn _merge[
    type: CollectionElement,
    span_origin: ImmutableOrigin,
    result_origin: MutableOrigin, //,
    cmp_fn: fn (_SortWrapper[type], _SortWrapper[type]) capturing [_] -> Bool,
](
    span1: Span2[type, span_origin],
    span2: Span2[type, span_origin],
    result: Span2[type, result_origin],
):
    var span1_size = len(span1)
    var span2_size = len(span2)
    var res_ptr = result.unsafe_ptr()

    debug_assert(
        span1_size + span2_size <= len(result),
        "The merge result does not fit in the span provided",
    )
    var i = 0
    var j = 0
    var k = 0
    while i < span1_size:
        if j == span2_size:
            while i < span1_size:
                (res_ptr + k).init_pointee_copy(span1[i])
                k += 1
                i += 1
            return
        if cmp_fn(span2[j], span1[i]):
            (res_ptr + k).init_pointee_copy(span2[j])
            j += 1
        else:
            (res_ptr + k).init_pointee_copy(span1[i])
            i += 1
        k += 1

    while j < span2_size:
        (res_ptr + k).init_pointee_copy(span2[j])
        k += 1
        j += 1


alias insertion_sort_threshold = 32


fn _stable_sort_impl[
    type: CollectionElement,
    span_life: MutableOrigin,
    tmp_life: MutableOrigin, //,
    cmp_fn: fn (_SortWrapper[type], _SortWrapper[type]) capturing [_] -> Bool,
](span: Span2[type, span_life], temp_buff: Span2[type, tmp_life]):
    var size = len(span)
    if size <= 1:
        return
    var i = 0
    while i < size:
        _insertion_sort[cmp_fn](
            span[i : min(i + insertion_sort_threshold, size)]
        )
        i += insertion_sort_threshold
    var merge_size = insertion_sort_threshold
    while merge_size < size:
        var j = 0
        while j + merge_size < size:
            var span1 = span[j : j + merge_size]
            var span2 = span[j + merge_size : min(size, j + 2 * merge_size)]
            _merge[cmp_fn](span1, span2, temp_buff)
            for i in range(merge_size + len(span2)):
                span[j + i] = temp_buff[i]
            j += 2 * merge_size
        merge_size *= 2


fn _stable_sort[
    type: CollectionElement,
    origin: MutableOrigin, //,
    cmp_fn: fn (_SortWrapper[type], _SortWrapper[type]) capturing [_] -> Bool,
](span: Span2[type, origin]):
    var temp_buff = UnsafePointer[type].alloc(len(span))
    var temp_buff_span = Span2[type, __origin_of(temp_buff)](
        ptr=temp_buff, length=len(span)
    )
    _stable_sort_impl[cmp_fn](span, temp_buff_span)
    temp_buff.free()


fn _sort[
    type: CollectionElement,
    origin: MutableOrigin, //,
    cmp_fn: fn (_SortWrapper[type], _SortWrapper[type]) capturing [_] -> Bool,
    *,
](span: Span2[type, origin]):
    _stable_sort[cmp_fn](span)


fn sort[
    type: DType,
    origin: MutableOrigin, //,
    cmp_fn: fn (Scalar[type], Scalar[type]) capturing [_] -> Bool,
    *,
](span: Span2[Scalar[type], origin]):
    @parameter
    fn _cmp_fn(
        lhs: _SortWrapper[Scalar[type]], rhs: _SortWrapper[Scalar[type]]
    ) -> Bool:
        return cmp_fn(lhs.data, rhs.data)

    _sort[_cmp_fn](span)


fn sort[
    origin: MutableOrigin, //,
    *,
](span: Span2[Int64, origin]):
    @parameter
    fn _cmp_fn(lhs: Int64, rhs: Int64) -> Bool:
        return lhs < rhs

    sort[_cmp_fn](span)


fn main():
    items = List[Int64](3, 4, 5, 1, 2)
    span = Span2[Int64, origin = __origin_of(items)](
        ptr=items.unsafe_ptr(), length=len(items)
    )
    sort(span)
    print(items.__str__())

System information

- What OS did you do install Mojo on ?
- Provide version information for Mojo by pasting the output of `mojo -v`
`mojo 24.6.0.dev2024120705`
- Provide Magic CLI version by pasting the output of `magic -V` or `magic --version` 
- Optionally, provide more information with `magic info`.
@martinvuyk martinvuyk added bug Something isn't working mojo-repo Tag all issues with this label labels Dec 7, 2024
@martinvuyk
Copy link
Contributor Author

I think that this would also remove the need for the @__unsafe_disable_nested_origin_exclusivity decorator in many functions since they would just be annotated with ImmutableOrigin for the readonly ops, and the implicit constructors would then handle any value with a MutableOrigin that is passed in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working mojo-repo Tag all issues with this label
Projects
None yet
Development

No branches or pull requests

1 participant