-
Notifications
You must be signed in to change notification settings - Fork 1
/
NeurocircuitX_mix.py
48 lines (41 loc) · 2.37 KB
/
NeurocircuitX_mix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Author : Ziyuan Ye
@Email : [email protected]
'''
import numpy as np
import os
def stdlize(data: np.ndarray):
data = np.array([(_data - np.min(data)) / (np.max(data) - np.min(data))
for _data in data]).reshape(-1, 1)
return data
def mix_acc():
cate_list = ['0bk_body', '0bk_faces', '0bk_places', '0bk_tools', '2bk_body', '2bk_faces', '2bk_places', '2bk_tools',
'loss', 'win', 't', 'lf', 'lh', 'rh', 'rf', 'math', 'story', 'mental', 'rnd', 'match', 'relation',
'fear', 'neut']
for cate in cate_list:
# ab_weight = np.loadtxt(r'.\result_cv\gcn\saliency_result\yeo17\ablation\{}.csv'.format(cate))
# ke_weight = np.loadtxt(r'.\result_cv\gcn\saliency_result\yeo17\keep\{}.csv'.format(cate))
# ab_weight = np.loadtxt(r'.\result_cv\gat\saliency_result\yeo17\ablation\{}.csv'.format(cate))
# ke_weight = np.loadtxt(r'.\result_cv\gat\saliency_result\yeo17\keep\{}.csv'.format(cate))
# ab_weight = np.loadtxt(r'.\result_cv\stgcn\saliency_result\yeo17\ablation\{}.csv'.format(cate))
# ke_weight = np.loadtxt(r'.\result_cv\stgcn\saliency_result\yeo17\keep\{}.csv'.format(cate))
# ab_weight = np.loadtxt(r'.\result_cv\stpgcn\saliency_result\yeo17\ablation\{}.csv'.format(cate))
# ke_weight = np.loadtxt(r'.\result_cv\stpgcn\saliency_result\yeo17\keep\{}.csv'.format(cate))
ab_weight = np.loadtxt(r'.\result_cv\MMP\mlp_mixer\saliency_result\yeo17\ablation\{}.csv'.format(cate))
ke_weight = np.loadtxt(r'.\result_cv\MMP\mlp_mixer\saliency_result\yeo17\keep\{}.csv'.format(cate))
save_array = []
for i in range(379):
mix_weight = ab_weight[i] * 0.5 + ke_weight[i] * 0.5
save_array.append(mix_weight)
save_array = np.array(save_array)
save_array = stdlize(save_array)
# pth = r'.\result_cv\gcn\saliency_result\yeo17\mix'
# pth = r'.\result_cv\gat\saliency_result\yeo17\mix'
# pth = r'.\result_cv\stgcn\saliency_result\yeo17\mix'
# pth = r'.\result_cv\stpgcn\saliency_result\yeo17\mix'
pth = r'.\result_cv\MMP\mlp_mixer\saliency_result\yeo17\mix'
np.savetxt(os.path.join(pth,'{}.csv'.format(cate)),save_array,delimiter=",")
print('finish')
mix_acc()