Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 73 additions & 28 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import treedata as td

from pycea.utils import _check_tree_overlap, get_keyed_node_data, get_keyed_obs_data, get_root, get_trees
from pycea.utils import _check_tree_overlap, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root, get_trees


def _most_common(arr: np.ndarray) -> Any:
Expand Down Expand Up @@ -45,15 +45,18 @@ def _remove_node_attributes(tree: nx.DiGraph, key: str) -> None:


def _reconstruct_fitch_hartigan(
tree: nx.DiGraph, key: str, missing: str | None = None, index: int | None = None
tree: nx.DiGraph, key: str, missing: str | None = None, index: int | None = None, fixed_nodes: set | None = None
) -> None:
"""Reconstructs ancestral states using the Fitch-Hartigan algorithm."""

def _is_fixed(node, value):
return fixed_nodes is not None and node in fixed_nodes and value is not None and value != missing

# Recursive function to calculate the downpass
def downpass(node):
# Base case: leaf
if tree.out_degree(node) == 0:
value = _get_node_value(tree, node, key, index)
value = _get_node_value(tree, node, key, index)
# Base case: leaf or fixed observed internal node
if tree.out_degree(node) == 0 or _is_fixed(node, value):
if value == missing:
tree.nodes[node]["value_set"] = missing
else:
Expand All @@ -78,6 +81,11 @@ def downpass(node):
# Recursive function to calculate the uppass
def uppass(node, parent_state=None):
value = _get_node_value(tree, node, key, index)
if _is_fixed(node, value):
# Fixed node: keep its observed value, propagate it to children
for child in tree.successors(node):
uppass(child, value)
return
if value is None:
if parent_state and parent_state in tree.nodes[node]["value_set"]:
value = parent_state
Expand Down Expand Up @@ -107,6 +115,7 @@ def _reconstruct_sankoff(
missing: str | None = None,
default: str | None = None,
index: int | None = None,
fixed_nodes: set | None = None,
) -> None:
"""Reconstructs ancestral states using the Sankoff algorithm."""
# Set up
Expand All @@ -115,16 +124,19 @@ def _reconstruct_sankoff(
cost_matrix = costs.to_numpy()
value_to_index = {value: i for i, value in enumerate(alphabet)}

def _is_fixed(node, value):
return fixed_nodes is not None and node in fixed_nodes and value is not None and value != missing

# Recursive function to calculate the Sankoff scores
def sankoff_scores(node):
# Base case: leaf
if tree.out_degree(node) == 0:
leaf_value = _get_node_value(tree, node, key, index)
if leaf_value == missing:
node_value = _get_node_value(tree, node, key, index)
# Base case: leaf or fixed observed internal node
if tree.out_degree(node) == 0 or _is_fixed(node, node_value):
if node_value == missing:
return np.zeros(num_states)
else:
scores = np.full(num_states, float("inf"))
scores[value_to_index[leaf_value]] = 0
scores[value_to_index[node_value]] = 0
return scores
# Recursive case: internal node
else:
Expand All @@ -143,29 +155,40 @@ def sankoff_scores(node):
# Recursive function to traceback the Sankoff scores
def traceback(node, parent_value_index):
for i, child in enumerate(tree.successors(node)):
child_value_index = tree.nodes[node]["_pointers"][parent_value_index, i]
_set_node_value(tree, child, key, alphabet[child_value_index], index)
child_value = _get_node_value(tree, child, key, index)
if _is_fixed(child, child_value):
# Fixed node: keep its value, traceback using its fixed index
child_value_index = value_to_index[child_value]
else:
child_value_index = tree.nodes[node]["_pointers"][parent_value_index, i]
_set_node_value(tree, child, key, alphabet[child_value_index], index)
traceback(child, child_value_index)

# Get scores
root = [n for n, d in tree.in_degree() if d == 0][0]
root_scores = sankoff_scores(root)
# Reconstruct ancestral states
root_value_index = np.argmin(root_scores)
_set_node_value(tree, root, key, alphabet[root_value_index], index)
root_value = _get_node_value(tree, root, key, index)
if _is_fixed(root, root_value):
root_value_index = value_to_index[root_value]
else:
root_value_index = np.argmin(root_scores)
_set_node_value(tree, root, key, alphabet[root_value_index], index)
traceback(root, root_value_index)
# Clean up
for node in tree.nodes:
if "_pointers" in tree.nodes[node]:
del tree.nodes[node]["_pointers"]


def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None) -> None:
def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None:
"""Reconstructs ancestral by averaging the values of the children."""

