-
Notifications
You must be signed in to change notification settings - Fork 89
Open
Description
When running the notebook: Parametric t-SNE (Keras).ipynb
, I followed the notebook by constructing the neural network using the customized loss function. However, I received ValueError from Keras when trying to train the neural network.
- Customized loss function:
# P is the joint probabilities for this batch (Keras loss functions call this y_true)
# activations is the low-dimensional output (Keras loss functions call this y_pred)
def tsne(P, activations):
d = 2 # TODO: should set this automatically, but the above is very slow for some reason
n = batch_size # TODO: should set this automatically
v = d - 1.
eps = K.variable(10e-15) # needs to be at least 10e-8 to get anything after Q /= K.sum(Q)
sum_act = K.sum(K.square(activations), axis=1)
Q = K.reshape(sum_act, [-1, 1]) + -2 * K.dot(activations, K.transpose(activations))
Q = (sum_act + Q) / v
Q = K.pow(1 + Q, -(v + 1) / 2)
Q *= K.variable(1 - np.eye(n))
Q /= K.sum(Q)
Q = K.maximum(Q, eps)
C = K.log((P + eps) / (Q + eps))
C = K.sum(P * C)
return C
- construct model:
model = Sequential()
model.add(Dense(500, activation='relu', input_shape=(X_train.shape[1],)))
model.add(Dense(500, activation='relu'))
model.add(Dense(2000, activation='relu'))
model.add(Dense(2))
sgd = SGD(lr=0.1)
%time model.compile(loss=tsne, optimizer=sgd)
- preparing Y_train
Y_train = P.reshape(X_train.shape[0], -1)
- I received a ValueError from Keras and trying the execute the following line:
%time model.fit(X_train, Y_train, batch_size=batch_size, shuffle=False, nb_epoch=100)
ValueError: Error when checking target: expected dense_4 to have shape (None, 2) but got array with shape (60000, 5000)
Since the loss function has been customized to received different shape of y_true (batch_size x batch_size) and y_pred (batch_size x 2), I was not sure why I received the error above. Note that the batch_size here is 5000, which is the same value as the notebook is using.
Metadata
Metadata
Assignees
Labels
No labels