diff --git a/ppca/_ppca.py b/ppca/_ppca.py index 45e5f2b..8a0d127 100644 --- a/ppca/_ppca.py +++ b/ppca/_ppca.py @@ -22,7 +22,10 @@ def _standardize(self, X): return (X - self.means) / self.stds - def fit(self, data, d=None, tol=1e-4, min_obs=10, verbose=False): + def fit(self, data, d=None, tol=1e-4, min_obs=10, verbose=False, seed=None): + + if seed is not None: + np.random.seed(seed) self.raw = data self.raw[np.isinf(self.raw)] = np.max(self.raw[np.isfinite(self.raw)]) @@ -45,7 +48,7 @@ def fit(self, data, d=None, tol=1e-4, min_obs=10, verbose=False): if d is None: d = data.shape[1] - + if self.C is None: C = np.random.randn(D, d) else: @@ -131,7 +134,7 @@ def _calc_var(self): def save(self, fpath): np.save(fpath, self.C) - + def load(self, fpath): assert os.path.isfile(fpath)