def subtree_mean(node):
if tree.out_degree(node) == 0:
return _get_node_value(tree, node, key, index), 1
val = _get_node_value(tree, node, key, index)
is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None
if tree.out_degree(node) == 0 or is_fixed:
return val, 1
else:
values, weights = [], []
for child in tree.successors(node):
Expand All @@ -180,12 +203,16 @@ def subtree_mean(node):
subtree_mean(root)


def _reconstruct_list(tree: nx.DiGraph, key: str, sum_func: Callable, index: int | None) -> None:
def _reconstruct_list(
tree: nx.DiGraph, key: str, sum_func: Callable, index: int | None, fixed_nodes: set | None = None
) -> None:
"""Reconstructs ancestral states by concatenating the values of the children."""

def subtree_list(node):
if tree.out_degree(node) == 0:
return [_get_node_value(tree, node, key, index)]
val = _get_node_value(tree, node, key, index)
is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None
if tree.out_degree(node) == 0 or is_fixed:
return [val]
else:
values = []
for child in tree.successors(node):
Expand All @@ -205,20 +232,21 @@ def _ancestral_states(
missing: str | None = None,
default: str | None = None,
index: int | None = None,
fixed_nodes: set | None = None,
) -> None:
"""Reconstructs ancestral states for a given attribute using a given method"""
if method == "sankoff":
if costs is None:
raise ValueError("Costs matrix must be provided for Sankoff algorithm.")
_reconstruct_sankoff(tree, key, costs, missing, default, index)
_reconstruct_sankoff(tree, key, costs, missing, default, index, fixed_nodes)
elif method == "fitch_hartigan":
_reconstruct_fitch_hartigan(tree, key, missing, index)
_reconstruct_fitch_hartigan(tree, key, missing, index, fixed_nodes)
elif method == "mean":
_reconstruct_mean(tree, key, index)
_reconstruct_mean(tree, key, index, fixed_nodes)
elif method == "mode":
_reconstruct_list(tree, key, _most_common, index)
_reconstruct_list(tree, key, _most_common, index, fixed_nodes)
elif callable(method):
_reconstruct_list(tree, key, method, index)
_reconstruct_list(tree, key, method, index, fixed_nodes)
else:
raise ValueError(f"Method {method} not recognized.")

