-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcca.go
67 lines (60 loc) · 1.78 KB
/
cca.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package cca
import (
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
)
// CCA performs a canonical correlation analysis of the input data x
// and y, columns of which should be interpretable as two sets of
// measurements on the same observations (rows). These observations
// are optionally weighted by weights.
//
// CCA will return an error if the inputs x and y do not have the same
// number of rows.
//
// The vector weights is used to weight the observations. If weights is NULL,
// each weight is considered to have a value of one, otherwise the length of
// weights must match the number of observations (rows of both x and y) or
// CanonicalCorrelations will return an error..
func CCA(x, y blas64.GeneralCols, weights []float64) (ccors []float64, pVecs, qVecs, phiVs, psiVs blas64.GeneralCols, err error) {
var xdata, ydata mat.Dense
xdata.SetRawMatrix(rowMajor(x))
ydata.SetRawMatrix(rowMajor(y))
var cc stat.CC
err = cc.CanonicalCorrelations(&xdata, &ydata, weights)
if err != nil {
return nil, pVecs, qVecs, phiVs, psiVs, err
}
ccors = cc.CorrsTo(nil)
var _pVecs, _qVecs, _phiVs, _psiVs mat.Dense
cc.LeftTo(&_pVecs, true)
cc.RightTo(&_qVecs, true)
cc.LeftTo(&_phiVs, false)
cc.RightTo(&_psiVs, false)
return ccors,
colMajor(_pVecs.RawMatrix()),
colMajor(_qVecs.RawMatrix()),
colMajor(_phiVs.RawMatrix()),
colMajor(_psiVs.RawMatrix()),
err
}
func rowMajor(a blas64.GeneralCols) blas64.General {
t := blas64.General{
Rows: a.Rows,
Cols: a.Cols,
Data: make([]float64, len(a.Data)),
Stride: a.Cols,
}
t.From(a)
return t
}
func colMajor(a blas64.General) blas64.GeneralCols {
t := blas64.GeneralCols{
Rows: a.Rows,
Cols: a.Cols,
Data: make([]float64, len(a.Data)),
Stride: a.Rows,
}
t.From(a)
return t
}