Skip to content

Commit d1f3257

Browse files
committed
Fixed an issue with the omp swap gate implementation introduced last time and more loop unrolling
1 parent a484e17 commit d1f3257

File tree

1 file changed

+166
-36
lines changed

1 file changed

+166
-36
lines changed

QCSim/QubitRegisterCalculator.h

Lines changed: 166 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,31 @@ namespace QC {
333333

334334
swapStorage = false;
335335

336-
for (size_t state = std::min(qubitBit, ctrlQubitBit); state < NrBasisStates; ++state)
336+
for (size_t state = 0; state < NrBasisStates; ++state)
337337
{
338338
if ((state & qubitBit) == 0 && (state & ctrlQubitBit) != 0)
339339
{
340340
const size_t swapstate = state ^ orqubits;
341341
std::swap(registerStorage(state), registerStorage(swapstate));
342342
}
343+
++state;
344+
if ((state & qubitBit) == 0 && (state & ctrlQubitBit) != 0)
345+
{
346+
const size_t swapstate = state ^ orqubits;
347+
std::swap(registerStorage(state), registerStorage(swapstate));
348+
}
349+
++state;
350+
if ((state & qubitBit) == 0 && (state & ctrlQubitBit) != 0)
351+
{
352+
const size_t swapstate = state ^ orqubits;
353+
std::swap(registerStorage(state), registerStorage(swapstate));
354+
}
355+
++state;
356+
if ((state & qubitBit) == 0 && (state & ctrlQubitBit) != 0)
357+
{
358+
const size_t swapstate = state ^ orqubits;
359+
std::swap(registerStorage(state), registerStorage(swapstate));
360+
}
343361
}
344362
}
345363

@@ -356,40 +374,79 @@ namespace QC {
356374
// TODO: is it worth parallelizing the swap gate?
357375
#pragma omp parallel for
358376
//num_threads(processor_count) schedule(static, blockSize)
359-
for (long long int state = std::min(qubitBit, ctrlQubitBit); state < static_cast<long long int>(NrBasisStates); state += 2)
377+
for (long long int state = 0; state < static_cast<long long int>(NrBasisStates); state += 4)
360378
{
361379
if ((state & qubitBit) == 0 && (state & ctrlQubitBit) != 0)
362380
{
363381
const size_t swapstate = state ^ orqubits;
364382
std::swap(registerStorage(state), registerStorage(swapstate));
365383
}
384+
long long int nextState = state + 1;
385+
if ((nextState & qubitBit) == 0 && (nextState & ctrlQubitBit) != 0)
386+
{
387+
const size_t swapstate = nextState ^ orqubits;
388+
std::swap(registerStorage(nextState), registerStorage(swapstate));
389+
}
390+
++nextState;
391+
if ((nextState & qubitBit) == 0 && (nextState & ctrlQubitBit) != 0)
392+
{
393+
const size_t swapstate = nextState ^ orqubits;
394+
std::swap(registerStorage(nextState), registerStorage(swapstate));
395+
}
396+
++nextState;
397+
if ((nextState & qubitBit) == 0 && (nextState & ctrlQubitBit) != 0)
398+
{
399+
const size_t swapstate = nextState ^ orqubits;
400+
std::swap(registerStorage(nextState), registerStorage(swapstate));
401+
}
366402
}
367403
}
368404

369405
static inline void ApplyDiagonalControlGate(VectorClass& registerStorage, const MatrixClass& gateMatrix, const size_t qubitBit, const size_t ctrlQubitBit, const size_t NrBasisStates, bool& swapStorage)
370406
{
371407
swapStorage = false;
372408

373-
for (size_t state = ctrlQubitBit; state < NrBasisStates; ++state)
409+
for (size_t state = 0; state < NrBasisStates; ++state)
374410
{
375411
if ((state & ctrlQubitBit) != 0)
376412
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
413+
++state;
414+
if ((state & ctrlQubitBit) != 0)
415+
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
416+
++state;
417+
if ((state & ctrlQubitBit) != 0)
418+
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
419+
++state;
420+
if ((state & ctrlQubitBit) != 0)
421+
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
377422
}
378423
}
379424

