File tree Expand file tree Collapse file tree 3 files changed +33
-10
lines changed
Expand file tree Collapse file tree 3 files changed +33
-10
lines changed Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ Development Lead
99
1010Contributors
1111------------
12- - None yet
12+ * ` @tdrobbins < https://github.com/tdrobbins >`_
1313
1414Citations
1515---------
Original file line number Diff line number Diff line change @@ -43,28 +43,30 @@ def crop(image, label):
4343
4444def to_rgb (img : np .array ):
4545 """
46- Converts the given array into a RGB image. If the number of channels is less
47- than 3, the array is tiled such that it has 3 channels. If the number of
48- channels is greater than 3, only the first 3 channels are used. Finally, the
49- values are rescaled to [0,255)
46+ Converts the given array into a RGB image and normalizes the values to [0, 1).
47+ If the number of channels is less than 3, the array is tiled such that it has 3 channels.
48+ If the number of channels is greater than 3, only the first 3 channels are used
5049
51- :param img: the array to convert [nx, ny, channels]
50+ :param img: the array to convert [bs, nx, ny, channels]
5251
53- :returns img: the rgb image [nx, ny, 3]
52+ :returns img: the rgb image [bs, nx, ny, 3]
5453 """
5554 img = img .astype (np .float32 )
5655 img = np .atleast_3d (img )
5756
5857 channels = img .shape [- 1 ]
59- if channels < 3 :
58+ if channels == 1 :
6059 img = np .tile (img , 3 )
60+
61+ elif channels == 2 :
62+ img = np .concatenate ((img , img [..., :1 ]), axis = - 1 )
63+
6164 elif channels > 3 :
62- img = img [:,:, :3 ]
65+ img = img [..., :3 ]
6366
6467 img [np .isnan (img )] = 0
6568 img -= np .amin (img )
6669 if np .amax (img ) != 0 :
6770 img /= np .amax (img )
6871
69- # img *= 255
7072 return img
Original file line number Diff line number Diff line change 1+ import numpy as np
2+ import pytest
3+
4+ from unet import utils
5+
6+
7+ @pytest .mark .parametrize ("channels" , [
8+ 1 ,2 ,3 ,4
9+ ])
10+ def test_to_rgb (channels ):
11+ tensor = np .random .normal (size = (5 , 12 , 12 , channels ))
12+
13+ tensor [1 , 5 , 5 , 0 ] = np .nan
14+
15+ rgb_img = utils .to_rgb (tensor )
16+
17+ assert rgb_img .shape [:2 ] == tensor .shape [:2 ]
18+ assert rgb_img .shape [3 ] == 3
19+
20+ assert rgb_img .min () == 0
21+ assert rgb_img .max () == 1
You can’t perform that action at this time.
0 commit comments