Skip to content

Commit 631d4b7

Browse files
committed
fixing & testing rgb convertion
1 parent 3a603db commit 631d4b7

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

AUTHORS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Development Lead
99

1010
Contributors
1111
------------
12-
- None yet
12+
* `@tdrobbins <https://github.com/tdrobbins>`_
1313

1414
Citations
1515
---------

src/unet/utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,30 @@ def crop(image, label):
4343

4444
def to_rgb(img: np.array):
4545
"""
46-
Converts the given array into a RGB image. If the number of channels is less
47-
than 3, the array is tiled such that it has 3 channels. If the number of
48-
channels is greater than 3, only the first 3 channels are used. Finally, the
49-
values are rescaled to [0,255)
46+
Converts the given array into a RGB image and normalizes the values to [0, 1).
47+
If the number of channels is less than 3, the array is tiled such that it has 3 channels.
48+
If the number of channels is greater than 3, only the first 3 channels are used
5049
51-
:param img: the array to convert [nx, ny, channels]
50+
:param img: the array to convert [bs, nx, ny, channels]
5251
53-
:returns img: the rgb image [nx, ny, 3]
52+
:returns img: the rgb image [bs, nx, ny, 3]
5453
"""
5554
img = img.astype(np.float32)
5655
img = np.atleast_3d(img)
5756

5857
channels = img.shape[-1]
59-
if channels < 3:
58+
if channels == 1:
6059
img = np.tile(img, 3)
60+
61+
elif channels == 2:
62+
img = np.concatenate((img, img[..., :1]), axis=-1)
63+
6164
elif channels > 3:
62-
img = img[:,:,:3]
65+
img = img[..., :3]
6366

6467
img[np.isnan(img)] = 0
6568
img -= np.amin(img)
6669
if np.amax(img) != 0:
6770
img /= np.amax(img)
6871

69-
# img *= 255
7072
return img

tests/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
import pytest
3+
4+
from unet import utils
5+
6+
7+
@pytest.mark.parametrize("channels", [
8+
1,2,3,4
9+
])
10+
def test_to_rgb(channels):
11+
tensor = np.random.normal(size=(5, 12, 12, channels))
12+
13+
tensor[1, 5, 5, 0] = np.nan
14+
15+
rgb_img = utils.to_rgb(tensor)
16+
17+
assert rgb_img.shape[:2] == tensor.shape[:2]
18+
assert rgb_img.shape[3] == 3
19+
20+
assert rgb_img.min() == 0
21+
assert rgb_img.max() == 1

0 commit comments

Comments
 (0)