380425
static inline void ApplyAntidiagonalControlGate(const VectorClass& registerStorage, VectorClass& resultsStorage, const MatrixClass& gateMatrix, const size_t qubitBit, const size_t ctrlQubitBit, const size_t NrBasisStates)
381426
{
382427
const size_t notQubitBit = ~qubitBit;
383428

384-
for (size_t state = 0; state < ctrlQubitBit; ++state)
385-
resultsStorage(state) = registerStorage(state);
386-
387-
for (size_t state = ctrlQubitBit; state < NrBasisStates; ++state)
429+
for (size_t state = 0; state < NrBasisStates; ++state)
388430
{
389431
if ((state & ctrlQubitBit) == 0)
390432
resultsStorage(state) = registerStorage(state);
391433
else
392434
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
435+
++state;
436+
if ((state & ctrlQubitBit) == 0)
437+
resultsStorage(state) = registerStorage(state);
438+
else
439+
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
440+
++state;
441+
if ((state & ctrlQubitBit) == 0)
442+
resultsStorage(state) = registerStorage(state);
443+
else
444+
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
445+
++state;
446+
if ((state & ctrlQubitBit) == 0)
447+
resultsStorage(state) = registerStorage(state);
448+
else
449+
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
393450
}
394451
}
395452

@@ -398,10 +455,7 @@ namespace QC {
398455
const size_t notQubitBit = ~qubitBit;
399456
const size_t orqubits = qubitBit | ctrlQubitBit;
400457

401-
for (size_t state = 0; state < ctrlQubitBit; ++state)
402-
resultsStorage(state) = registerStorage(state);
403-
404-
for (size_t state = ctrlQubitBit; state < NrBasisStates; ++state)
458+
for (size_t state = 0; state < NrBasisStates; ++state)
405459
{
406460
if ((state & ctrlQubitBit) == 0)
407461
resultsStorage(state) = registerStorage(state);
@@ -413,6 +467,39 @@ namespace QC {
413467
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
414468
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
415469
}
470+
++state;
471+
if ((state & ctrlQubitBit) == 0)
472+
resultsStorage(state) = registerStorage(state);
473+
else
474+
{
475+
const size_t row = 2 | (state & qubitBit ? 1 : 0);
476+
const size_t m = state & notQubitBit;
477+
478+
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
479+
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
480+
}
481+
++state;
482+
if ((state & ctrlQubitBit) == 0)
483+
resultsStorage(state) = registerStorage(state);
484+
else
485+
{
486+
const size_t row = 2 | (state & qubitBit ? 1 : 0);
487+
const size_t m = state & notQubitBit;
488+
489+
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
490+
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
491+
}
492+
++state;
493+
if ((state & ctrlQubitBit) == 0)
494+
resultsStorage(state) = registerStorage(state);
495+
else
496+
{
497+
const size_t row = 2 | (state & qubitBit ? 1 : 0);
498+
const size_t m = state & notQubitBit;
499+
500+
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
501+
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
502+
}
416503
}
417504
}
418505

@@ -427,12 +514,19 @@ namespace QC {
427514

428515
#pragma omp parallel for
429516
//num_threads(processor_count) schedule(static, blockSize)
430-
for (long long int state = ctrlQubitBit; state < static_cast<long long int>(NrBasisStates); ++state)
517+
for (long long int state = 0; state < static_cast<long long int>(NrBasisStates); state += 4)
431518
{
432-
if ((state & ctrlQubitBit) == 0)
433-
continue;
434-
435-
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
519+
if ((state & ctrlQubitBit) != 0)
520+
registerStorage(state) *= state & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
521+
long long int nextState = state + 1;
522+
if ((nextState & ctrlQubitBit) != 0)
523+
registerStorage(nextState) *= nextState & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
524+
++nextState;
525+
if ((nextState & ctrlQubitBit) != 0)
526+
registerStorage(nextState) *= nextState & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
527+
++nextState;
528+
if ((nextState & ctrlQubitBit) != 0)
529+
registerStorage(nextState) *= nextState & qubitBit ? gateMatrix(3, 3) : gateMatrix(2, 2);
436530
}
437531
}
438532

