From dcae0598a655e69f378ad3c7e04b61b41fcb712c Mon Sep 17 00:00:00 2001 From: valentynbez Date: Fri, 29 Mar 2024 18:34:31 +0100 Subject: [PATCH] test: add tests, correct return type to float --- mDeepFRI/predict.pyx | 10 +++++--- mDeepFRI/tests/test_build_database.py | 0 mDeepFRI/tests/test_predict.py | 34 +++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) delete mode 100644 mDeepFRI/tests/test_build_database.py create mode 100644 mDeepFRI/tests/test_predict.py diff --git a/mDeepFRI/predict.pyx b/mDeepFRI/predict.pyx index 740e621..5473af7 100644 --- a/mDeepFRI/predict.pyx +++ b/mDeepFRI/predict.pyx @@ -15,7 +15,7 @@ import onnxruntime as rt @cython.wraparound(False) @cython.cdivision(True) @cython.initializedcheck(False) -cpdef np.ndarray[int, ndim=2] seq2onehot(str seq): +cpdef np.ndarray[float, ndim=2] seq2onehot(str seq): """ Converts a protein sequence to 26-dim one-hot encoding. @@ -25,6 +25,8 @@ cpdef np.ndarray[int, ndim=2] seq2onehot(str seq): Returns: np.ndarray: one-hot encoding of the protein sequence """ + cdef bytes seq_bytes = seq.encode() + cdef bytes chars = b"-DGULNTKHYWCPVSOIEFXQABZRM" cdef int vocab_size = len(chars) cdef int seq_len = len(seq) @@ -32,11 +34,13 @@ cpdef np.ndarray[int, ndim=2] seq2onehot(str seq): cdef int[:, ::1] onehot_view = np.zeros((seq_len, vocab_size), dtype=np.int32) for i in range(seq_len): - j = chars.find(seq[i].encode()) + j = chars.find(seq_bytes[i]) if j != -1: onehot_view[i, j] = 1 + else: + raise ValueError(f"Invalid character in sequence: {seq[i]}") - return np.asarray(onehot_view, dtype=np.int32) + return np.asarray(onehot_view, dtype=np.float32) cdef class Predictor(object): diff --git a/mDeepFRI/tests/test_build_database.py b/mDeepFRI/tests/test_build_database.py deleted file mode 100644 index e69de29..0000000 diff --git a/mDeepFRI/tests/test_predict.py b/mDeepFRI/tests/test_predict.py new file mode 100644 index 0000000..54b6c8a --- /dev/null +++ b/mDeepFRI/tests/test_predict.py @@ -0,0 +1,34 @@ +import unittest + +import numpy as np + +from mDeepFRI.predict import seq2onehot + + +class TestSeq2OneHot(unittest.TestCase): + def test_empty_sequence(self): + seq = "" + result = seq2onehot(seq) + self.assertEqual(result.shape, (0, 26)) + self.assertEqual(result.dtype, np.float32) + self.assertTrue(np.all(result == np.array([]).reshape(0, 26))) + + def test_single_character_sequence(self): + seq = "D" + result = seq2onehot(seq) + self.assertEqual(result.shape, (1, 26)) + self.assertTrue(np.all(result == np.array([[0, 1] + [0] * 24]))) + + def test_sequence(self): + seq = "-DGU" + result = seq2onehot(seq) + self.assertEqual(result.shape, (4, 26)) + self.assertTrue( + np.all(result == np.array([[1, 0, 0, 0] + [0] * 22, [0, 1, 0, 0] + + [0] * 22, [0, 0, 1, 0] + + [0] * 22, [0, 0, 0, 1] + [0] * 22]))) + + def test_invalid_character(self): + seq = "J" + with self.assertRaises(ValueError): + seq2onehot(seq)