forked from lloda/ra-ra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbench-gemv.C
140 lines (118 loc) · 4.88 KB
/
bench-gemv.C
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// (c) Daniel Llorens - 2017
// This library is free software; you can redistribute it and/or modify it under
// the terms of the GNU Lesser General Public License as published by the Free
// Software Foundation; either version 3 of the License, or (at your option) any
// later version.
/// @file bench-gemv.H
/// @brief Benchmark for BLAS-2 type ops
// These operations aren't really part of the ET framework, just standalone
// functions.
// Cf bench-gemm.C for BLAS-3 type ops.
#include <iostream>
#include <iomanip>
#include "ra/test.H"
#include "ra/complex.H"
#include "ra/format.H"
#include "ra/big.H"
#include "ra/operators.H"
#include "ra/io.H"
#include "ra/bench.H"
using std::cout, std::endl, std::setw, std::setprecision;
using ra::Small, ra::View, ra::Unique, ra::ra_traits;
using real = double;
// -------------------
// variants of the defaults, should be slower if the default is well picked.
// TODO compare with external GEMV/GEVM
// -------------------
enum trans_t { NOTRANS, TRANS };
int main()
{
TestRecorder tr(std::cout);
auto gemv_i = [&](auto const & a, auto const & b)
{
int const M = a.size(0);
ra::Big<decltype(a(0, 0)*b(0)), 1> c({M}, ra::unspecified);
for (int i=0; i<M; ++i) {
c(i) = dot(a(i), b);
}
return c;
};
auto gemv_j = [&](auto const & a, auto const & b)
{
int const M = a.size(0);
int const N = a.size(1);
ra::Big<decltype(a(0, 0)*b(0)), 1> c({M}, 0.);
for (int j=0; j<N; ++j) {
c += a(ra::all, j)*b(j);
}
return c;
};
auto gevm_j = [&](auto const & b, auto const & a)
{
int const N = a.size(1);
ra::Big<decltype(b(0)*a(0, 0)), 1> c({N}, ra::unspecified);
for (int j=0; j<N; ++j) {
c(j) = dot(b, a(ra::all, j));
}
return c;
};
auto gevm_i = [&](auto const & b, auto const & a)
{
int const M = a.size(0);
int const N = a.size(1);
ra::Big<decltype(b(0)*a(0, 0)), 1> c({N}, 0.);
for (int i=0; i<M; ++i) {
c += b(i)*a(i);
}
return c;
};
auto bench_all = [&](int k, int m, int n, int reps)
{
auto bench_mv = [&tr, &m, &n, &reps](auto && f, char const * tag, trans_t t)
{
ra::Big<real, 2> aa({m, n}, ra::_0-ra::_1);
auto a = t==TRANS ? transpose<1, 0>(aa) : aa();
ra::Big<real, 1> b({a.size(1)}, 1-2*ra::_0);
ra::Big<real, 1> ref = gemv(a, b);
ra::Big<real, 1> c;
auto bv = Benchmark().repeats(reps).runs(3).run([&]() { c = f(a, b); });
tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n)/1e-9, " ns [",
Benchmark::stddev(bv)/(m*n)/1e-9 ,"] ", tag, t==TRANS ? " [T]" : " [N]").test_eq(ref, c);
};
auto bench_vm = [&tr, &m, &n, &reps](auto && f, char const * tag, trans_t t)
{
ra::Big<real, 2> aa({m, n}, ra::_0-ra::_1);
auto a = t==TRANS ? transpose<1, 0>(aa) : aa();
ra::Big<real, 1> b({a.size(0)}, 1-2*ra::_0);
ra::Big<real, 1> ref = gevm(b, a);
ra::Big<real, 1> c;
auto bv = Benchmark().repeats(reps).runs(4).run([&]() { c = f(b, a); });
tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n)/1e-9, " ns [",
Benchmark::stddev(bv)/(m*n)/1e-9 ,"] ", tag, t==TRANS ? " [T]" : " [N]").test_eq(ref, c);
};
tr.section(m, " x ", n, " times ", reps);
// some variants are way too slow to check with larger arrays.
if (k>0) {
bench_mv(gemv_i, "mv i", NOTRANS);
bench_mv(gemv_i, "mv i", TRANS);
bench_mv(gemv_j, "mv j", NOTRANS);
bench_mv(gemv_j, "mv j", TRANS);
bench_mv([&](auto const & a, auto const & b) { return gemv(a, b); }, "mv default", NOTRANS);
bench_mv([&](auto const & a, auto const & b) { return gemv(a, b); }, "mv default", TRANS);
bench_vm(gevm_i, "vm i", NOTRANS);
bench_vm(gevm_i, "vm i", TRANS);
bench_vm(gevm_j, "vm j", NOTRANS);
bench_vm(gevm_j, "vm j", TRANS);
bench_vm([&](auto const & a, auto const & b) { return gevm(a, b); }, "vm default", NOTRANS);
bench_vm([&](auto const & a, auto const & b) { return gevm(a, b); }, "vm default", TRANS);
}
};
bench_all(3, 10, 10, 10000);
bench_all(3, 100, 100, 100);
bench_all(3, 500, 500, 1);
bench_all(3, 10000, 1000, 1);
bench_all(3, 1000, 10000, 1);
bench_all(3, 100000, 100, 1);
bench_all(3, 100, 100000, 1);
return tr.summary();
}