-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsteerable_pyramid.py
119 lines (84 loc) · 3.2 KB
/
steerable_pyramid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from util import max_pyramid_height
from util import correlate_and_downsample
from util import upsample_and_correlate
from util import imshow
from sp3filters import SPFilterSet3
from numpy import array, empty, ones
from scipy.signal import resample
import scipy.ndimage as ndi
class SteerablePyramid:
def __init__(self, filter_set=None):
if filter_set is None:
filter_set = SPFilterSet3()
self.filter_set = filter_set
def process_image(self, im, pyramid_height=None, upsample=True):
if pyramid_height is None:
fs = self.filter_set.lo_filt.shape[0]
pyramid_height = max_pyramid_height(im.shape, fs)
print 'pyramid height:', pyramid_height
hi0 = correlate_and_downsample(im, self.filter_set.hi0_filt)
lo0 = correlate_and_downsample(im, self.filter_set.lo0_filt)
#imshow(hi0, 'hi0')
#imshow(lo0, 'lo0')
self.residual_hipass = hi0
pyramid = self.build_sp_levels(lo0, pyramid_height)
if upsample:
return self.upsample_pyramid(pyramid)
else:
return pyramid
def build_sp_levels(self, im, height, use_band_fb=True):
""" Recursively build the levels of a steerable pyramid
"""
if height <= 0:
return [[im]]
bands = []
if use_band_fb:
bands_tmp = correlate_and_downsample(im, self.filter_set.band_fb)
#imshow(bands_tmp, 'bands_tmp')
for i in range(0, bands_tmp.shape[2]):
bands.append(bands_tmp[:,:,i])
else:
for filt in self.filter_set.band_filts:
band = correlate_and_downsample(im, filt)
bands.append(band)
lo = correlate_and_downsample(im, self.filter_set.lo_filt, 2)
print lo.shape
pyramid_below = self.build_sp_levels(lo, height-1)
return [bands] + pyramid_below
def upsample_pyramid(self, pyramid):
target_shape = self.residual_hipass.shape
result = []
for level in pyramid:
new_level = []
for band in level:
band_shape = band.shape
if len(target_shape) > len(band_shape):
band_shape = (band_shape[0], band_shape[1], 1)
zf = array(target_shape) / array(band_shape)
band.shape = band_shape
tmp = ones(target_shape)
if any(zf != 1):
ndi.zoom(band, zf, tmp, order=1)
upsamped = tmp
else:
upsamped = band
new_level.append(upsamped)
result.append(new_level)
return result
if __name__ == "__main__":
from scipy.misc import lena
import numpy as np
import matplotlib.pylab as plt
import time
from scipy.ndimage import zoom
im = lena().astype(np.float32)
imshow(im, 'lena orig')
tic = time.time()
sp = SteerablePyramid()
upsamp = sp.process_image(im, upsample=False)
print "run time: %f" % (time.time() - tic)
for i in range(0, 4):
im = upsamp[i][0].copy()
im.shape = (im.shape[0], im.shape[1])
plt.imshow(im, cmap='gray')
plt.show()