@@ -425,76 +425,67 @@ class iteration_data_t {
425425 const f_t primal_perturb = 1e-6 ;
426426 if (first_call) {
427427 i_t new_nnz = 2 * nnzA + n + m + nnzQ;
428- csc_matrix_t <i_t , f_t > augmented (n + m, n + m, new_nnz);
428+ csr_matrix_t <i_t , f_t > augmented_CSR (n + m, n + m, new_nnz);
429+ std::vector<i_t > augmented_diagonal_indices (n + m, -1 );
429430 i_t q = 0 ;
430431 i_t off_diag_Qnz = 0 ;
431- for ( i_t j = 0 ; j < n; j++) {
432- cuopt_assert ( std::isfinite (diag[j]), " diag[j] is not finite " );
433- augmented. col_start [j ] = q;
432+
433+ for ( i_t i = 0 ; i < n; i++) {
434+ augmented_CSR. row_start [i ] = q;
434435 if (nnzQ == 0 ) {
435- // augmented_diagonal_indices[j ] = q;
436- augmented. i [q] = j ;
437- augmented .x [q++] = -diag[j ] - dual_perturb;
436+ augmented_diagonal_indices[i ] = q;
437+ augmented_CSR. j [q] = i ;
438+ augmented_CSR .x [q++] = -diag[i ] - dual_perturb;
438439 } else {
439- const i_t q_col_beg = Q.col_start [j];
440- const i_t q_col_end = Q.col_start [j + 1 ];
440+ // Q is symmetric
441+ const i_t q_col_beg = Q.col_start [i];
442+ const i_t q_col_end = Q.col_start [i + 1 ];
441443 bool has_diagonal = false ;
442444 for (i_t p = q_col_beg; p < q_col_end; ++p) {
443- augmented. i [q] = Q.i [p];
444- if (Q.i [p] == j ) {
445- has_diagonal = true ;
446- // augmented_diagonal_indices[j ] = q;
447- augmented .x [q++] = -Q.x [p] - diag[j ] - dual_perturb;
445+ augmented_CSR. j [q] = Q.i [p];
446+ if (Q.i [p] == i ) {
447+ has_diagonal = true ;
448+ augmented_diagonal_indices[i ] = q;
449+ augmented_CSR .x [q++] = -Q.x [p] - diag[i ] - dual_perturb;
448450 } else {
449451 off_diag_Qnz++;
450- augmented .x [q++] = -Q.x [p];
452+ augmented_CSR .x [q++] = -Q.x [p];
451453 }
452454 }
453455 if (!has_diagonal) {
454- // augmented_diagonal_indices[j ] = q;
455- augmented. i [q] = j ;
456- augmented .x [q++] = -diag[j ] - dual_perturb;
456+ augmented_diagonal_indices[i ] = q;
457+ augmented_CSR. j [q] = i ;
458+ augmented_CSR .x [q++] = -diag[i ] - dual_perturb;
457459 }
458460 }
459- const i_t col_beg = A.col_start [j];
460- const i_t col_end = A.col_start [j + 1 ];
461+ // AT block, we can use A in csc directly
462+ const i_t col_beg = A.col_start [i];
463+ const i_t col_end = A.col_start [i + 1 ];
461464 for (i_t p = col_beg; p < col_end; ++p) {
462- augmented. i [q] = n + A.i [p];
463- augmented .x [q++] = A.x [p];
465+ augmented_CSR. j [q] = A.i [p] + n ;
466+ augmented_CSR .x [q++] = A.x [p];
464467 }
465468 }
466- settings_. log . debug ( " augmented nz %d predicted %d \n " , q, off_diag_Qnz + nnzA + n);
469+
467470 for (i_t k = n; k < n + m; ++k) {
468- augmented.col_start [k] = q;
469- const i_t l = k - n;
470- const i_t col_beg = AT.col_start [l];
471- const i_t col_end = AT.col_start [l + 1 ];
471+ // A block, we can use AT in csc directly
472+ augmented_CSR.row_start [k] = q;
473+ const i_t l = k - n;
474+ const i_t col_beg = AT.col_start [l];
475+ const i_t col_end = AT.col_start [l + 1 ];
472476 for (i_t p = col_beg; p < col_end; ++p) {
473- augmented. i [q] = AT.i [p];
474- augmented .x [q++] = AT.x [p];
477+ augmented_CSR. j [q] = AT.i [p];
478+ augmented_CSR .x [q++] = AT.x [p];
475479 }
476- // augmented_diagonal_indices[k] = q;
477- augmented. i [q] = k;
478- augmented .x [q++] = primal_perturb;
480+ augmented_diagonal_indices[k] = q;
481+ augmented_CSR. j [q] = k;
482+ augmented_CSR .x [q++] = primal_perturb;
479483 }
480- augmented.col_start [n + m] = q;
484+ augmented_CSR.row_start [n + m] = q;
485+ settings_.log .debug (" augmented nz %d predicted %d\n " , q, off_diag_Qnz + nnzA + n);
481486 cuopt_assert (q == 2 * nnzA + n + m + off_diag_Qnz, " augmented nnz != predicted" );
482487 cuopt_assert (A.col_start [n] == AT.col_start [m], " A nz != AT nz" );
483488
484- csr_matrix_t <i_t , f_t > augmented_CSR (n + m, n + m, augmented.col_start [n + m]);
485- augmented.to_compressed_row (augmented_CSR);
486-
487- std::vector<i_t > augmented_diagonal_indices (augmented_CSR.n , -1 );
488- // Extract the diagonal indices from augmented_CSR
489- for (i_t row = 0 ; row < augmented_CSR.n ; ++row) {
490- for (i_t k = augmented_CSR.row_start [row]; k < augmented_CSR.row_start [row + 1 ]; ++k) {
491- if (augmented_CSR.j [k] == row) {
492- augmented_diagonal_indices[row] = k;
493- break ;
494- }
495- }
496- }
497-
498489 device_augmented.copy (augmented_CSR, handle_ptr->get_stream ());
499490 d_augmented_diagonal_indices_.resize (augmented_diagonal_indices.size (),
500491 handle_ptr->get_stream ());
0 commit comments