diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index 5939df2..1f8bb33 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -8,7 +8,7 @@ import pandas as pd import treedata as td -from pycea.utils import _check_tree_overlap, get_keyed_node_data, get_keyed_obs_data, get_root, get_trees +from pycea.utils import _check_tree_overlap, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root, get_trees def _most_common(arr: np.ndarray) -> Any: @@ -45,15 +45,18 @@ def _remove_node_attributes(tree: nx.DiGraph, key: str) -> None: def _reconstruct_fitch_hartigan( - tree: nx.DiGraph, key: str, missing: str | None = None, index: int | None = None + tree: nx.DiGraph, key: str, missing: str | None = None, index: int | None = None, fixed_nodes: set | None = None ) -> None: """Reconstructs ancestral states using the Fitch-Hartigan algorithm.""" + def _is_fixed(node, value): + return fixed_nodes is not None and node in fixed_nodes and value is not None and value != missing + # Recursive function to calculate the downpass def downpass(node): - # Base case: leaf - if tree.out_degree(node) == 0: - value = _get_node_value(tree, node, key, index) + value = _get_node_value(tree, node, key, index) + # Base case: leaf or fixed observed internal node + if tree.out_degree(node) == 0 or _is_fixed(node, value): if value == missing: tree.nodes[node]["value_set"] = missing else: @@ -78,6 +81,11 @@ def downpass(node): # Recursive function to calculate the uppass def uppass(node, parent_state=None): value = _get_node_value(tree, node, key, index) + if _is_fixed(node, value): + # Fixed node: keep its observed value, propagate it to children + for child in tree.successors(node): + uppass(child, value) + return if value is None: if parent_state and parent_state in tree.nodes[node]["value_set"]: value = parent_state @@ -107,6 +115,7 @@ def _reconstruct_sankoff( missing: str | None = None, default: str | None = None, index: int | None = None, + fixed_nodes: set | None = None, ) -> None: """Reconstructs ancestral states using the Sankoff algorithm.""" # Set up @@ -115,16 +124,19 @@ def _reconstruct_sankoff( cost_matrix = costs.to_numpy() value_to_index = {value: i for i, value in enumerate(alphabet)} + def _is_fixed(node, value): + return fixed_nodes is not None and node in fixed_nodes and value is not None and value != missing + # Recursive function to calculate the Sankoff scores def sankoff_scores(node): - # Base case: leaf - if tree.out_degree(node) == 0: - leaf_value = _get_node_value(tree, node, key, index) - if leaf_value == missing: + node_value = _get_node_value(tree, node, key, index) + # Base case: leaf or fixed observed internal node + if tree.out_degree(node) == 0 or _is_fixed(node, node_value): + if node_value == missing: return np.zeros(num_states) else: scores = np.full(num_states, float("inf")) - scores[value_to_index[leaf_value]] = 0 + scores[value_to_index[node_value]] = 0 return scores # Recursive case: internal node else: @@ -143,16 +155,25 @@ def sankoff_scores(node): # Recursive function to traceback the Sankoff scores def traceback(node, parent_value_index): for i, child in enumerate(tree.successors(node)): - child_value_index = tree.nodes[node]["_pointers"][parent_value_index, i] - _set_node_value(tree, child, key, alphabet[child_value_index], index) + child_value = _get_node_value(tree, child, key, index) + if _is_fixed(child, child_value): + # Fixed node: keep its value, traceback using its fixed index + child_value_index = value_to_index[child_value] + else: + child_value_index = tree.nodes[node]["_pointers"][parent_value_index, i] + _set_node_value(tree, child, key, alphabet[child_value_index], index) traceback(child, child_value_index) # Get scores root = [n for n, d in tree.in_degree() if d == 0][0] root_scores = sankoff_scores(root) # Reconstruct ancestral states - root_value_index = np.argmin(root_scores) - _set_node_value(tree, root, key, alphabet[root_value_index], index) + root_value = _get_node_value(tree, root, key, index) + if _is_fixed(root, root_value): + root_value_index = value_to_index[root_value] + else: + root_value_index = np.argmin(root_scores) + _set_node_value(tree, root, key, alphabet[root_value_index], index) traceback(root, root_value_index) # Clean up for node in tree.nodes: @@ -160,12 +181,14 @@ def traceback(node, parent_value_index): del tree.nodes[node]["_pointers"] -def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None) -> None: +def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None: """Reconstructs ancestral by averaging the values of the children.""" def subtree_mean(node): - if tree.out_degree(node) == 0: - return _get_node_value(tree, node, key, index), 1 + val = _get_node_value(tree, node, key, index) + is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None + if tree.out_degree(node) == 0 or is_fixed: + return val, 1 else: values, weights = [], [] for child in tree.successors(node): @@ -180,12 +203,16 @@ def subtree_mean(node): subtree_mean(root) -def _reconstruct_list(tree: nx.DiGraph, key: str, sum_func: Callable, index: int | None) -> None: +def _reconstruct_list( + tree: nx.DiGraph, key: str, sum_func: Callable, index: int | None, fixed_nodes: set | None = None +) -> None: """Reconstructs ancestral states by concatenating the values of the children.""" def subtree_list(node): - if tree.out_degree(node) == 0: - return [_get_node_value(tree, node, key, index)] + val = _get_node_value(tree, node, key, index) + is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None + if tree.out_degree(node) == 0 or is_fixed: + return [val] else: values = [] for child in tree.successors(node): @@ -205,20 +232,21 @@ def _ancestral_states( missing: str | None = None, default: str | None = None, index: int | None = None, + fixed_nodes: set | None = None, ) -> None: """Reconstructs ancestral states for a given attribute using a given method""" if method == "sankoff": if costs is None: raise ValueError("Costs matrix must be provided for Sankoff algorithm.") - _reconstruct_sankoff(tree, key, costs, missing, default, index) + _reconstruct_sankoff(tree, key, costs, missing, default, index, fixed_nodes) elif method == "fitch_hartigan": - _reconstruct_fitch_hartigan(tree, key, missing, index) + _reconstruct_fitch_hartigan(tree, key, missing, index, fixed_nodes) elif method == "mean": - _reconstruct_mean(tree, key, index) + _reconstruct_mean(tree, key, index, fixed_nodes) elif method == "mode": - _reconstruct_list(tree, key, _most_common, index) + _reconstruct_list(tree, key, _most_common, index, fixed_nodes) elif callable(method): - _reconstruct_list(tree, key, method, index) + _reconstruct_list(tree, key, method, index, fixed_nodes) else: raise ValueError(f"Method {method} not recognized.") @@ -261,10 +289,15 @@ def ancestral_states( """Reconstructs ancestral states for an attribute. This function reconstructs ancestral (internal node) states for categorical or - continuous attributes defined on tree leaves. Several reconstruction methods + continuous attributes defined on tree observations. Several reconstruction methods are supported, ranging from simple aggregation rules to the Sankoff and Fitch-Hartigan algorithms for discrete character data, or a custom aggregation function can be provided. + For ``tdata.alignment == "leaves"``, only leaf node values are used as input and all + internal node states are reconstructed. For ``tdata.alignment == "nodes"`` or ``"subset"``, + internal nodes present in ``tdata.obs`` with non-missing values are treated as fixed + constraints and are not overwritten by reconstruction. + Parameters ---------- tdata @@ -334,6 +367,8 @@ def ancestral_states( if dtypes.intersection({"O", "S"}): if method in ["mean"]: raise ValueError(f"Method {method} requires numeric data.") + # Determine fixed internal nodes for nodes/subset alignment + leaves_set = set(get_leaves(t)) # If array add to tree as list if is_array: length = data.shape[1] @@ -343,13 +378,23 @@ def ancestral_states( node_attrs[node] = [None] * length _remove_node_attributes(t, keys_added[0]) nx.set_node_attributes(t, node_attrs, keys_added[0]) + fixed_nodes = None + if tdata.alignment != "leaves": + not_all_nan = ~data.isna().all(axis=1) + fixed_nodes = set(data[not_all_nan].index) - leaves_set for index in range(length): - _ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index) + _ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes) # If column add to tree as scalar else: for key, key_added in zip(keys, keys_added, strict=False): _remove_node_attributes(t, key_added) nx.set_node_attributes(t, data[key].to_dict(), key_added) - _ancestral_states(t, key_added, method, costs, missing_state, default_state) + fixed_nodes = None + if tdata.alignment != "leaves": + valid = data[key].notna() + if missing_state is not None: + valid = valid & (data[key] != missing_state) + fixed_nodes = set(data[valid].index) - leaves_set + _ancestral_states(t, key_added, method, costs, missing_state, default_state, fixed_nodes=fixed_nodes) if copy: return get_keyed_node_data(tdata, keys_added, tree_keys, slot="obst") diff --git a/src/pycea/tl/fitness.py b/src/pycea/tl/fitness.py index fdaedc3..98f0f1c 100644 --- a/src/pycea/tl/fitness.py +++ b/src/pycea/tl/fitness.py @@ -514,7 +514,9 @@ def fitness( * tdata.obst[tree].nodes[key_added] : `float` - Inferred fitness values for each node. * tdata.obs[key_added] : `float` - - Inferred fitness values for each leaf. + - Inferred fitness values for each observed node. For ``tdata.alignment == "leaves"``, + only leaf nodes are written. For ``"nodes"`` or ``"subset"``, all observed nodes + (including internal nodes) are written. """ tree_keys = tree _check_tree_overlap(tdata, tree_keys) @@ -528,8 +530,12 @@ def fitness( _infer_fitness_lbi(t, depth_key=depth_key, key_added=key_added, sample_n=sample_n, **(method_kwargs or {})) else: raise ValueError(f"method {method!r} not recognized, use 'sbd' or 'lbi'") - leaf_fitness = get_keyed_leaf_data(tdata, key_added, tree_keys) - tdata.obs[key_added] = tdata.obs.index.map(leaf_fitness[key_added]) + if tdata.alignment == "leaves": + node_fitness = get_keyed_leaf_data(tdata, key_added, tree_keys) + else: + node_fitness = get_keyed_node_data(tdata, key_added, tree_keys, slot="obst") + node_fitness.index = node_fitness.index.droplevel(0) + tdata.obs[key_added] = tdata.obs.index.map(node_fitness[key_added]) if copy: df = get_keyed_node_data(tdata, key_added, tree_keys) if len(trees) == 1: diff --git a/src/pycea/tl/tree_distance.py b/src/pycea/tl/tree_distance.py index dd2c3f8..d5fecfe 100755 --- a/src/pycea/tl/tree_distance.py +++ b/src/pycea/tl/tree_distance.py @@ -145,9 +145,11 @@ def tree_distance( ) -> None | sp.sparse.csr_matrix | np.ndarray: r"""Computes tree distances between observations. - This function calculates distances between observations (typically tree leaves) - based on their positions and depths in the tree. It supports *lowest common ancestor (lca)* - and *path* distances. + This function calculates distances between observations based on their positions + and depths in the tree. For ``tdata.alignment == "leaves"``, this computes distances + between leaf nodes. For ``tdata.alignment == "nodes"`` or ``"subset"``, distances are + computed between all observed nodes (leaves and internal nodes in ``tdata.obs``). + It supports *lowest common ancestor (lca)* and *path* distances. Given two nodes :math:`i` and :math:`j` in a rooted tree, with depths :math:`d_i` and :math:`d_j`, and with their lowest common ancestor having @@ -175,8 +177,8 @@ def tree_distance( obs The observations to use: - - If `None`, pairwise distance for tree leaves is stored in `tdata.obsp`. - - If a string, distance to all other tree leaves is `tdata.obs`. + - If `None`, pairwise distance for all observed nodes is stored in `tdata.obsp`. + - If a string, distance to all other observed nodes is stored in `tdata.obs`. - If a sequence, pairwise distance is stored in `tdata.obsp`. - If a sequence of pairs, distance between pairs is stored in `tdata.obsp`. metric diff --git a/src/pycea/tl/tree_neighbors.py b/src/pycea/tl/tree_neighbors.py index b471a02..ab5aacf 100755 --- a/src/pycea/tl/tree_neighbors.py +++ b/src/pycea/tl/tree_neighbors.py @@ -20,13 +20,19 @@ ) -def _lca_neighbors(tree, start_node, n_neighbors, max_dist, depth_key): +def _lca_neighbors(tree, start_node, n_neighbors, max_dist, depth_key, observed_nodes=None): """Find neighbors using LCA distance via a walk-up approach. - Walks from start_node to root, collecting sibling subtree leaves at each level. - All leaves in a sibling subtree share the same LCA distance (depth_key of their + Walks from start_node to root, collecting sibling subtree nodes at each level. + All nodes in a sibling subtree share the same LCA distance (depth_key of their common ancestor with start_node). Processes closest relatives first. - Time complexity: O(n) per leaf. + Time complexity: O(n) per node. + + Parameters + ---------- + observed_nodes + Set of observed node names to collect as neighbors. If None, only leaf nodes + (out_degree == 0) are collected (default leaves-alignment behavior). """ neighbors = [] neighbor_distances = [] @@ -34,6 +40,27 @@ def _lca_neighbors(tree, start_node, n_neighbors, max_dist, depth_key): node = start_node is_finite = n_neighbors != float("inf") + # For nodes/subset alignment: also collect observed descendants of start_node. + # LCA(start_node, descendant) = start_node, so lca_dist = depth[start_node]. + if observed_nodes is not None: + start_depth = tree.nodes[start_node][depth_key] + if start_depth <= max_dist: + desc_candidates = [] + stack = list(tree.successors(start_node)) + while stack: + n = stack.pop() + if n not in seen: + seen.add(n) + if n in observed_nodes: + desc_candidates.append(n) + if tree.out_degree(n) != 0: + stack.extend(tree.successors(n)) + random.shuffle(desc_candidates) + take = desc_candidates[: n_neighbors - len(neighbors)] if is_finite else desc_candidates + for n in take: + neighbors.append(n) + neighbor_distances.append(start_depth) + while len(neighbors) < n_neighbors: parents = list(tree.predecessors(node)) if not parents: @@ -43,23 +70,29 @@ def _lca_neighbors(tree, start_node, n_neighbors, max_dist, depth_key): seen.add(parent) if lca_dist <= max_dist: - # Collect unseen leaves — their LCA with start_node is exactly parent - sibling_leaves = [] + candidates = [] + # For nodes/subset alignment: parent itself is a candidate (LCA(start, parent) = parent) + if observed_nodes is not None and parent in observed_nodes: + candidates.append(parent) + # Collect observed nodes from sibling subtrees stack = list(tree.successors(parent)) while stack: n = stack.pop() if n in seen: continue seen.add(n) - if tree.out_degree(n) == 0: - sibling_leaves.append(n) - else: + is_observed = (observed_nodes is None and tree.out_degree(n) == 0) or ( + observed_nodes is not None and n in observed_nodes + ) + if is_observed: + candidates.append(n) + if tree.out_degree(n) != 0: stack.extend(tree.successors(n)) - random.shuffle(sibling_leaves) - take = sibling_leaves[: n_neighbors - len(neighbors)] if is_finite else sibling_leaves - for leaf in take: - neighbors.append(leaf) + random.shuffle(candidates) + take = candidates[: n_neighbors - len(neighbors)] if is_finite else candidates + for candidate in take: + neighbors.append(candidate) neighbor_distances.append(lca_dist) node = parent @@ -67,8 +100,15 @@ def _lca_neighbors(tree, start_node, n_neighbors, max_dist, depth_key): return neighbors, neighbor_distances -def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, depth_key): - """Breadth-first search for path distance neighbors.""" +def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, depth_key, observed_nodes=None): + """Breadth-first search for path distance neighbors. + + Parameters + ---------- + observed_nodes + Set of observed node names to collect as neighbors. If None, only leaf nodes + (out_degree == 0) are collected (default leaves-alignment behavior). + """ queue = [] heapq.heappush(queue, (0, start_node)) visited = {start_node} @@ -91,36 +131,63 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, depth_key): # Breadth-first search using direct children only while queue and (len(neighbors) < n_neighbors): distance, node = heapq.heappop(queue) + # For nodes/subset alignment: the popped node itself may be an observed neighbor + # (handles ancestor nodes that were pre-queued; descendants are handled in child loop) + if observed_nodes is not None and node != start_node and node in observed_nodes: + neighbors.append(node) + neighbor_distances.append(distance) + if len(neighbors) >= n_neighbors: + break children = list(tree.successors(node)) random.shuffle(children) for child in children: if child not in visited: child_distance = distance + abs(tree.nodes[node][depth_key] - tree.nodes[child][depth_key]) if child_distance <= max_dist: - if tree.out_degree(child) == 0: + # For leaves alignment: add leaf when discovered as child + if observed_nodes is None and tree.out_degree(child) == 0: neighbors.append(child) neighbor_distances.append(child_distance) if len(neighbors) >= n_neighbors: break - heapq.heappush(queue, (child_distance, child)) + # Push to queue: non-leaves always; observed nodes for nodes alignment + # (observed non-leaves will be added as neighbors when popped) + if tree.out_degree(child) != 0 or observed_nodes is not None: + heapq.heappush(queue, (child_distance, child)) visited.add(child) return neighbors, neighbor_distances -def _tree_neighbors(tree, n_neighbors, max_dist, depth_key, metric, leaves=None): - """Identify neighbors in a given tree.""" +def _tree_neighbors(tree, n_neighbors, max_dist, depth_key, metric, nodes=None, observed_nodes=None): + """Identify neighbors in a given tree. + + Parameters + ---------- + nodes + Nodes to find neighbors for. If None, defaults to all leaf nodes. + observed_nodes + Set of nodes that are considered observable neighbors. If None, only leaves + (out_degree == 0) are returned as neighbors. + """ rows, cols, distances = [], [], [] - if leaves is None: - leaves = [node for node in tree.nodes() if tree.out_degree(node) == 0] - for leaf in leaves: + if nodes is None: + if observed_nodes is not None: + nodes = list(observed_nodes) + else: + nodes = [node for node in tree.nodes() if tree.out_degree(node) == 0] + for node in nodes: if metric == "lca": - leaf_neighbors, leaf_distances = _lca_neighbors(tree, leaf, n_neighbors, max_dist, depth_key) + node_neighbors, node_distances = _lca_neighbors( + tree, node, n_neighbors, max_dist, depth_key, observed_nodes + ) else: - leaf_neighbors, leaf_distances = _bfs_by_distance(tree, leaf, n_neighbors, max_dist, depth_key) - rows.extend([leaf] * len(leaf_neighbors)) - cols.extend(leaf_neighbors) - distances.extend(leaf_distances) + node_neighbors, node_distances = _bfs_by_distance( + tree, node, n_neighbors, max_dist, depth_key, observed_nodes + ) + rows.extend([node] * len(node_neighbors)) + cols.extend(node_neighbors) + distances.extend(node_distances) return rows, cols, distances @@ -167,17 +234,21 @@ def tree_neighbors( ) -> None | tuple[sp.sparse.csr_matrix, sp.sparse.csr_matrix]: """Identifies neighbors in the tree. - For each leaf, this function identifies neighbors according to a chosen + For each observation, this function identifies neighbors according to a chosen tree distance `metric` and either: - * the top-``n_neighbors`` closest leaves (ties broken at random) + * the top-``n_neighbors`` closest observations (ties broken at random) - * all leaves within a distance threshold ``max_dist``. + * all observations within a distance threshold ``max_dist``. Results are stored as sparse connectivities and distances, or returned when - ``copy=True``. You can restrict the operation to a subset of leaves via + ``copy=True``. You can restrict the operation to a subset of observations via ``obs`` and/or to specific trees via ``tree``. + For ``tdata.alignment == "leaves"``, only leaf nodes are considered as neighbors. + For ``tdata.alignment == "nodes"`` or ``"subset"``, all observed nodes (leaves and + internal nodes present in ``tdata.obs``) are considered as neighbors. + Parameters ---------- tdata @@ -191,9 +262,9 @@ def tree_neighbors( obs The observations to use: - - If `None`, neighbors for all leaves are stored in `tdata.obsp`. - - If a string, neighbors of specified leaf are stored in `tdata.obs`. - - If a sequence, neighbors within specified leaves are stored in `tdata.obsp`. + - If `None`, neighbors for all observed nodes are stored in `tdata.obsp`. + - If a string, neighbors of specified observation are stored in `tdata.obs`. + - If a sequence, neighbors within specified observations are stored in `tdata.obsp`. metric The type of tree distance to compute: @@ -239,18 +310,28 @@ def tree_neighbors( _check_tree_overlap(tdata, tree_keys) if update: _check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"]) - # Neighbors of a single leaf + # Neighbors of a single observation if isinstance(obs, str): trees = get_trees(tdata, tree_keys) - leaf_to_tree = {leaf: key for key, tree in trees.items() for leaf in get_leaves(tree)} - if obs not in leaf_to_tree: - raise ValueError(f"Leaf {obs} not found in any tree.") - t = trees[leaf_to_tree[obs]] + if tdata.alignment == "leaves": + node_to_tree = {leaf: key for key, tree in trees.items() for leaf in get_leaves(tree)} + else: + node_to_tree = {node: key for key, tree in trees.items() for node in tree.nodes()} + if obs not in node_to_tree: + raise ValueError(f"Observation {obs} not found in any tree.") + t = trees[node_to_tree[obs]] + obs_set = set(tdata.obs_names) & set(t.nodes()) if tdata.alignment != "leaves" else None connectivities, _, distances = _tree_neighbors( - t, n_neighbors or float("inf"), max_dist or float("inf"), depth_key, metric, leaves=[obs] + t, + n_neighbors or float("inf"), + max_dist or float("inf"), + depth_key, + metric, + nodes=[obs], + observed_nodes=obs_set, ) tdata.obs[f"{key_added}_neighbors"] = tdata.obs_names.isin(connectivities) - # Neighbors for some or all leaves + # Neighbors for some or all observations else: if isinstance(obs, Sequence): tdata_subset = tdata[obs] @@ -261,10 +342,17 @@ def tree_neighbors( raise ValueError("obs must be a string, a sequence of strings, or None.") # For each tree, identify neighbors rows, cols, data = [], [], [] + obs_names_set = set(tdata.obs_names) for _, t in trees.items(): check_tree_has_key(t, depth_key) + observed_nodes = obs_names_set & set(t.nodes()) if tdata.alignment != "leaves" else None tree_rows, tree_cols, tree_data = _tree_neighbors( - t, n_neighbors or float("inf"), max_dist or float("inf"), depth_key, metric + t, + n_neighbors or float("inf"), + max_dist or float("inf"), + depth_key, + metric, + observed_nodes=observed_nodes, ) rows.extend([tdata.obs_names.get_loc(row) for row in tree_rows]) cols.extend([tdata.obs_names.get_loc(col) for col in tree_cols]) diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index b0668c5..c9d7212 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -28,10 +28,15 @@ def tdata(): @pytest.fixture def nodes_tdata(): + # Tree: root -> B, C; C -> D, E + # spatial: root=NaN (reconstruct), B=[0,1] (leaf), C=[1,1] (fixed), D=[2,1] (leaf), E=[4,4] (leaf) tree = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")]) spatial = np.array([[np.nan, np.nan], [0, 1], [1, 1], [2, 1], [4, 4]]) nodes_tdata = td.TreeData( - obs=pd.DataFrame(index=["root", "B", "C", "D", "E"]), + obs=pd.DataFrame( + {"value": [np.nan, 0, 5, 3, 2], "str_value": [None, "0", "1", "3", "2"]}, + index=["root", "B", "C", "D", "E"], + ), obst={"tree": tree}, obsm={"spatial": spatial}, # type: ignore alignment="nodes", @@ -114,10 +119,28 @@ def test_ancestral_states_sankoff(tdata): def test_ancestral_states_nodes_tdata(nodes_tdata): + # C=[1,1] is a fixed observed internal node; root has NaN (reconstructed) + # root = mean(B=[0,1], C=[1,1]) = [0.5, 1.0] (C treated as fixed, not expanded into D/E) states = ancestral_states(nodes_tdata, "spatial", method="mean", copy=True) - print(nodes_tdata.obst["tree"].nodes["root"]["spatial"]) - print(states) - assert states.loc[("tree", "root"), "spatial"] == [2.0, 2.0] + assert nodes_tdata.obst["tree"].nodes["root"]["spatial"] == [0.5, 1.0] + assert nodes_tdata.obst["tree"].nodes["C"]["spatial"] == [1, 1] # C value preserved + assert states.loc[("tree", "root"), "spatial"] == [0.5, 1.0] + + +def test_ancestral_states_nodes_scalar(nodes_tdata): + # C=5 (fixed), root=NaN (reconstruct from B=0 and C=5) + ancestral_states(nodes_tdata, "value", method="mean", copy=False) + tree = nodes_tdata.obst["tree"] + assert tree.nodes["C"]["value"] == 5 # C preserved + assert tree.nodes["root"]["value"] == pytest.approx(2.5) # mean(B=0, C=5) + assert tree.nodes["B"]["value"] == 0 # leaf unchanged + + +def test_ancestral_states_nodes_fitch(nodes_tdata): + # C="1" (fixed internal); root reconstructed from B="0" and C="1" + ancestral_states(nodes_tdata, "str_value", method="fitch_hartigan", missing_state=None, copy=False) + tree = nodes_tdata.obst["tree"] + assert tree.nodes["C"]["str_value"] == "1" # C value preserved def test_ancestral_states_invalid(tdata): diff --git a/tests/test_fitness.py b/tests/test_fitness.py index 7754043..d907d43 100644 --- a/tests/test_fitness.py +++ b/tests/test_fitness.py @@ -60,5 +60,28 @@ def test_random_state_gives_reproducible_output(tdata): pd.testing.assert_frame_equal(df1.sort_index(), df2.sort_index()) +@pytest.fixture +def nodes_tdata(): + tree = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D"), ("B", "E")]) + nx.set_node_attributes(tree, {"root": 0, "A": 3, "B": 1, "C": 2.5, "D": 2.5, "E": 2.5}, "depth") + tdata = td.TreeData( + obs=pd.DataFrame(index=["root", "A", "B", "C", "D", "E"]), + obst={"tree": tree}, + alignment="nodes", + ) + return tdata + + +def test_fitness_nodes_alignment(nodes_tdata): + fitness(nodes_tdata, method="lbi", copy=False, random_state=42) + # All nodes (leaves + internal) should be written to obs + assert set(nodes_tdata.obs.index) == {"root", "A", "B", "C", "D", "E"} + assert nodes_tdata.obs["fitness"].notna().all() + # obs values match what's stored in the tree + tree = nodes_tdata.obst["tree"] + for node in tree.nodes: + assert nodes_tdata.obs.loc[node, "fitness"] == pytest.approx(tree.nodes[node]["fitness"], abs=1e-10) + + if __name__ == "__main__": pytest.main(["-v", __file__]) diff --git a/tests/test_tree_neighbors.py b/tests/test_tree_neighbors.py index eb73d5e..8d90825 100755 --- a/tests/test_tree_neighbors.py +++ b/tests/test_tree_neighbors.py @@ -94,5 +94,49 @@ def test_tree_neighbors_invalid(tdata): tree_neighbors(tdata, n_neighbors=3, metric="path", depth_key="invalid") +@pytest.fixture +def nodes_tdata(): + # Tree: root(0) -> A(1), B(1); A -> C(2), D(2); B -> E(2) + tree = nx.DiGraph([("root", "A"), ("root", "B"), ("A", "C"), ("A", "D"), ("B", "E")]) + nx.set_node_attributes(tree, {"root": 0, "A": 1, "B": 1, "C": 2, "D": 2, "E": 2}, "depth") + tdata = td.TreeData( + obs=pd.DataFrame(index=["root", "A", "B", "C", "D", "E"]), + obst={"tree": tree}, + alignment="nodes", + ) + return tdata + + +def test_tree_neighbors_nodes_alignment(nodes_tdata): + # path(A, C) = |1+2 - 2*1| = 1; path(A, B) = |1+1 - 2*0| = 2; path(C, E) = |2+2 - 2*0| = 4 + tree_neighbors(nodes_tdata, max_dist=2, metric="path") + dist = nodes_tdata.obsp["tree_distances"] + a_idx = nodes_tdata.obs_names.get_loc("A") + b_idx = nodes_tdata.obs_names.get_loc("B") + c_idx = nodes_tdata.obs_names.get_loc("C") + d_idx = nodes_tdata.obs_names.get_loc("D") + # Internal nodes appear as neighbors + assert dist[a_idx, c_idx] == 1 # A -> C (path = 1) + assert dist[a_idx, d_idx] == 1 # A -> D (path = 1) + assert dist[c_idx, a_idx] == 1 # C -> A (path = 1) + assert dist[a_idx, b_idx] == 2 # A -> B (path = 2) + # matrix is n_obs x n_obs (6 observations including internal nodes) + assert dist.shape == (6, 6) + # LCA metric: lca(A, C) = A (depth 1); lca(C, D) = A (depth 1); lca(C, E) = root (depth 0) + tree_neighbors(nodes_tdata, max_dist=2, metric="lca", key_added="lca") + lca_dist = nodes_tdata.obsp["lca_distances"] + assert lca_dist[a_idx, c_idx] == 1 # lca(A, C) = A, depth = 1 + assert lca_dist[c_idx, d_idx] == 1 # lca(C, D) = A, depth = 1 + + +def test_tree_neighbors_nodes_single_obs(nodes_tdata): + # For a single-string obs, only the queried node is marked True (consistent behavior) + tree_neighbors(nodes_tdata, n_neighbors=3, metric="path", obs="A") + assert nodes_tdata.obs.query("tree_neighbors").index.tolist() == ["A"] + # Internal node (B) is also a valid starting point + tree_neighbors(nodes_tdata, n_neighbors=2, metric="path", obs="B") + assert nodes_tdata.obs.query("tree_neighbors").index.tolist() == ["B"] + + if __name__ == "__main__": pytest.main(["-v", __file__])