@@ -444,20 +538,30 @@ namespace QC {
444538

445539
const size_t notQubitBit = ~qubitBit;
446540

447-
for (size_t state = 0; state < ctrlQubitBit; ++state)
448-
resultsStorage(state) = registerStorage(state);
449-
450541
#pragma omp parallel for
451542
//num_threads(processor_count) schedule(static, blockSize)
452-
for (long long int state = ctrlQubitBit; state < static_cast<long long int>(NrBasisStates); ++state)
543+
for (long long int state = 0; state < static_cast<long long int>(NrBasisStates); state += 4)
453544
{
454545
if ((state & ctrlQubitBit) == 0)
455-
{
456546
resultsStorage(state) = registerStorage(state);
457-
continue;
458-
}
547+
else
548+
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
459549

460-
resultsStorage(state) = state & qubitBit ? gateMatrix(3, 2) * registerStorage(state & notQubitBit) : gateMatrix(2, 3) * registerStorage(state | qubitBit);
550+
long long int nextState = state + 1;
551+
if ((nextState & ctrlQubitBit) == 0)
552+
resultsStorage(nextState) = registerStorage(nextState);
553+
else
554+
resultsStorage(nextState) = nextState & qubitBit ? gateMatrix(3, 2) * registerStorage(nextState & notQubitBit) : gateMatrix(2, 3) * registerStorage(nextState | qubitBit);
555+
++nextState;
556+
if ((nextState & ctrlQubitBit) == 0)
557+
resultsStorage(nextState) = registerStorage(nextState);
558+
else
559+
resultsStorage(nextState) = nextState & qubitBit ? gateMatrix(3, 2) * registerStorage(nextState & notQubitBit) : gateMatrix(2, 3) * registerStorage(nextState | qubitBit);
560+
++nextState;
561+
if ((nextState & ctrlQubitBit) == 0)
562+
resultsStorage(nextState) = registerStorage(nextState);
563+
else
564+
resultsStorage(nextState) = nextState & qubitBit ? gateMatrix(3, 2) * registerStorage(nextState & notQubitBit) : gateMatrix(2, 3) * registerStorage(nextState | qubitBit);
461565
}
462566
}
463567

@@ -470,24 +574,50 @@ namespace QC {
470574
const size_t notQubitBit = ~qubitBit;
471575
const size_t orqubits = qubitBit | ctrlQubitBit;
472576

473-
for (size_t state = 0; state < ctrlQubitBit; ++state)
474-
resultsStorage(state) = registerStorage(state);
475-
476577
#pragma omp parallel for
477578
//num_threads(processor_count) schedule(static, blockSize)
478-
for (long long int state = ctrlQubitBit; state < static_cast<long long int>(NrBasisStates); ++state)
579+
for (long long int state = 0; state < static_cast<long long int>(NrBasisStates); state += 4)
479580
{
480581
if ((state & ctrlQubitBit) == 0)
481-
{
482582
resultsStorage(state) = registerStorage(state);
483-
continue;
484-
}
485-
486-
const size_t row = 2 | (state & qubitBit ? 1 : 0);
487-
const size_t m = state & notQubitBit;
583+
else
584+
{
585+
const size_t row = 2 | (state & qubitBit ? 1 : 0);
586+
const size_t m = state & notQubitBit;
488587

489-
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
490-
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
588+
resultsStorage(state) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
589+
gateMatrix(row, 3) * registerStorage(state | orqubits); // state | ctrlQubitBit | qubitBit : 11
590+
}
591+
long long int nextState = state + 1;
592+
if ((nextState & ctrlQubitBit) == 0)
593+
resultsStorage(nextState) = registerStorage(nextState);
594+
else
595+
{
596+
const size_t row = 2 | (nextState & qubitBit ? 1 : 0);
597+
const size_t m = nextState & notQubitBit;
598+
resultsStorage(nextState) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
599+
gateMatrix(row, 3) * registerStorage(nextState | orqubits); // state | ctrlQubitBit | qubitBit : 11
600+
}
601+
++nextState;
602+
if ((nextState & ctrlQubitBit) == 0)
603+
resultsStorage(nextState) = registerStorage(nextState);
604+
else
605+
{
606+
const size_t row = 2 | (nextState & qubitBit ? 1 : 0);
607+
const size_t m = nextState & notQubitBit;
608+
resultsStorage(nextState) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
609+
gateMatrix(row, 3) * registerStorage(nextState | orqubits); // state | ctrlQubitBit | qubitBit : 11
610+
}
611+
++nextState;
612+
if ((nextState & ctrlQubitBit) == 0)
613+
resultsStorage(nextState) = registerStorage(nextState);
614+
else
615+
{
616+
const size_t row = 2 | (nextState & qubitBit ? 1 : 0);
617+
const size_t m = nextState & notQubitBit;
618+
resultsStorage(nextState) = gateMatrix(row, 2) * registerStorage(m | ctrlQubitBit) + // state & ~qubitBit | ctrlQubitBit : 10
619+
gateMatrix(row, 3) * registerStorage(nextState | orqubits); // state | ctrlQubitBit | qubitBit : 11
620+
}
491621
}
492622
}
493623

0 commit comments

Comments
 (0)