Skip to content

Commit

Permalink
test: add tests, correct return type to float
Browse files Browse the repository at this point in the history
  • Loading branch information
valentynbez committed Mar 29, 2024
1 parent 4db76c1 commit dcae059
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
10 changes: 7 additions & 3 deletions mDeepFRI/predict.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import onnxruntime as rt
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
cpdef np.ndarray[int, ndim=2] seq2onehot(str seq):
cpdef np.ndarray[float, ndim=2] seq2onehot(str seq):
"""
Converts a protein sequence to 26-dim one-hot encoding.
Expand All @@ -25,18 +25,22 @@ cpdef np.ndarray[int, ndim=2] seq2onehot(str seq):
Returns:
np.ndarray: one-hot encoding of the protein sequence
"""
cdef bytes seq_bytes = seq.encode()

cdef bytes chars = b"-DGULNTKHYWCPVSOIEFXQABZRM"
cdef int vocab_size = len(chars)
cdef int seq_len = len(seq)
cdef int i, j
cdef int[:, ::1] onehot_view = np.zeros((seq_len, vocab_size), dtype=np.int32)

for i in range(seq_len):
j = chars.find(seq[i].encode())
j = chars.find(seq_bytes[i])
if j != -1:
onehot_view[i, j] = 1
else:
raise ValueError(f"Invalid character in sequence: {seq[i]}")

return np.asarray(onehot_view, dtype=np.int32)
return np.asarray(onehot_view, dtype=np.float32)


cdef class Predictor(object):
Expand Down
Empty file.
34 changes: 34 additions & 0 deletions mDeepFRI/tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

import numpy as np

from mDeepFRI.predict import seq2onehot


class TestSeq2OneHot(unittest.TestCase):
def test_empty_sequence(self):
seq = ""
result = seq2onehot(seq)
self.assertEqual(result.shape, (0, 26))
self.assertEqual(result.dtype, np.float32)
self.assertTrue(np.all(result == np.array([]).reshape(0, 26)))

def test_single_character_sequence(self):
seq = "D"
result = seq2onehot(seq)
self.assertEqual(result.shape, (1, 26))
self.assertTrue(np.all(result == np.array([[0, 1] + [0] * 24])))

def test_sequence(self):
seq = "-DGU"
result = seq2onehot(seq)
self.assertEqual(result.shape, (4, 26))
self.assertTrue(
np.all(result == np.array([[1, 0, 0, 0] + [0] * 22, [0, 1, 0, 0] +
[0] * 22, [0, 0, 1, 0] +
[0] * 22, [0, 0, 0, 1] + [0] * 22])))

def test_invalid_character(self):
seq = "J"
with self.assertRaises(ValueError):
seq2onehot(seq)

0 comments on commit dcae059

Please sign in to comment.