Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Changelog

## [0.2.5] - XXXXXXXXXX

- The `eps` parameter was previously used for two things: the minimum time interval in
likelihood calculations for the discrete-time algorithms, and the minimum allowed branch length
when forcing positive branch lengths. The latter is now a separate parameter called
`min_branch_length` for all algorithms, while the `eps` parameter is only used for the
discrete time algorithms.

- The default `min_branch_length` and `eps` have been set to 1e-8 rather than 1e-10, to avoid
occasional issues with floating point error.

## [0.2.4] - 2025-09-18

- Add support for Python 3.13, minimum version is now 3.10.
Expand Down
63 changes: 38 additions & 25 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
DEFAULT_RESCALING_INTERVALS = 1000
DEFAULT_RESCALING_ITERATIONS = 5
DEFAULT_MAX_ITERATIONS = 25
DEFAULT_EPSILON = 1e-10
DEFAULT_EPSILON = 1e-8
DEFAULT_MIN_BRANCH_LENGTH = 1e-8


# Classes for each method
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
allow_unary=None,
record_provenance=None,
constr_iterations=None,
min_branch_length=None,
set_metadata=None,
progress=None,
# deprecated params
Expand Down Expand Up @@ -150,6 +152,13 @@ def __init__(
)
self.constr_iterations = constr_iterations

if min_branch_length is None:
self.min_branch_length = DEFAULT_MIN_BRANCH_LENGTH
else:
if not min_branch_length > 0.0:
raise ValueError("Minimum branch length must be positive")
self.min_branch_length = min_branch_length

self.allow_unary = False if allow_unary is None else allow_unary

if self.prior_grid_func_name is None:
Expand Down Expand Up @@ -188,7 +197,7 @@ def __init__(
self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts)
self.fixed_nodes = np.array(list(ts.samples()))

def get_modified_ts(self, result, eps):
def get_modified_ts(self, result):
# Return a new ts based on the existing one, but with the various
# time-related information correctly set.
ts = self.ts
Expand All @@ -215,7 +224,9 @@ def get_modified_ts(self, result, eps):

# Constrain node ages for positive branch lengths
constr_timing = time.time()
nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations)
nodes.time = util.constrain_ages(
ts, node_mean_t, self.min_branch_length, self.constr_iterations
)
constr_timing -= time.time()
logger.info(f"Constrained node ages in {abs(constr_timing):.2f} seconds")
# Possibly change mutation nodes if phasing singletons
Expand Down Expand Up @@ -282,10 +293,10 @@ def _time_md_array(table, mean, var):
table.metadata_schema = default_schema
table.packset_metadata(_time_md_array(table, mean, var))

def parse_result(self, result, epsilon):
def parse_result(self, result):
# Construct the tree sequence to return and add other stuff we might want to
# return. pst_cols is a dict to be appended to the output posterior dict
ret = [self.get_modified_ts(result, epsilon)]
ret = [self.get_modified_ts(result)]
if self.return_fit:
ret.append(result.fit_object)
if self.return_likelihood:
Expand Down Expand Up @@ -451,7 +462,6 @@ def __init__(self, ts, **kwargs):

