Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hiennguyen9874 committed Nov 30, 2022
1 parent 51880a8 commit b5f99aa
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
16 changes: 8 additions & 8 deletions plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,11 @@ pluginStatus_t EfficientNMSCustomDispatch(EfficientNMSCustomParameters param, co
void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream)
{
// Clear Outputs (not all elements will get overwritten by the kernels, so safer to clear everything out)
CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream));
CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream));
CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream));
CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream));
CSC(cudaMemsetAsync(nmsIndicesOutput, 0xFF, param.batchSize * param.numOutputBoxes * sizeof(int), stream));
CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsIndicesOutput, 0xFF, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE);

// Empty Inputs
if (param.numScoreElements < 1)
Expand All @@ -610,7 +610,7 @@ pluginStatus_t EfficientNMSCustomDispatch(EfficientNMSCustomParameters param, co
int* topOffsetsEndData = topNumData + 2 * param.batchSize;
int* outputIndexData = topNumData + 3 * param.batchSize;
int* outputClassData = topNumData + 4 * param.batchSize;
CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream));
CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream), STATUS_FAILURE);
cudaError_t status = cudaGetLastError();
CSC(status, STATUS_FAILURE);

Expand All @@ -633,9 +633,9 @@ pluginStatus_t EfficientNMSCustomDispatch(EfficientNMSCustomParameters param, co

// Device Specific Properties
int device;
CSC(cudaGetDevice(&device));
CSC(cudaGetDevice(&device), STATUS_FAILURE);
struct cudaDeviceProp properties;
CSC(cudaGetDeviceProperties(&properties, device));
CSC(cudaGetDeviceProperties(&properties, device), STATUS_FAILURE);
if (properties.regsPerBlock >= 65536)
{
// Most Devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,11 @@ pluginStatus_t EfficientNMSLandmarkDispatch(EfficientNMSLandmarkParameters param
cudaStream_t stream)
{
// Clear Outputs (not all elements will get overwritten by the kernels, so safer to clear everything out)
CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream));
CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream));
CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream));
CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream));
CSC(cudaMemsetAsync(nmsLandmarksOutput, 0x00, param.batchSize * param.numOutputBoxes * 10 * sizeof(T), stream));
CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE);
CSC(cudaMemsetAsync(nmsLandmarksOutput, 0x00, param.batchSize * param.numOutputBoxes * 10 * sizeof(T), stream), STATUS_FAILURE);

// Empty Inputs
if (param.numScoreElements < 1)
Expand All @@ -618,7 +618,7 @@ pluginStatus_t EfficientNMSLandmarkDispatch(EfficientNMSLandmarkParameters param
int* topOffsetsEndData = topNumData + 2 * param.batchSize;
int* outputIndexData = topNumData + 3 * param.batchSize;
int* outputClassData = topNumData + 4 * param.batchSize;
CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream));
CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream), STATUS_FAILURE);
cudaError_t status = cudaGetLastError();
CSC(status, STATUS_FAILURE);

Expand All @@ -642,9 +642,9 @@ pluginStatus_t EfficientNMSLandmarkDispatch(EfficientNMSLandmarkParameters param

// Device Specific Properties
int device;
CSC(cudaGetDevice(&device));
CSC(cudaGetDevice(&device), STATUS_FAILURE);
struct cudaDeviceProp properties;
CSC(cudaGetDeviceProperties(&properties, device));
CSC(cudaGetDeviceProperties(&properties, device), STATUS_FAILURE);
if (properties.regsPerBlock >= 65536)
{
// Most Devices
Expand Down
2 changes: 1 addition & 1 deletion plugin/roIAlignPlugin/roIAlignPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class RoIAlignDynamicPlugin : public IPluginV2DynamicExt
bool mAligned;
};

class RoIAlignBasePluginCreator : public BaseCreator
class RoIAlignBasePluginCreator : public nvinfer1::pluginInternal::BaseCreator
{
public:
RoIAlignBasePluginCreator() noexcept;
Expand Down

0 comments on commit b5f99aa

Please sign in to comment.