From 4a960ffe7fd8709b45f447ef57f237e3b4438367 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 21 Jan 2026 21:13:27 -0800 Subject: [PATCH] Disambiguate eps into min_branch_length; change default to 1e-8 --- CHANGELOG.md | 11 +++++++++ tsdate/core.py | 63 ++++++++++++++++++++++++++++++-------------------- 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be6b8f73..81c7b861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [0.2.5] - XXXXXXXXXX + +- The `eps` parameter was previously used for two things: the minimum time interval in + likelihood calculations for the discrete-time algorithms, and the minimum allowed branch length + when forcing positive branch lengths. The latter is now a separate parameter called + `min_branch_length` for all algorithms, while the `eps` parameter is only used for the + discrete time algorithms. + +- The default `min_branch_length` and `eps` have been set to 1e-8 rather than 1e-10, to avoid + occasional issues with floating point error. + ## [0.2.4] - 2025-09-18 - Add support for Python 3.13, minimum version is now 3.10. diff --git a/tsdate/core.py b/tsdate/core.py index b2e1c895..02bdbfdc 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -41,7 +41,8 @@ DEFAULT_RESCALING_INTERVALS = 1000 DEFAULT_RESCALING_ITERATIONS = 5 DEFAULT_MAX_ITERATIONS = 25 -DEFAULT_EPSILON = 1e-10 +DEFAULT_EPSILON = 1e-8 +DEFAULT_MIN_BRANCH_LENGTH = 1e-8 # Classes for each method @@ -87,6 +88,7 @@ def __init__( allow_unary=None, record_provenance=None, constr_iterations=None, + min_branch_length=None, set_metadata=None, progress=None, # deprecated params @@ -150,6 +152,13 @@ def __init__( ) self.constr_iterations = constr_iterations + if min_branch_length is None: + self.min_branch_length = DEFAULT_MIN_BRANCH_LENGTH + else: + if not min_branch_length > 0.0: + raise ValueError("Minimum branch length must be positive") + self.min_branch_length = min_branch_length + self.allow_unary = False if allow_unary is None else allow_unary if self.prior_grid_func_name is None: @@ -188,7 +197,7 @@ def __init__( self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts) self.fixed_nodes = np.array(list(ts.samples())) - def get_modified_ts(self, result, eps): + def get_modified_ts(self, result): # Return a new ts based on the existing one, but with the various # time-related information correctly set. ts = self.ts @@ -215,7 +224,9 @@ def get_modified_ts(self, result, eps): # Constrain node ages for positive branch lengths constr_timing = time.time() - nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) + nodes.time = util.constrain_ages( + ts, node_mean_t, self.min_branch_length, self.constr_iterations + ) constr_timing -= time.time() logger.info(f"Constrained node ages in {abs(constr_timing):.2f} seconds") # Possibly change mutation nodes if phasing singletons @@ -282,10 +293,10 @@ def _time_md_array(table, mean, var): table.metadata_schema = default_schema table.packset_metadata(_time_md_array(table, mean, var)) - def parse_result(self, result, epsilon): + def parse_result(self, result): # Construct the tree sequence to return and add other stuff we might want to # return. pst_cols is a dict to be appended to the output posterior dict - ret = [self.get_modified_ts(result, epsilon)] + ret = [self.get_modified_ts(result)] if self.return_fit: ret.append(result.fit_object) if self.return_likelihood: @@ -451,7 +462,6 @@ def __init__(self, ts, **kwargs): def run( self, - eps, max_iterations, max_shape, rescaling_intervals, @@ -570,9 +580,8 @@ def maximization( the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. - :param float eps: The error factor in time difference calculations, and the - minimum distance separating parent and child ages in the returned tree sequence. - Default: None, treated as 1e-10. + :param float eps: The error factor in time difference calculations. Default: None, + treated as 1e-8. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param string probability_space: Should the internal algorithm save @@ -587,7 +596,7 @@ def maximization( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **marginal_likelihood** (:py:class:`float`) -- (Only returned if ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. @@ -615,7 +624,7 @@ def maximization( cache_inside=cache_inside, probability_space=probability_space, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) def inside_outside( @@ -692,9 +701,8 @@ def inside_outside( the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. - :param float eps: The error factor in time difference calculations, and the - minimum distance separating parent and child ages in the returned tree sequence. - Default: None, treated as 1e-10. + :param float eps: The error factor in time difference calculations. Default: None, + treated as 1e-8. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param bool outside_standardize: Should the likelihoods be standardized during the @@ -720,7 +728,7 @@ def inside_outside( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **fit** (:class:`~discrete.BeliefPropagation`) -- (Only returned if ``return_fit`` is ``True``) The underlying object used to run the dating inference. This can then be queried e.g. using @@ -757,14 +765,13 @@ def inside_outside( cache_inside=cache_inside, probability_space=probability_space, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) def variational_gamma( tree_sequence, *, mutation_rate, - eps=None, max_iterations=None, rescaling_intervals=None, rescaling_iterations=None, @@ -773,10 +780,12 @@ def variational_gamma( max_shape=None, regularise_roots=None, singletons_phased=None, + # deprecated parameters + eps=None, **kwargs, ): """ - variational_gamma(tree_sequence, *, mutation_rate, eps=None, max_iterations=None,\ + variational_gamma(tree_sequence, *, mutation_rate, max_iterations=None,\ rescaling_intervals=None, rescaling_iterations=None,\ match_segregating_sites=None, **kwargs) @@ -797,8 +806,6 @@ def variational_gamma( :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. - :param float eps: The minimum distance separating parent and child ages in - the returned tree sequence. Default: None, treated as 1e-10 :param int max_iterations: The number of iterations used in the expectation propagation algorithm. Default: None, treated as 25. :param float rescaling_intervals: For time rescaling, the number of time @@ -820,7 +827,7 @@ def variational_gamma( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **fit** (:class:`~variational.ExpectationPropagation`) -- (Only returned if ``return_fit`` is ``True``). The underlying object used to run the dating inference. This can then be queried e.g. using @@ -830,8 +837,6 @@ def variational_gamma( the mutation data given the inferred node times. Not currently implemented for this method (set to ``None``) """ - if eps is None: - eps = DEFAULT_EPSILON if max_iterations is None: max_iterations = DEFAULT_MAX_ITERATIONS if max_shape is None: @@ -848,6 +853,11 @@ def variational_gamma( regularise_roots = True if singletons_phased is None: singletons_phased = True + if eps is not None: + raise ValueError( + "The `eps` parameter has been disambiguated and is no longer used " + "for the variational gamma algorithm; use `min_branch_length` instead" + ) if tree_sequence.num_mutations == 0: raise ValueError( "No mutations present: these are required for the variational_gamma method" @@ -856,7 +866,6 @@ def variational_gamma( tree_sequence, mutation_rate=mutation_rate, **kwargs ) result = dating_method.run( - eps=eps, max_iterations=max_iterations, max_shape=max_shape, rescaling_intervals=rescaling_intervals, @@ -865,7 +874,7 @@ def variational_gamma( regularise_roots=regularise_roots, singletons_phased=singletons_phased, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) estimation_methods = { @@ -893,6 +902,7 @@ def date( time_units=None, method=None, constr_iterations=None, + min_branch_length=None, set_metadata=None, return_fit=None, return_likelihood=None, @@ -943,6 +953,8 @@ def date( :param int constr_iterations: The maximum number of constrained least squares iterations to use prior to forcing positive branch lengths. Default: None, treated as 0. + :param float min_branch_length: The minimum distance separating parent and + child ages in the returned tree sequence. Default: None, treated as 1e-8 :param bool set_metadata: Should unconstrained times be stored in table metadata, in the form of ``"mn"`` (mean) and ``"vr"`` (variance) fields? If ``False``, do not store metadata. If ``True``, force metadata to be set (if no schema @@ -984,6 +996,7 @@ def date( time_units=time_units, progress=progress, constr_iterations=constr_iterations, + min_branch_length=min_branch_length, return_fit=return_fit, return_likelihood=return_likelihood, allow_unary=allow_unary,