def run(
self,
eps,
max_iterations,
max_shape,
rescaling_intervals,
Expand Down Expand Up @@ -570,9 +580,8 @@ def maximization(
the prior parameters for each node-to-be-dated. Note that different estimation
methods may require different types of prior, as described in the documentation
for each estimation method.
:param float eps: The error factor in time difference calculations, and the
minimum distance separating parent and child ages in the returned tree sequence.
Default: None, treated as 1e-10.
:param float eps: The error factor in time difference calculations. Default: None,
treated as 1e-8.
:param int num_threads: The number of threads to use when precalculating likelihoods.
A simpler unthreaded algorithm is used unless this is >= 1. Default: None
:param string probability_space: Should the internal algorithm save
Expand All @@ -587,7 +596,7 @@ def maximization(
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
updated node times based on the posterior mean, corrected where necessary to
ensure that parents are strictly older than all their children by an amount
given by the ``eps`` parameter.
given by the ``min_branch_length`` parameter.
- **marginal_likelihood** (:py:class:`float`) -- (Only returned if
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
Expand Down Expand Up @@ -615,7 +624,7 @@ def maximization(
cache_inside=cache_inside,
probability_space=probability_space,
)
return dating_method.parse_result(result, eps)
return dating_method.parse_result(result)


def inside_outside(
Expand Down Expand Up @@ -692,9 +701,8 @@ def inside_outside(
the prior parameters for each node-to-be-dated. Note that different estimation
methods may require different types of prior, as described in the documentation
for each estimation method.
:param float eps: The error factor in time difference calculations, and the
minimum distance separating parent and child ages in the returned tree sequence.
Default: None, treated as 1e-10.
:param float eps: The error factor in time difference calculations. Default: None,
treated as 1e-8.
:param int num_threads: The number of threads to use when precalculating likelihoods.
A simpler unthreaded algorithm is used unless this is >= 1. Default: None
:param bool outside_standardize: Should the likelihoods be standardized during the
Expand All @@ -720,7 +728,7 @@ def inside_outside(
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
updated node times based on the posterior mean, corrected where necessary to
ensure that parents are strictly older than all their children by an amount
given by the ``eps`` parameter.
given by the ``min_branch_length`` parameter.
- **fit** (:class:`~discrete.BeliefPropagation`) -- (Only returned if
``return_fit`` is ``True``) The underlying object used to run the dating
inference. This can then be queried e.g. using
Expand Down Expand Up @@ -757,14 +765,13 @@ def inside_outside(
cache_inside=cache_inside,
probability_space=probability_space,
)
return dating_method.parse_result(result, eps)
return dating_method.parse_result(result)


def variational_gamma(
tree_sequence,
*,
mutation_rate,
eps=None,
max_iterations=None,
rescaling_intervals=None,
rescaling_iterations=None,
Expand All @@ -773,10 +780,12 @@ def variational_gamma(
max_shape=None,
regularise_roots=None,
singletons_phased=None,
# deprecated parameters
eps=None,
**kwargs,
):
"""
variational_gamma(tree_sequence, *, mutation_rate, eps=None, max_iterations=None,\
variational_gamma(tree_sequence, *, mutation_rate, max_iterations=None,\
rescaling_intervals=None, rescaling_iterations=None,\
match_segregating_sites=None, **kwargs)

Expand All @@ -797,8 +806,6 @@ def variational_gamma(
:param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated.
:param float mutation_rate: The estimated mutation rate per unit of genome per
unit time.
:param float eps: The minimum distance separating parent and child ages in
the returned tree sequence. Default: None, treated as 1e-10
:param int max_iterations: The number of iterations used in the expectation
propagation algorithm. Default: None, treated as 25.
:param float rescaling_intervals: For time rescaling, the number of time
Expand All @@ -820,7 +827,7 @@ def variational_gamma(
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
updated node times based on the posterior mean, corrected where necessary to
ensure that parents are strictly older than all their children by an amount
given by the ``eps`` parameter.
given by the ``min_branch_length`` parameter.
- **fit** (:class:`~variational.ExpectationPropagation`) -- (Only returned
if ``return_fit`` is ``True``). The underlying object used to run the dating
inference. This can then be queried e.g. using
Expand All @@ -830,8 +837,6 @@ def variational_gamma(
the mutation data given the inferred node times. Not currently
implemented for this method (set to ``None``)
"""
if eps is None:
eps = DEFAULT_EPSILON
if max_iterations is None:
max_iterations = DEFAULT_MAX_ITERATIONS
if max_shape is None:
Expand All @@ -848,6 +853,11 @@ def variational_gamma(
regularise_roots = True
if singletons_phased is None:
singletons_phased = True
if eps is not None:
raise ValueError(
"The `eps` parameter has been disambiguated and is no longer used "
"for the variational gamma algorithm; use `min_branch_length` instead"
)
if tree_sequence.num_mutations == 0:
raise ValueError(
"No mutations present: these are required for the variational_gamma method"
Expand All @@ -856,7 +866,6 @@ def variational_gamma(
tree_sequence, mutation_rate=mutation_rate, **kwargs
)
result = dating_method.run(
eps=eps,
max_iterations=max_iterations,
max_shape=max_shape,
rescaling_intervals=rescaling_intervals,
Expand All @@ -865,7 +874,7 @@ def variational_gamma(
regularise_roots=regularise_roots,
singletons_phased=singletons_phased,
)
return dating_method.parse_result(result, eps)
return dating_method.parse_result(result)


estimation_methods = {
Expand Down Expand Up @@ -893,6 +902,7 @@ def date(
time_units=None,
method=None,
constr_iterations=None,
min_branch_length=None,
set_metadata=None,
return_fit=None,
return_likelihood=None,
Expand Down Expand Up @@ -943,6 +953,8 @@ def date(
:param int constr_iterations: The maximum number of constrained least
squares iterations to use prior to forcing positive branch lengths.
Default: None, treated as 0.
:param float min_branch_length: The minimum distance separating parent and
child ages in the returned tree sequence. Default: None, treated as 1e-8
:param bool set_metadata: Should unconstrained times be stored in table metadata,
in the form of ``"mn"`` (mean) and ``"vr"`` (variance) fields? If ``False``,
do not store metadata. If ``True``, force metadata to be set (if no schema
Expand Down Expand Up @@ -984,6 +996,7 @@ def date(
time_units=time_units,
progress=progress,
constr_iterations=constr_iterations,
min_branch_length=min_branch_length,
return_fit=return_fit,
return_likelihood=return_likelihood,
allow_unary=allow_unary,
Expand Down
Loading