Expand Down Expand Up @@ -261,10 +289,15 @@ def ancestral_states(
"""Reconstructs ancestral states for an attribute.

This function reconstructs ancestral (internal node) states for categorical or
continuous attributes defined on tree leaves. Several reconstruction methods
continuous attributes defined on tree observations. Several reconstruction methods
are supported, ranging from simple aggregation rules to the Sankoff and Fitch-Hartigan
algorithms for discrete character data, or a custom aggregation function can be provided.

For ``tdata.alignment == "leaves"``, only leaf node values are used as input and all
internal node states are reconstructed. For ``tdata.alignment == "nodes"`` or ``"subset"``,
internal nodes present in ``tdata.obs`` with non-missing values are treated as fixed
constraints and are not overwritten by reconstruction.

Parameters
----------
tdata
Expand Down Expand Up @@ -334,6 +367,8 @@ def ancestral_states(
if dtypes.intersection({"O", "S"}):
if method in ["mean"]:
raise ValueError(f"Method {method} requires numeric data.")
# Determine fixed internal nodes for nodes/subset alignment
leaves_set = set(get_leaves(t))
# If array add to tree as list
if is_array:
length = data.shape[1]
Expand All @@ -343,13 +378,23 @@ def ancestral_states(
node_attrs[node] = [None] * length
_remove_node_attributes(t, keys_added[0])
nx.set_node_attributes(t, node_attrs, keys_added[0])
fixed_nodes = None
if tdata.alignment != "leaves":
not_all_nan = ~data.isna().all(axis=1)
fixed_nodes = set(data[not_all_nan].index) - leaves_set
for index in range(length):
_ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index)
_ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes)
Comment on lines +383 to +386

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Recompute fixed internal nodes per array dimension

For array-valued inputs, fixed_nodes is computed once from ~data.isna().all(axis=1) and reused for every index, which means a partially observed internal node (for example [1, NaN]) is treated as fixed even on missing dimensions. Those missing dimensions are then not reconstructed and can propagate NaN/sentinel values to ancestors instead of being inferred from descendants; fixed-node selection needs to be index-aware (and honor missing_state).

Useful? React with 👍 / 👎.

# If column add to tree as scalar
else:
for key, key_added in zip(keys, keys_added, strict=False):
_remove_node_attributes(t, key_added)
nx.set_node_attributes(t, data[key].to_dict(), key_added)
_ancestral_states(t, key_added, method, costs, missing_state, default_state)
fixed_nodes = None
if tdata.alignment != "leaves":
valid = data[key].notna()
if missing_state is not None:
valid = valid & (data[key] != missing_state)
fixed_nodes = set(data[valid].index) - leaves_set
_ancestral_states(t, key_added, method, costs, missing_state, default_state, fixed_nodes=fixed_nodes)
if copy:
return get_keyed_node_data(tdata, keys_added, tree_keys, slot="obst")
12 changes: 9 additions & 3 deletions src/pycea/tl/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ def fitness(
* tdata.obst[tree].nodes[key_added] : `float`
- Inferred fitness values for each node.
* tdata.obs[key_added] : `float`
- Inferred fitness values for each leaf.
- Inferred fitness values for each observed node. For ``tdata.alignment == "leaves"``,
only leaf nodes are written. For ``"nodes"`` or ``"subset"``, all observed nodes
(including internal nodes) are written.
"""
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
Expand All @@ -528,8 +530,12 @@ def fitness(
_infer_fitness_lbi(t, depth_key=depth_key, key_added=key_added, sample_n=sample_n, **(method_kwargs or {}))
else:
raise ValueError(f"method {method!r} not recognized, use 'sbd' or 'lbi'")
leaf_fitness = get_keyed_leaf_data(tdata, key_added, tree_keys)
tdata.obs[key_added] = tdata.obs.index.map(leaf_fitness[key_added])
if tdata.alignment == "leaves":
node_fitness = get_keyed_leaf_data(tdata, key_added, tree_keys)
else:
node_fitness = get_keyed_node_data(tdata, key_added, tree_keys, slot="obst")
node_fitness.index = node_fitness.index.droplevel(0)
tdata.obs[key_added] = tdata.obs.index.map(node_fitness[key_added])
if copy:
df = get_keyed_node_data(tdata, key_added, tree_keys)
if len(trees) == 1:
Expand Down
12 changes: 7 additions & 5 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ def tree_distance(
) -> None | sp.sparse.csr_matrix | np.ndarray:
r"""Computes tree distances between observations.

This function calculates distances between observations (typically tree leaves)
based on their positions and depths in the tree. It supports *lowest common ancestor (lca)*
and *path* distances.
This function calculates distances between observations based on their positions
and depths in the tree. For ``tdata.alignment == "leaves"``, this computes distances
between leaf nodes. For ``tdata.alignment == "nodes"`` or ``"subset"``, distances are
computed between all observed nodes (leaves and internal nodes in ``tdata.obs``).
It supports *lowest common ancestor (lca)* and *path* distances.

Given two nodes :math:`i` and :math:`j` in a rooted tree, with depths
:math:`d_i` and :math:`d_j`, and with their lowest common ancestor having
Expand Down Expand Up @@ -175,8 +177,8 @@ def tree_distance(
obs
The observations to use:

- If `None`, pairwise distance for tree leaves is stored in `tdata.obsp`.
- If a string, distance to all other tree leaves is `tdata.obs`.
- If `None`, pairwise distance for all observed nodes is stored in `tdata.obsp`.
- If a string, distance to all other observed nodes is stored in `tdata.obs`.
- If a sequence, pairwise distance is stored in `tdata.obsp`.
- If a sequence of pairs, distance between pairs is stored in `tdata.obsp`.
metric
Expand Down
Loading