Skip to content

Commit

Permalink
Merge pull request #15 from kelizhang/master
Browse files Browse the repository at this point in the history
2021.5.31: Adjustment for the project structure and code optimizations.
  • Loading branch information
kelizhang authored May 31, 2021
2 parents c13e167 + 9ca28a2 commit 744f693
Show file tree
Hide file tree
Showing 81 changed files with 2,081 additions and 1,198 deletions.
2 changes: 1 addition & 1 deletion gcastle/README.cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ print(mt.metrics)
* 算法库拓展:新增`GES`,`HPCI`,`TTPM`等因果结构学习算法,并提供简易的可配置脚本辅助进行相应算法的调用和运行。
* 真实场景数据集:将陆续公开一批来源于真实AIOPS场景的时间序列和事件序列数据集,其中真实的因果图标注来源于业务专家经验。

欢迎大家使用`gCastle`. 该项目尚处于起步阶段,欢迎各个经验等级的贡献者。有任何疑问及建议,包括修改bug、贡献算法、完善文档等,请在社区提交issue,我们会及时回复交流。
欢迎大家使用`gCastle`. 该项目尚处于起步阶段,欢迎各个经验等级的贡献者。有任何疑问及建议,包括修改bug、贡献算法、完善文档等,请在社区提交issue,我们会及时回复交流。
2 changes: 1 addition & 1 deletion gcastle/castle/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .gradient import NotearsSob
from .gradient import NotearsLowRank
from .gradient import GOLEM
from .gradient import GraN_DAG, Parameters
from .gradient import GraN_DAG
from .gradient import GAE
from .gradient import MCSL
from .gradient import RL
Expand Down
2 changes: 1 addition & 1 deletion gcastle/castle/algorithms/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .notears.low_rank import NotearsLowRank
from .notears.golem import GOLEM

from .gran_dag.gran_dag import GraN_DAG, Parameters
from .gran_dag.gran_dag import GraN_DAG
from .graph_auto_encoder.gae import GAE
from .masked_csl.mcsl import MCSL

Expand Down
2 changes: 1 addition & 1 deletion gcastle/castle/algorithms/gradient/corl1/corl1.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CORL1(BaseLearner):
>>> from castle.metrics import MetricsDAG
>>> true_dag, X = load_dataset(name='iid_test')
>>> n = CORL1()
>>> n.learn(X, dag=true_dag)
>>> n.learn(X)
>>> GraphDAG(n.causal_matrix, true_dag)
>>> met = MetricsDAG(n.causal_matrix, true_dag)
>>> print(met.metrics)
Expand Down
2 changes: 1 addition & 1 deletion gcastle/castle/algorithms/gradient/corl2/corl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CORL2(BaseLearner):
>>> from castle.metrics import MetricsDAG
>>> true_dag, X = load_dataset(name='iid_test')
>>> n = CORL2()
>>> n.learn(X, dag=true_dag)
>>> n.learn(X)
>>> GraphDAG(n.causal_matrix, true_dag)
>>> met = MetricsDAG(n.causal_matrix, true_dag)
>>> print(met.metrics)
Expand Down
2 changes: 1 addition & 1 deletion gcastle/castle/algorithms/gradient/gran_dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .gran_dag import GraN_DAG, Parameters
from .gran_dag import GraN_DAG
2 changes: 0 additions & 2 deletions gcastle/castle/algorithms/gradient/gran_dag/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .base_model import BaseModel
from .base_model import LearnableModel
from .base_model import NonlinearGauss
from .base_model import NonlinearGaussANM

Expand Down
29 changes: 0 additions & 29 deletions gcastle/castle/algorithms/gradient/gran_dag/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -86,7 +83,6 @@ def forward_given_params(self, x, weights, biases):
the parameters of each variable conditional
"""

num_zero_weights = 0
for k in range(self.hidden_num + 1):
# apply affine operator
if k == 0:
Expand All @@ -95,18 +91,13 @@ def forward_given_params(self, x, weights, biases):
else:
x = torch.einsum("tij,btj->bti", weights[k], x) + biases[k]

# count num of zeros
num_zero_weights += weights[k].numel() - torch.nonzero(weights[k]).size()[0]

# apply non-linearity
if k != self.hidden_num:
if self.nonlinear == "leaky-relu":
x = F.leaky_relu(x)
else:
x = torch.sigmoid(x)

self.zero_weights_ratio = num_zero_weights / float(self.numel_weights)

return torch.unbind(x, 1)

def get_w_adj(self):
Expand Down Expand Up @@ -218,26 +209,6 @@ def get_grad_norm(self, mode="wbx"):

return torch.sqrt(grad_norm)

def save_parameters(self, exp_path, mode="wbx"):
"""
Parameters
----------
exp_path : str
path for saving model parameters
mode : str
w=weights, b=biases, x=extra_params (order is irrelevant)
"""
params = self.get_parameters(mode=mode)
# save
with open(os.path.join(exp_path, "params_"+mode), 'wb') as f:
pickle.dump(params, f)

def load_parameters(self, exp_path, mode="wbx"):
with open(os.path.join(exp_path, "params_"+mode), 'rb') as f:
params = pickle.load(f)
self.set_parameters(params, mode=mode)

def get_distribution(self, density_params):
raise NotImplementedError

Expand Down
Loading

0 comments on commit 744f693

Please sign in to comment.