Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to apply softmax to a batch of data with variable length #489

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions lib/THCUNN/LenSoftMax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include "THCUNN.h"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"

#define LENSOFTMAX_THREADS 128

template <typename T, typename AccumT, typename IndexT>
__global__ void cunn_LenSoftMax_updateOutput_kernel(
T *output, T *input, int nframe, int dim, IndexT *len)
{
__shared__ AccumT buffer[LENSOFTMAX_THREADS+1];
T *input_k = input + blockIdx.x*dim + blockIdx.y + blockIdx.z;
T *output_k = output + blockIdx.x*dim + blockIdx.y + blockIdx.z;

int i_start = threadIdx.x;
int i_end = ScalarConvert<IndexT, int>::to(len[blockIdx.x]);
int i_step = blockDim.x;

// max?
buffer[threadIdx.x] = -THCNumerics<AccumT>::max();
for (int i=i_start; i<i_end; i+=i_step)
{
T z = input_k[i];
AccumT zAcc = ScalarConvert<T, AccumT>::to(z);
if (buffer[threadIdx.x] < zAcc)
buffer[threadIdx.x] = zAcc;
}


__syncthreads();

// reduce
if (threadIdx.x == 0)
{
AccumT max_k = -THCNumerics<AccumT>::max();
for (int i=0; i<blockDim.x; i++)
{
if (max_k < buffer[i])
max_k = buffer[i];
}
buffer[LENSOFTMAX_THREADS] = max_k;
}

__syncthreads();

// sum?
T max_k = ScalarConvert<AccumT, T>::to(buffer[LENSOFTMAX_THREADS]);
buffer[threadIdx.x] = ScalarConvert<int, AccumT>::to(0);
for (int i=i_start; i<i_end; i+=i_step) {
T z = THCNumerics<T>::exp(input_k[i]-max_k);
buffer[threadIdx.x] += ScalarConvert<T, AccumT>::to(z);
output_k[i] = z;
}

__syncthreads();

// reduce
if (threadIdx.x == 0)
{
AccumT sum_k = ScalarConvert<int, AccumT>::to(0);
for (int i=0; i<blockDim.x; i++)
sum_k += buffer[i];
buffer[LENSOFTMAX_THREADS] = sum_k;
}

__syncthreads();

// softmax
T sum_k = ScalarConvert<AccumT, T>::to(buffer[LENSOFTMAX_THREADS]);
for (int i=i_start; i<i_end; i+=i_step)
output_k[i] = output_k[i] / sum_k;
}

template <typename T, typename AccumT, typename IndexT>
__global__ void cunn_LenSoftMax_updateGradInput_kernel(
T *gradInput, T *output, T *gradOutput, int nframe, int dim, IndexT *len)
{
__shared__ AccumT buffer[LENSOFTMAX_THREADS];
T *gradInput_k = gradInput + blockIdx.x*dim + blockIdx.y + blockIdx.z;
T *output_k = output + blockIdx.x*dim + blockIdx.y + blockIdx.z;
T *gradOutput_k = gradOutput + blockIdx.x*dim + blockIdx.y + blockIdx.z;

int i_start = threadIdx.x;
int i_end = ScalarConvert<IndexT, int>::to(len[blockIdx.x]);
int i_step = blockDim.x;

// sum?
buffer[threadIdx.x] = ScalarConvert<int, AccumT>::to(0);
for (int i=i_start; i<i_end; i+=i_step)
buffer[threadIdx.x] += ScalarConvert<T, AccumT>::to(gradOutput_k[i] * output_k[i]);

__syncthreads();

// reduce
if (threadIdx.x == 0)
{
AccumT sum_k = ScalarConvert<int, AccumT>::to(0);
for (int i=0; i<blockDim.x; i++)
sum_k += buffer[i];
buffer[0] = sum_k;
}

__syncthreads();

T sum_k = ScalarConvert<AccumT, T>::to(buffer[0]);
for (int i=i_start; i<i_end; i+=i_step)
gradInput_k[i] = output_k[i] * (gradOutput_k[i] - sum_k);
}

#include "generic/LenSoftMax.cu"
#include "THCGenerateFloatTypes.h"

#undef LENSOFTMAX_THREADS
77 changes: 77 additions & 0 deletions lib/THCUNN/generic/LenSoftMax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/LenSoftMax.cu"
#else

#include "../common.h"

void THNN_(LenSoftMax_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *len)
{
THCUNN_assertSameGPU(state, 2, input, output);

if ((input->nDimension != 2) && (len->nDimension != 1))
{
THError("2D tensor expected for input, 1D tensor expected for len");
}

input = THCTensor_(newContiguous)(state, input);
THCTensor_(resizeAs)(state, output, input);
THCTensor_(zero)(state, output);
long batchSize = input->size[0], dim = input->size[1];
long blocksY = 1, blocksZ = 1;

dim3 blocks(batchSize, blocksY, blocksZ);
dim3 threads(LENSOFTMAX_THREADS);
cunn_LenSoftMax_updateOutput_kernel<real, accreal, THCIndex_t><<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, output),
THCTensor_(data)(state, input),
batchSize, dim, THCIndexTensor_(data)(state, len)
);
THCudaCheck(cudaGetLastError());

THCTensor_(free)(state, input);
}

void THNN_(LenSoftMax_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
THCIndexTensor *len)
{
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);

if ((gradInput->nDimension != 2) && (len->nDimension != 1))
{
THError("2D tensor expected for input, 1D tensor expected for len");
}


output = THCTensor_(newContiguous)(state, output);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);

THCTensor_(resizeAs)(state, gradInput, output);
THCTensor_(zero)(state, gradInput);
long batchSize = gradInput->size[0], dim = gradInput->size[1];
long blocksY = 1, blocksZ = 1;

dim3 blocks(batchSize, blocksY, blocksZ);
dim3 threads(LENSOFTMAX_THREADS);
cunn_LenSoftMax_updateGradInput_kernel<real, accreal, THCIndex_t><<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, gradInput),
THCTensor_(data)(state, output),
THCTensor_(data)(state, gradOutput),
batchSize, dim, THCIndexTensor_(data)(state, len)
);
THCudaCheck(cudaGetLastError());

THCTensor_(free)(state, gradOutput);
THCTensor_(free)(state, output);
}

#endif
14 changes: 14 additions & 0 deletions lib/THCUNN/generic/THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,20 @@ TH_API void THNN_(SoftMax_updateGradInput)(
THCTensor *gradInput,
THCTensor *output);

TH_API void THNN_(LenSoftMax_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *len);

TH_API void THNN_(LenSoftMax_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
THCIndexTensor *len);

TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
Expand Down