diff --git a/pylift/generate_data.py b/pylift/generate_data.py index 6886ab6..19491a6 100644 --- a/pylift/generate_data.py +++ b/pylift/generate_data.py @@ -126,7 +126,8 @@ def sim_pte(N=1000, n_features=20, beta=None, rho=0, sigma=np.sqrt(2), beta_den= ### Generate x ~ N~p(0, rho) mean = np.zeros(p) - cov = np.identity(n=p) + cov = np.full((p,p), rho) + np.fill_diagonal(cov, 1) x = np.random.multivariate_normal(mean, cov, N) ### Random error from N~1(0,1)