First of all, thanks so much for making this! I've been wanting to use the ZINB-Wave method on my dataset, but it was too big to run.
I got your implementation to run without including batch variables. When I add in the batch variables, it can run for some of the data, but an error appears eventually.
groups = np.unique(adata.obs['GEO']).shape[0]
one_hot_encoding = one_hot(torch.tensor(list(adata.obs['GEO'].cat.codes)).long(), num_classes=groups).reshape((-1,groups))
X = torch.ones((batch_size, 1+groups))
PATH = './zinb_grad_constants_geo/'
model = ZINB_grad.ZINB_Grad(Y = torch.randint(0, 100, size = size), X = X, K = K, device=device).to(device)
for i, data in enumerate(data_loader):
batch = data["X"].to(device, dtype=torch.int32)
X[:,1:] = one_hot_encoding[(batch_size*i):(batch_size*(i+1))]
# Using the alphas, betas, and theta from the dummy model.
model = ZINB_grad.ZINB_Grad(Y = batch, K = K, X = X, device = device,
alpha_mu = model.alpha_mu,
alpha_pi = model.alpha_pi,
beta_mu = model.beta_mu,
beta_pi = model.beta_pi,
log_theta = model.log_theta).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1, weight_decay=.001)
losses, neg_log_liks = ZINB_grad.train_ZINB(batch, optimizer, model, epochs = n_epochs)
ValueError: Expected parameter probs (Tensor of shape (9556, 3000)) of distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([9556, 3000])) to satisfy the constraint HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), but found invalid values:
I think it's saying the issue is with the matrix supplied to probs in _loss(), which is calculated in forward(). The p matrix seems to sometimes get a Nan in one of its elements. Everything else is within the bounds the model wants.
self.mu[torch.isnan(p)] yields Inf, so something is happening there. self.log_mu[torch.isnan(p)] is about 88.8, which exceeds pytorch's default float32 limit.
First of all, thanks so much for making this! I've been wanting to use the ZINB-Wave method on my dataset, but it was too big to run.
I got your implementation to run without including batch variables. When I add in the batch variables, it can run for some of the data, but an error appears eventually.
ValueError: Expected parameter probs (Tensor of shape (9556, 3000)) of distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([9556, 3000])) to satisfy the constraint HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), but found invalid values:
I think it's saying the issue is with the matrix supplied to probs in _loss(), which is calculated in forward(). The p matrix seems to sometimes get a Nan in one of its elements. Everything else is within the bounds the model wants.
self.mu[torch.isnan(p)] yields Inf, so something is happening there. self.log_mu[torch.isnan(p)] is about 88.8, which exceeds pytorch's default float32 limit.