Skip to content

Commit 6930508

Browse files
committed
Build augmented CSR directly instead of building CSC and transposing it
1 parent 6528706 commit 6930508

File tree

1 file changed

+38
-47
lines changed

1 file changed

+38
-47
lines changed

cpp/src/dual_simplex/barrier.cu

Lines changed: 38 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -425,76 +425,67 @@ class iteration_data_t {
425425
const f_t primal_perturb = 1e-6;
426426
if (first_call) {
427427
i_t new_nnz = 2 * nnzA + n + m + nnzQ;
428-
csc_matrix_t<i_t, f_t> augmented(n + m, n + m, new_nnz);
428+
csr_matrix_t<i_t, f_t> augmented_CSR(n + m, n + m, new_nnz);
429+
std::vector<i_t> augmented_diagonal_indices(n + m, -1);
429430
i_t q = 0;
430431
i_t off_diag_Qnz = 0;
431-
for (i_t j = 0; j < n; j++) {
432-
cuopt_assert(std::isfinite(diag[j]), "diag[j] is not finite");
433-
augmented.col_start[j] = q;
432+
433+
for (i_t i = 0; i < n; i++) {
434+
augmented_CSR.row_start[i] = q;
434435
if (nnzQ == 0) {
435-
// augmented_diagonal_indices[j] = q;
436-
augmented.i[q] = j;
437-
augmented.x[q++] = -diag[j] - dual_perturb;
436+
augmented_diagonal_indices[i] = q;
437+
augmented_CSR.j[q] = i;
438+
augmented_CSR.x[q++] = -diag[i] - dual_perturb;
438439
} else {
439-
const i_t q_col_beg = Q.col_start[j];
440-
const i_t q_col_end = Q.col_start[j + 1];
440+
// Q is symmetric
441+
const i_t q_col_beg = Q.col_start[i];
442+
const i_t q_col_end = Q.col_start[i + 1];
441443
bool has_diagonal = false;
442444
for (i_t p = q_col_beg; p < q_col_end; ++p) {
443-
augmented.i[q] = Q.i[p];
444-
if (Q.i[p] == j) {
445-
has_diagonal = true;
446-
// augmented_diagonal_indices[j] = q;
447-
augmented.x[q++] = -Q.x[p] - diag[j] - dual_perturb;
445+
augmented_CSR.j[q] = Q.i[p];
446+
if (Q.i[p] == i) {
447+
has_diagonal = true;
448+
augmented_diagonal_indices[i] = q;
449+
augmented_CSR.x[q++] = -Q.x[p] - diag[i] - dual_perturb;
448450
} else {
449451
off_diag_Qnz++;
450-
augmented.x[q++] = -Q.x[p];
452+
augmented_CSR.x[q++] = -Q.x[p];
451453
}
452454
}
453455
if (!has_diagonal) {
454-
// augmented_diagonal_indices[j] = q;
455-
augmented.i[q] = j;
456-
augmented.x[q++] = -diag[j] - dual_perturb;
456+
augmented_diagonal_indices[i] = q;
457+
augmented_CSR.j[q] = i;
458+
augmented_CSR.x[q++] = -diag[i] - dual_perturb;
457459
}
458460
}
459-
const i_t col_beg = A.col_start[j];
460-
const i_t col_end = A.col_start[j + 1];
461+
// AT block, we can use A in csc directly
462+
const i_t col_beg = A.col_start[i];
463+
const i_t col_end = A.col_start[i + 1];
461464
for (i_t p = col_beg; p < col_end; ++p) {
462-
augmented.i[q] = n + A.i[p];
463-
augmented.x[q++] = A.x[p];
465+
augmented_CSR.j[q] = A.i[p] + n;
466+
augmented_CSR.x[q++] = A.x[p];
464467
}
465468
}
466-
settings_.log.debug("augmented nz %d predicted %d\n", q, off_diag_Qnz + nnzA + n);
469+
467470
for (i_t k = n; k < n + m; ++k) {
468-
augmented.col_start[k] = q;
469-
const i_t l = k - n;
470-
const i_t col_beg = AT.col_start[l];
471-
const i_t col_end = AT.col_start[l + 1];
471+
// A block, we can use AT in csc directly
472+
augmented_CSR.row_start[k] = q;
473+
const i_t l = k - n;
474+
const i_t col_beg = AT.col_start[l];
475+
const i_t col_end = AT.col_start[l + 1];
472476
for (i_t p = col_beg; p < col_end; ++p) {
473-
augmented.i[q] = AT.i[p];
474-
augmented.x[q++] = AT.x[p];
477+
augmented_CSR.j[q] = AT.i[p];
478+
augmented_CSR.x[q++] = AT.x[p];
475479
}
476-
// augmented_diagonal_indices[k] = q;
477-
augmented.i[q] = k;
478-
augmented.x[q++] = primal_perturb;
480+
augmented_diagonal_indices[k] = q;
481+
augmented_CSR.j[q] = k;
482+
augmented_CSR.x[q++] = primal_perturb;
479483
}
480-
augmented.col_start[n + m] = q;
484+
augmented_CSR.row_start[n + m] = q;
485+
settings_.log.debug("augmented nz %d predicted %d\n", q, off_diag_Qnz + nnzA + n);
481486
cuopt_assert(q == 2 * nnzA + n + m + off_diag_Qnz, "augmented nnz != predicted");
482487
cuopt_assert(A.col_start[n] == AT.col_start[m], "A nz != AT nz");
483488

484-
csr_matrix_t<i_t, f_t> augmented_CSR(n + m, n + m, augmented.col_start[n + m]);
485-
augmented.to_compressed_row(augmented_CSR);
486-
487-
std::vector<i_t> augmented_diagonal_indices(augmented_CSR.n, -1);
488-
// Extract the diagonal indices from augmented_CSR
489-
for (i_t row = 0; row < augmented_CSR.n; ++row) {
490-
for (i_t k = augmented_CSR.row_start[row]; k < augmented_CSR.row_start[row + 1]; ++k) {
491-
if (augmented_CSR.j[k] == row) {
492-
augmented_diagonal_indices[row] = k;
493-
break;
494-
}
495-
}
496-
}
497-
498489
device_augmented.copy(augmented_CSR, handle_ptr->get_stream());
499490
d_augmented_diagonal_indices_.resize(augmented_diagonal_indices.size(),
500491
handle_ptr->get_stream());

0 commit comments

Comments
 (0)