Skip to content

Nan in p matrix when running with batch covariates #1

@jg9zk

Description

@jg9zk

First of all, thanks so much for making this! I've been wanting to use the ZINB-Wave method on my dataset, but it was too big to run.

I got your implementation to run without including batch variables. When I add in the batch variables, it can run for some of the data, but an error appears eventually.

groups = np.unique(adata.obs['GEO']).shape[0]
one_hot_encoding = one_hot(torch.tensor(list(adata.obs['GEO'].cat.codes)).long(), num_classes=groups).reshape((-1,groups))
X = torch.ones((batch_size, 1+groups))

PATH = './zinb_grad_constants_geo/'
model = ZINB_grad.ZINB_Grad(Y = torch.randint(0, 100, size = size), X = X, K = K, device=device).to(device)

for i, data in enumerate(data_loader):

    batch = data["X"].to(device, dtype=torch.int32)
    X[:,1:] = one_hot_encoding[(batch_size*i):(batch_size*(i+1))]

    # Using the alphas, betas, and theta from the dummy model.
    model = ZINB_grad.ZINB_Grad(Y = batch, K = K, X = X, device = device,
                            alpha_mu = model.alpha_mu,
                            alpha_pi = model.alpha_pi,
                            beta_mu = model.beta_mu,
                            beta_pi =  model.beta_pi, 
                            log_theta = model.log_theta).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = 0.1, weight_decay=.001)

    losses, neg_log_liks = ZINB_grad.train_ZINB(batch, optimizer, model, epochs = n_epochs)

ValueError: Expected parameter probs (Tensor of shape (9556, 3000)) of distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([9556, 3000])) to satisfy the constraint HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), but found invalid values:

I think it's saying the issue is with the matrix supplied to probs in _loss(), which is calculated in forward(). The p matrix seems to sometimes get a Nan in one of its elements. Everything else is within the bounds the model wants.

self.mu[torch.isnan(p)] yields Inf, so something is happening there. self.log_mu[torch.isnan(p)] is about 88.8, which exceeds pytorch's default float32 limit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions