Skip to content

Commit

Permalink
reimplement setitem_long_long
Browse files Browse the repository at this point in the history
  • Loading branch information
jsn1993 committed Jun 10, 2020
1 parent dcee323 commit f8a02e2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 45 deletions.
95 changes: 50 additions & 45 deletions pypy/module/mamba/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,72 +381,77 @@ def setitem_long_long_helper( value, other, start, stop ):
return setitem_long_int_helper( value, other.digit(0), start, stop )

if other.sign < 0:
slice_nbits = stop - start
other = other.and_( get_long_mask(slice_nbits) )
other = other.and_( get_long_mask(stop-start) )
if other.numdigits() == 1:
return setitem_long_int_helper( value, other.digit(0), start, stop )

# After the two above checks, we have made sure other has more than one
# digit and wordstart must < wordstop
vsize = value.numdigits()
other = other.lshift( start ) # lshift first to align two rbigints
osize = other.numdigits()

# Now other must be long, wordstart must < wordstop
# After the two above checks, we have made sure other has more than one digit
# assert osize >= 2
# assert wordstart < wordstop

# Also, the caller must have already checked if bitwidth exceeds the slice
# assert wordstart <= osize - 1 <= wordstop

wordstart = start / SHIFT

# 1. vsize <= wordstart < wordstop, concatenate
if vsize <= wordstart:
return rbigint(value._digits[:vsize] + other._digits[vsize:], 1, osize )

wordstop = stop / SHIFT
bitstart = start - wordstart*SHIFT

# 2. wordstart < wordstop < vsize
if wordstop < vsize:
ret = rbigint( value._digits[:vsize], 1, vsize )
wordstop = stop / SHIFT

# do start
bitstart = start - wordstart*SHIFT
tmpstart = other.digit( wordstart ) | (ret.digit(wordstart) & get_int_mask(bitstart))
# if bitstart:
# tmpstart |= ret.digit(wordstart) & get_int_mask(bitstart) # lo
ret.setdigit( wordstart, tmpstart )

i = wordstart+1

# wordstart < osize <= wordstop < vsize
if osize <= wordstop:
while i < osize:
ret.setdigit( i, other.digit(i) )
i += 1
while i < wordstop:
ret._digits[i] = NULLDIGIT
i += 1
# wordstart < wordstop < osize < vsize
else:
while i < wordstop:
ret.setdigit( i, other.digit(i) )
i += 1
# 2. wordstart < vsize <= wordstop, merge wordstart and concatenate
if vsize <= wordstop:
assert wordstart >= 0
ret = rbigint( value._digits[:wordstart] + \
other._digits[wordstart:osize], 1, osize )

# do stop
bitstop = stop - wordstop*SHIFT
if bitstop:
masked_val = ret.digit(wordstop) & ~get_int_mask(bitstop) #hi
ret.setdigit( wordstop, other.digit(wordstop) | masked_val ) # lo|hi
# union start, there is no value in lower bits of other.digit(wordstart)
if bitstart:
value_lo = value.digit(wordstart) & get_int_mask(bitstart) # lo
ret.setdigit( wordstart, value_lo | ret.digit(wordstart) ) # lo | hi

return ret

assert wordstart >= 0
# wordstart < vsize <= wordstop
ret = rbigint( value._digits[:wordstart] + \
other._digits[wordstart:osize], 1, osize )
# 3. wordstart < wordstop < vsize, handle both sides
ret = rbigint( value._digits[:], 1, vsize )

# do start
bitstart = start - wordstart*SHIFT
if bitstart:
masked_val = value.digit(wordstart) & get_int_mask(bitstart) # lo
ret.setdigit( wordstart, masked_val | ret.digit(wordstart) ) # lo | hi
# union start, there is no value in lower bits of other.digit(wordstart)
value_lo = ret.digit(wordstart) & get_int_mask(bitstart) # lo
ret.setdigit( wordstart, value_lo | other.digit( wordstart ) ) # lo | hi

# put other into
i = wordstart + 1

inv_maskstop = ~get_int_mask( stop - wordstop*SHIFT )
# wordstop == osize - 1 means other's last word is wordstop
if wordstop == osize - 1:
while i < wordstop:
ret.setdigit(i, other.digit(i) )
i += 1
# union stop
value_hi = ret.digit(wordstop) & inv_maskstop # hi
ret.setdigit( wordstop, other.digit(wordstop) | value_hi ) # lo|hi

# wordstop > osize - 1, other is shorter
else:
while i < osize:
ret.setdigit(i, other.digit(i) )
i += 1
while i < wordstop:
ret._digits[i] = NULLDIGIT
i += 1

# clear stop
ret.setdigit( wordstop, ret.digit(wordstop) & inv_maskstop )

ret._normalize()
return ret

@jit.elidable
Expand Down
5 changes: 5 additions & 0 deletions pypy/module/mamba/test/test_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def make_long(x): return x + 2 ** 100 - 2 ** 100
with raises(ValueError):
b[0:80] = mamba.Bits(100, 0)

a = mamba.Bits(146, 0x00000c26283b3f1402373002002b700000293)
a[0:128] = mamba.Bits(128, 0x00001317134282930000129740710133)
print(a, a == mamba.Bits(146, 0x0000000001317134282930000129740710133), repr(a) == repr(mamba.Bits(146, 0x0000000001317134282930000129740710133)) )
assert a == mamba.Bits(146, 0x0000000001317134282930000129740710133)

def test_setitem_crash(self):
from mamba import Bits
input = Bits(465, 0x00095700000000000000003f950000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f5d )
Expand Down

0 comments on commit f8a02e2

Please sign in to comment.