diff --git a/examples/mixed_hmm/__init__.py b/examples/mixed_hmm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py new file mode 100644 index 0000000000..5e74050b7e --- /dev/null +++ b/examples/mixed_hmm/experiment.py @@ -0,0 +1,191 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import logging +import os +import json +import uuid + +import numpy as np +import torch + +import pyro +from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO +from pyro.optim import Adam +from pyro.util import ignore_jit_warnings + +from model import model_generic, guide_generic +from seal_data import prepare_seal +from shark_data import prepare_shark + + +def aic_num_parameters(config): + + def _size(tensor): + """product of shape""" + s = 1 + for d in tensor.shape: + s = s * d + return s + + num_params = 0 + + for level in ["group", "individual", "timestep"]: + # count random effect parameters + if config[level]["random"] == "discrete": + num_params += _size(pyro.param("probs_e_{}".format(level))) + num_params += _size(pyro.param("theta_{}".format(level))) + elif config[level]["random"] == "continuous": + num_params += _size(pyro.param("loc_{}".format(level))) + num_params += _size(pyro.param("scale_{}".format(level))) + + # count fixed effect parameters + if config[level]["fixed"] is not None: + num_params += _size(pyro.param("beta_{}".format(level))) + + # count likelihood parameters + for coord, coord_config in config["observations"].items(): + num_params += sum([ + _size(pyro.param("{}_param_{}".format(coord, arg_name))) + for arg_name in coord_config["dist"].arg_constraints.keys() + ]) + # count zero-inflation parameters + if coord_config["zi"]: + num_params += _size(pyro.param("{}_zi_param".format(coord))) + + return num_params + + +def aic(model, guide, config): + neg_log_likelihood = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss(model, guide) + num_params = aic_num_parameters(config) + return 2. * neg_log_likelihood + 2. * num_params + + +def run_expt(args): + + data_dir = args["folder"] + dataset = args["dataset"] + seed = args["seed"] + optim = args["optim"] + lr = args["learnrate"] + timesteps = args["timesteps"] + schedule = [] if not args["schedule"] else [int(i) for i in args["schedule"].split(",")] + random_effects = {"group": args["group"], "individual": args["individual"]} + + pyro.enable_validation(args["validation"]) + pyro.set_rng_seed(seed) # reproducible random effect parameter init + + if dataset == "seal": + filename = os.path.join(data_dir, "prep_seal_data.RData") + config = prepare_seal(filename, random_effects) + elif dataset == "shark": + filename = os.path.join(data_dir, "gws_full.xlsx") + config = prepare_shark(filename, random_effects) + + model = lambda: model_generic(config) # for JITing + guide = lambda: guide_generic(config) + + losses = [] + # SGD + if optim == "sgd": + loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss + with pyro.poutine.trace(param_only=True) as param_capture: + loss_fn(model, guide) + params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] + optimizer = torch.optim.Adam(params, lr=lr) + + if schedule: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) + schedule_step_loss = False + else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + schedule_step_loss = True + + for t in range(timesteps): + + optimizer.zero_grad() + loss = loss_fn(model, guide) + loss.backward() + optimizer.step() + scheduler.step(loss.item() if schedule_step_loss else t) + losses.append(loss.item()) + + print("Loss: {}, AIC[{}]: ".format(loss.item(), t), + 2. * loss + 2. * aic_num_parameters(config)) + + # LBFGS + elif optim == "lbfgs": + loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss + with pyro.poutine.trace(param_only=True) as param_capture: + loss_fn(model, guide) + params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] + optimizer = torch.optim.LBFGS(params, lr=lr) + + if schedule: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) + schedule_step_loss = False + else: + schedule_step_loss = True + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + + for t in range(timesteps): + def closure(): + optimizer.zero_grad() + loss = loss_fn(model, guide) + loss.backward() + return loss + loss = optimizer.step(closure) + scheduler.step(loss.item() if schedule_step_loss else t) + losses.append(loss.item()) + print("Loss: {}, AIC[{}]: ".format(loss.item(), t), + 2. * loss + 2. * aic_num_parameters(config)) + + else: + raise ValueError("{} not supported optimizer".format(optim)) + + aic_final = aic(model, guide, config) + print("AIC final: {}".format(aic_final)) + + results = {} + results["args"] = args + results["sizes"] = config["sizes"] + results["likelihoods"] = losses + results["likelihood_final"] = losses[-1] + results["aic_final"] = aic_final.item() + results["aic_num_parameters"] = aic_num_parameters(config) + + if args["resultsdir"] is not None: + re_str = "g" + ("n" if (args["group"] is None or args["group"] == "none") else "d" if args["group"] == "discrete" else "c") + re_str += "i" + ("n" if (args["individual"] is None or args["individual"] == "none") else "d" if args["individual"] == "discrete" else "c") + results_filename = "expt_{}_{}_{}.json".format(args["dataset"], re_str, str(uuid.uuid4().hex)[0:5]) + with open(os.path.join(args["resultsdir"], results_filename), "w") as f: + json.dump(results, f) + + return results + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataset", default="seal", type=str) + parser.add_argument("-g", "--group", default="none", type=str) + parser.add_argument("-i", "--individual", default="none", type=str) + parser.add_argument("-f", "--folder", default="./", type=str) + parser.add_argument("-o", "--optim", default="sgd", type=str) + parser.add_argument("-lr", "--learnrate", default=0.05, type=float) + parser.add_argument("-t", "--timesteps", default=1000, type=int) + parser.add_argument("-r", "--resultsdir", default="./results", type=str) + parser.add_argument("-s", "--seed", default=101, type=int) + parser.add_argument("--schedule", default="", type=str) + parser.add_argument('--cuda', action='store_true') + parser.add_argument('--jit', action='store_true') + parser.add_argument('--validation', action='store_true') + args = parser.parse_args() + + if args.group == "none": + args.group = None + if args.individual == "none": + args.individual = None + + run_expt(vars(args)) diff --git a/examples/mixed_hmm/model.py b/examples/mixed_hmm/model.py new file mode 100644 index 0000000000..ac0ae64d3f --- /dev/null +++ b/examples/mixed_hmm/model.py @@ -0,0 +1,238 @@ +from __future__ import absolute_import, division, print_function + +import logging + +import numpy as np + +import torch +from torch.distributions import constraints + +import pyro +import pyro.distributions as dist +from pyro import poutine +from pyro.contrib.autoguide import AutoDelta +from pyro.infer import config_enumerate + + +def _index_param(param, ind, dim=-2): + """helper for advanced indexing black magic""" + # assume: dim < 0 + # assume: param.shape[dim:] == event_shape + # assume: index.shape == batch_shape + # assume: param.shape == batch_shape + event_shape + # goal: slice into an event_dim with index + # step 1: unsqueeze event dims in index + for d in range(len(param.shape[dim:])): + ind = ind.unsqueeze(-1) + # step 2: generate dummy indices for all other dimensions of param + inds = [None] * len(param.shape) + for d, sd in enumerate(reversed(param.shape)): + if dim == -d-1: + inds[-d-1] = ind + else: + inds[-d-1] = torch.arange(sd).reshape((sd,) + (1,) * d) + # step 3: use the index and dummy indices to select + res = param[tuple(inds)] + # XXX is this necessary? + # step 4: squeeze out the empty event_dim + return res.squeeze(dim) + + +def guide_generic(config): + """generic mean-field guide for continuous random effects""" + N_v = config["sizes"]["random"] + N_state = config["sizes"]["state"] + + if config["group"]["random"] == "continuous": + loc_g = pyro.param("loc_group", lambda: torch.zeros((N_state ** 2,))) + scale_g = pyro.param("scale_group", lambda: torch.ones((N_state ** 2,)), + constraint=constraints.positive) + + # initialize individual-level random effect parameters + N_c = config["sizes"]["group"] + if config["individual"]["random"] == "continuous": + loc_i = pyro.param("loc_individual", lambda: torch.zeros((N_c, N_state ** 2,))) + scale_i = pyro.param("scale_individual", lambda: torch.ones((N_c, N_state ** 2,)), + constraint=constraints.positive) + + N_c = config["sizes"]["group"] + with pyro.plate("group", N_c) as c: + + if config["group"]["random"] == "continuous": + eps_g = pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1), + ) # infer={"num_samples": 10}) + + N_s = config["sizes"]["individual"] + with pyro.plate("individual", N_s) as s, poutine.mask(mask=config["individual"]["mask"]): + + # individual-level random effects + if config["individual"]["random"] == "continuous": + eps_i = pyro.sample("eps_i", dist.Normal(loc_i, scale_i).to_event(1), + ) # infer={"num_samples": 10}) + + +@config_enumerate +def model_generic(config): + """generic hierarchical mixed-effects hidden markov model""" + + N_v = config["sizes"]["random"] + N_state = config["sizes"]["state"] + + # initialize fixed effect parameters - all the same size + if config["group"]["fixed"] is not None: + N_fg = config["group"]["fixed"].shape[-1] + beta_g = pyro.param("beta_group", lambda: torch.ones((N_fg, N_state ** 2))) + + if config["individual"]["fixed"] is not None: + N_fi = config["individual"]["fixed"].shape[-1] + beta_i = pyro.param("beta_individual", lambda: torch.ones((N_fi, N_state ** 2))) + + if config["timestep"]["fixed"] is not None: + N_ft = config["timestep"]["fixed"].shape[-1] + beta_t = pyro.param("beta_timestep", lambda: torch.ones((N_ft, N_state ** 2))) + + # initialize group-level random effect parameterss + if config["group"]["random"] == "discrete": + probs_e_g = pyro.param("probs_e_group", lambda: torch.randn((N_v,)).abs(), constraint=constraints.simplex) + theta_g = pyro.param("theta_group", lambda: torch.randn((N_v, N_state ** 2))) + elif config["group"]["random"] == "continuous": + loc_g = torch.zeros((N_state ** 2,)) + scale_g = torch.ones((N_state ** 2,)) + else: # none + pass + + # initialize individual-level random effect parameters + N_c = config["sizes"]["group"] + if config["individual"]["random"] == "discrete": + probs_e_i = pyro.param("probs_e_individual", + lambda: torch.randn((N_c, N_v,)).abs(), + constraint=constraints.simplex) + theta_i = pyro.param("theta_individual", + lambda: torch.randn((N_c, N_v, N_state ** 2))) + elif config["individual"]["random"] == "continuous": + loc_i = torch.zeros((N_c, N_state ** 2,)) + scale_i = torch.ones((N_c, N_state ** 2,)) + else: # none + pass + + # initialize likelihood parameters + for coord, coord_config in config["observations"].items(): + if coord_config["zi"]: + pyro.param("{}_zi_param".format(coord), lambda: torch.ones((N_state,2))) + for arg_name, arg_constraint in coord_config["dist"].arg_constraints.items(): + pyro.param("{}_param_{}".format(coord, arg_name), + lambda: torch.randn((N_state,)).abs(), + constraint=arg_constraint) + + # initialize gamma to uniform + gamma = torch.zeros((N_state ** 2,)) + + N_c = config["sizes"]["group"] + with pyro.plate("group", N_c, dim=-1) as c: + + # group-level random effects + if config["group"]["random"] == "discrete": + # group-level discrete effect + e_g = pyro.sample("e_g", dist.Categorical(probs_e_g)) + eps_g = _index_param(theta_g, e_g, dim=-2) + elif config["group"]["random"] == "continuous": + eps_g = pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1), + ) # infer={"num_samples": 10}) + else: + eps_g = 0. + + # add group-level random effect to gamma + gamma = gamma + eps_g + + # group-level fixed effects + if config["group"]["fixed"] is not None: + covariates_g = config["individual"]["fixed"] + beta_g = pyro.param("beta_group") + fixed_g = torch.einsum("...f,fs->...s", + [covariates_g, beta_g]) + gamma = gamma + fixed_g + + N_s = config["sizes"]["individual"] + with pyro.plate("individual", N_s, dim=-2) as s, poutine.mask(mask=config["individual"]["mask"]): + + # individual-level random effects + if config["individual"]["random"] == "discrete": + # individual-level discrete effect + e_i = pyro.sample("e_i", dist.Categorical(probs_e_i)) + eps_i = _index_param(theta_i, e_i, dim=-2) + # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v + elif config["individual"]["random"] == "continuous": + eps_i = pyro.sample("eps_i", dist.Normal(loc_i, scale_i).to_event(1), + ) # infer={"num_samples": 10}) + else: + eps_i = 0. + + # add individual-level random effect to gamma + gamma = gamma + eps_i + + # individual-level fixed effects + if config["individual"]["fixed"] is not None: + covariates_i = config["individual"]["fixed"] + beta_i = pyro.param("beta_individual") + fixed_i = torch.einsum("...f,fs->...s", + [covariates_i, beta_i]) + gamma = gamma + fixed_i + + # TODO initialize y from stationary distribution? + y = torch.tensor(0).long() + + N_t = config["sizes"]["timesteps"] + for t in pyro.markov(range(N_t)): + with poutine.mask(mask=config["timestep"]["mask"][..., t]): + # per-timestep fixed effects + gamma_t = gamma # per-timestep variable + if config["timestep"]["fixed"] is not None: + covariates_t = config["timestep"]["fixed"][..., t, :] + beta_t = pyro.param("beta_timestep") + fixed_t = torch.einsum("...f,fs->...s", + [covariates_t, beta_t]) + gamma_t = gamma_t + fixed_t + + # finally, reshape gamma as batch of transition matrices + gamma_t = gamma_t.reshape(tuple(gamma_t.shape[:-1]) + (N_state, N_state)) + + # we've accounted for all effects, now actually compute gamma_y + # gamma_y = _index_gamma(gamma_t, y, t) + gamma_y = _index_param(gamma_t, y, dim=-2) + y = pyro.sample("y_{}".format(t), dist.Categorical(logits=gamma_y)) + + # multivariate observations with different distributions + for coord, coord_config in config["observations"].items(): + coord_params = [ + _index_param(pyro.param("{}_param_{}".format(coord, arg_name)), y, dim=-1) + for arg_name in coord_config["dist"].arg_constraints.keys() + ] + coord_dist = coord_config["dist"](*coord_params) + + if not coord_config["zi"]: + pyro.sample("{}_{}".format(coord, t), + coord_dist, + obs=coord_config["values"][..., t]) + elif coord_config["zi"]: + # zero-inflation with MaskedMixture + coord_zi = _index_param(pyro.param("{}_zi_param".format(coord)), y, dim=-2) + # coord_zi_mask = coord_config["values"][..., t] == 1e-4 + # coord_zi_scale = dist.Categorical(logits=coord_zi).log_prob(coord_zi_mask).exp() + coord_zi_mask = pyro.sample("{}_zi_{}".format(coord, t), + dist.Categorical(logits=coord_zi), + obs=(coord_config["values"][..., t] == 1e-4)) + coord_zi_zero_dist = dist.Delta(v=torch.tensor(1e-4)) + coord_zi_dist = dist.MaskedMixture(coord_zi_mask, coord_dist, coord_zi_zero_dist) + + # do a bit of gross nan error checking... + # if t > 5 and t < 10: + # nan_check_mask = config["timestep"]["mask"][..., t] & config["individual"]["mask"] + # assert not torch.isnan(coord_zi_dist.log_prob(coord_config["values"][..., t]).sum(dim=0).squeeze()[nan_check_mask]).any(), \ + # "nan zi at {}_{}".format(coord, t) + + # assert not (coord_zi_dist.log_prob(coord_config["values"][..., t]).sum(dim=0).squeeze()[nan_check_mask] == 0.).all(), \ + # "zero zi at {}_{}".format(coord, t) + + pyro.sample("{}_{}".format(coord, t), + coord_zi_dist, + obs=coord_config["values"][..., t]) diff --git a/examples/mixed_hmm/seal_data.py b/examples/mixed_hmm/seal_data.py new file mode 100644 index 0000000000..1b59af8a2f --- /dev/null +++ b/examples/mixed_hmm/seal_data.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import, division, print_function + +import logging +import os + +import numpy as np +import pandas as pd +import pyreadr + +import torch + +import pyro +import pyro.distributions as dist + + +def prepare_seal(filename, random_effects): + seal_df = pyreadr.read_r(filename)['hsData'] + obs_keys = ["step", "angle", "omega"] + # data format for z1, z2: + # single tensor with shape (individual, group, time, coords) + observations = torch.zeros((20, 2, 1800, len(obs_keys))).fill_(float("-inf")) + for g, (group, group_df) in enumerate(seal_df.groupby("sex")): + for i, (ind, ind_df) in enumerate(group_df.groupby("ID")): + for o, obs_key in enumerate(obs_keys): + observations[i, g, 0:len(ind_df), o] = torch.tensor(ind_df[obs_key].values) + + observations[torch.isnan(observations)] = float("-inf") + + # make masks + # mask_i should mask out individuals, it applies at all timesteps + mask_i = (observations > float("-inf")).any(dim=-1).any(dim=-1) # time nonempty + + # mask_t handles padding for time series of different length + mask_t = (observations > float("-inf")).all(dim=-1) # include non-inf + + # temporary hack to avoid zero-inflation issues + # observations[observations == 0.] = 1e-4 + observations[(observations == 0.) | (observations == float("-inf"))] = 1e-4 + assert not torch.isnan(observations).any() + + # observations = observations[..., 5:11, :] # truncate for testing + + config = { + "sizes": { + "state": 3, + "random": 4, + "group": observations.shape[1], + "individual": observations.shape[0], + "timesteps": observations.shape[2], + }, + "group": {"random": random_effects["group"], "fixed": None}, + "individual": {"random": random_effects["individual"], "fixed": None, "mask": mask_i}, + "timestep": {"random": None, "fixed": None, "mask": mask_t}, + "observations": { + "step": {"dist": dist.Gamma, "zi": True, "values": observations[..., 0]}, + "angle": {"dist": dist.VonMises, "zi": False, "values": observations[..., 1]}, + "omega": {"dist": dist.Beta, "zi": True, "values": observations[..., 2]}, + }, + } + + return config + + + diff --git a/examples/mixed_hmm/shark_data.py b/examples/mixed_hmm/shark_data.py new file mode 100644 index 0000000000..7069d36ba4 --- /dev/null +++ b/examples/mixed_hmm/shark_data.py @@ -0,0 +1,152 @@ +from __future__ import absolute_import, division, print_function + +import logging + +import numpy as np +import pandas as pd +import pyreadr + +import utm + +import torch + +import pyro +import pyro.distributions as dist + + +def _encode_shark_df(tracks_df, summary_df): + """ + This function converts the excel-derived encoding of the original data + to something more uniform and machine-readable for preparing experiments + """ + shark_df = tracks_df.copy() + + # split up the columns + # group column + shark_df["sex"] = shark_df["Shark sex and track number"].str[2:3] + # individual column (track) + shark_df["ID"] = shark_df["Shark sex and track number"] + + # animal column + shark_df["id_name"] = shark_df["Shark sex and track number"].str[0:5] + shark_df["id_name"] = shark_df["id_name"].str.strip("T ") + + # covariates + # tail length column from summary + shark_df.loc[:, "TL"] = pd.Series(np.random.randn(len(shark_df["Latitude"])), index=shark_df.index) + for individual_name in shark_df["id_name"].unique(): + individual_length = summary_df[summary_df["Shark ID"] == individual_name]["TL (cm)"].values[0] + shark_df["TL"][shark_df["id_name"] == individual_name] = individual_length + + # make chum value into something usable + shark_df.loc[:, "chum"] = pd.Series(np.random.randn(len(shark_df["Latitude"])), + index=shark_df.index) + shark_df["chum"][shark_df["Cage Dive Boat"] == 'x'] = 1. + shark_df["chum"][shark_df["Cage Dive Boat"] != 'x'] = 0. + + # time covariates + # XXX gross pandas stuff + shark_df.loc[:, "time_num"] = pd.Series(np.array([t.hour + t.minute / 60. for t in list(shark_df["Time"].values)], dtype=np.float32), index=shark_df.index) + shark_df["sin_time"] = shark_df["time_num"].apply(lambda t: np.sin(t * np.pi * 2. / 288.)) + shark_df["cos_time"] = shark_df["time_num"].apply(lambda t: np.cos(t * np.pi * 2. / 288.)) + + # add back conversion of lat/lon to step/angle + # 1. convert to x/y projection with utm + # 2. compute length of each difference + # 3. compute angle between differences + shark_df["step"] = pd.Series(np.zeros((len(shark_df),), dtype=np.float32), + index=shark_df.index) + shark_df["angle"] = pd.Series(np.zeros((len(shark_df),), dtype=np.float32), + index=shark_df.index) + for trackname, track_df in shark_df.groupby("ID"): + track_lat, track_lon = track_df["Latitude"], track_df["Longitude"] + x, y = np.zeros((len(track_lat),)), np.zeros((len(track_lon),)) + for i, (lat, lon) in enumerate(zip(list(track_lat.values), list(track_lon.values))): + x[i], y[i], _, _, = utm.from_latlon(lat, lon) + + xy = np.stack([x, y], axis=-1) / 1000. # km + step = xy[1:] - xy[:-1] + step_length = np.sqrt(np.einsum("ab,ab->a", step, step)) + # dstep = step[1:] - step[:-1] + step_angle = np.arccos( + np.einsum("ab,ab->a", step[1:], step[:-1]) / (step_length[1:] * step_length[:-1])) + + step_length[np.isnan(step_length)] = 0. + step_angle[np.isnan(step_angle)] = 0. + + if len(track_df) > 2: # cover a weird edge case "WSF6 T4 (T5)" + shark_df["step"][shark_df.ID == trackname] = np.concatenate([np.zeros((1,), dtype=np.float32), step_length]) + shark_df["angle"][shark_df.ID == trackname] = np.concatenate([np.zeros((2,), dtype=np.float32), step_angle]) + + # sanity checks + assert (shark_df["step"][shark_df.ID == trackname].values != 0.).any() + assert (shark_df["angle"][shark_df.ID == trackname].values != 0.).any() + + return shark_df + + +def prepare_shark(filename, random_effects): + + tracks_df = pd.read_excel(filename, sheet_name=0) + summary_df = pd.read_excel(filename, sheet_name=1) + + shark_df = _encode_shark_df(tracks_df, summary_df) + obs_keys = ["step", "angle"] + + # data format for z1, z2: + # single tensor with shape (individual, group, time, coords) + observations = torch.zeros((100, 2, 270, len(obs_keys))).fill_(float("-inf")) + for g, (group, group_df) in enumerate(shark_df.groupby("sex")): + for i, (ind, ind_df) in enumerate(group_df.groupby("ID")): + for o, obs_key in enumerate(obs_keys): + observations[i, g, 0:len(ind_df), o] = torch.tensor(ind_df[obs_key].values) + + # make covariates: chum (timestep), time sin/cos (timestep), size (individual) + individual_cov = torch.zeros((100, 2, 1)).fill_(float("-inf")) + timestep_cov = torch.zeros((100, 2, 270, 4)).fill_(float("-inf")) + for g, (group, group_df) in enumerate(shark_df.groupby("sex")): + for i, (ind, ind_df) in enumerate(group_df.groupby("ID")): + individual_cov[i, g, 0:1] = torch.tensor(ind_df["TL"].values[0:1]) + timestep_cov[i, g, 0:len(ind_df), 0] = torch.tensor(ind_df["sin_time"].values) + timestep_cov[i, g, 0:len(ind_df), 1] = torch.tensor(ind_df["cos_time"].values) + # chum is an indicator so we expand as one-hot + timestep_cov[i, g, 0:len(ind_df), 2] = torch.tensor(ind_df["chum"].values) + timestep_cov[i, g, 0:len(ind_df), 3] = torch.tensor(1. - ind_df["chum"].values) + + observations[torch.isnan(observations)] = float("-inf") + + # make masks + # mask_i should mask out individuals, it applies at all timesteps + mask_i = (observations > float("-inf")).any(dim=-1).any(dim=-1) # time nonempty + + # mask_t handles padding for time series of different length + mask_t = (observations > float("-inf")).all(dim=-1) # include non-inf + + # temporary hack to avoid zero-inflation issues + # observations[observations == 0.] = 1e-4 + observations[(observations == 0.) | (observations == float("-inf"))] = 1e-4 + assert not torch.isnan(observations).any() + + timestep_cov[(timestep_cov == 0.) | (timestep_cov == float("-inf"))] = 1e-4 + individual_cov[(individual_cov == 0.) | (individual_cov == float("-inf"))] = 1e-4 + + # observations = observations[..., 5:11, :] # truncate for testing + + config = { + "sizes": { + "state": 2, + "random": 3, + "group": observations.shape[1], + "individual": observations.shape[0], + "timesteps": observations.shape[2], + }, + "group": {"random": random_effects["group"], "fixed": None}, + "individual": {"random": random_effects["individual"], "fixed": individual_cov, "mask": mask_i}, + "timestep": {"random": None, "fixed": timestep_cov, "mask": mask_t}, + "observations": { + "step": {"dist": dist.Gamma, "zi": True, "values": observations[..., 0]}, + "angle": {"dist": dist.VonMises, "zi": False, "values": observations[..., 1]}, + }, + } + + return config diff --git a/examples/mixed_hmm/sweep.py b/examples/mixed_hmm/sweep.py new file mode 100644 index 0000000000..d7af80c6e3 --- /dev/null +++ b/examples/mixed_hmm/sweep.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import os +import itertools +import collections +import subprocess +import multiprocessing + +from experiment import run_expt + + +config_components = collections.OrderedDict( + dataset = ["seal",], # "seal",], + # dataset = ["shark", "seal"], + group = ["none", "discrete", "continuous"], + individual = ["none", "discrete", "continuous"], + folder = ["./"], + optim = ["sgd",], + learnrate = [0.1, 0.05, 0.01, 0.005], + timesteps = [300, 600,], + resultsdir = ["./results",], + seed = [101, 102, 103, 104, 105], + schedule = ["", "10,20,60", "50,100,400", "100,200,600"], + validation = [False,], +) + +configs = [{k: v for k, v in zip(config_components.keys(), c)} + for c in itertools.product(*list(config_components.values()))] + +with multiprocessing.Pool(24) as p: + p.map(run_expt, configs)