forked from rioyokotalab/nbd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
profile.hxx
59 lines (46 loc) · 1.45 KB
/
profile.hxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
#include <cstdint>
#include <cstddef>
struct Profile {
int64_t gemm_flops = 0;
int64_t potrf_flops = 0;
int64_t trsm_flops = 0;
int64_t bytes_matrix = 0;
int64_t bytes_basis = 0;
int64_t bytes_vector = 0;
void record_factor(int64_t dimr, int64_t dimn, int64_t nnz, int64_t ndiag, int64_t nrows) {
if (dimr == 0 && nnz == 1) {
potrf_flops += + dimn * dimn * dimn / 3;
bytes_matrix += dimn * dimn * sizeof(double);
bytes_vector += dimn * sizeof(double);
}
else {
int64_t dims = dimn - dimr;
int64_t fgemm = 4 * dimn * dimn * dimn * nnz;
int64_t fsplit = 2 * dimn * dimr * (dimn + dimr) * ndiag;
int64_t fchol = dimr * dimr * dimr * ndiag / 3;
int64_t ftrsm = dimn * dimr * dimr * ndiag;
int64_t fschur = 2 * dims * dims * dimr * ndiag;
gemm_flops += + fgemm + fsplit + fschur;
potrf_flops += fchol;
trsm_flops += ftrsm;
bytes_matrix += dimn * dimn * nnz * sizeof(double);
bytes_basis += dimn * dimn * nrows * sizeof(double);
bytes_vector += dimn * nrows * sizeof(double);
}
}
void get_profile(int64_t flops[3], int64_t bytes[3]) {
flops[0] = gemm_flops;
flops[1] = potrf_flops;
flops[2] = trsm_flops;
bytes[0] = bytes_matrix;
bytes[1] = bytes_basis;
bytes[2] = bytes_vector;
gemm_flops = 0;
potrf_flops = 0;
trsm_flops = 0;
bytes_matrix = 0;
bytes_basis = 0;
bytes_vector = 0;
}
};