-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathds2_cont_filler_layer.py
32 lines (24 loc) · 976 Bytes
/
ds2_cont_filler_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# --------------------------------------------------------
# Deep Speech 2 Caffe Implementation
# Written by Tian, Feng <[email protected]>
# --------------------------------------------------------
"""The data layer used during training to train a DS2 network.
DS2ContFillerLayer implements a Caffe Python layer.
"""
import caffe
import numpy as np
class DS2ContFillerLayer(caffe.Layer):
"""DeepSpeech2 cont filler layer used for training."""
def setup(self, bottom, top):
"""Setup the DS2ContFillerLayer."""
top[0].reshape(bottom[0].shape[0], bottom[0].shape[1])
def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector."""
top[0].data[...] = 1
top[0].data[:,0] = 0
def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass