diff --git a/.gitignore b/.gitignore index 63f8cb70..233688f9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ target/** *.code-workspace tmp* temp* +# The `temp*` glob above matches `templates/` too; keep the Train-FIRE workflow templates dir tracked. +!Train-FIRE/workflow/templates/ +!Train-FIRE/workflow/templates/** dists/** .vscode/ .test/small_unaligned_fiberseq.bam @@ -24,4 +27,13 @@ py-ft/test.tbl tests/data/shuffle.chr20.hifi.bed.gz benchmarks/2025-07-08-new-torch-version/bin/ local-test-files/ -pg-test-data/ \ No newline at end of file +pg-test-data/ +.snakemake/ +.claude/ +results/* +Train-FIRE/results/* +Train-FIRE/pixi.toml +temp/** +scripts/* +pixi.lock +Train-FIRE/GM12878-fire-v0.1-filtered.cram.crai diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bef978f..1eea9f59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,35 @@ All notable changes to this project will be documented in this file. +## [Unreleased] + +### ๐Ÿš€ New subcommands + +- **`ft call-peaks`**: peak caller for FIRE pileup BED output. Identifies enriched regions, applies optional FIRE fiber-level filters, and emits per-peak statistics. +- **`ft mock-fire`**: generate mock FIRE-style data for testing and benchmarking pipelines. +- **`ft benchmark`**: lightweight benchmarking harness for measuring throughput on key subcommands. + +### ๐Ÿ”ง Improvements + +- **`ft pileup`**: accepts multiple regions in a single invocation; default `--frac-fibers` filter applied to drop low-coverage windows. +- **FIRE fiber-level filters hoisted to global `FiberFilters`**: `--fire-filter`, `--min-msp`, `--min-ave-msp-size`, and `--skip-no-m6a` are now available as global flags shared by `ft fire`, `ft pileup`, and `ft call-peaks`. `--fire-filter` is a convenience preset; individual flags override its defaults. +- **Train-FIRE Snakemake pipeline** added under `Train-FIRE/` for end-to-end FIRE model training. +- Build/dev-loop improvements: thin-LTO and a dedicated test profile for faster iteration. + +### ๐Ÿ› Bug fixes + +- **fa:Z extras pairing on reverse-strand reads with overlapping/shared-start peaks (follow-up to #103)**: when multiple BED annotations share a query start position on a reverse-strand read, the post-hoc `ann_vals.reverse()` step in `FiberAnnotations::from_bam_tags` did not produce a permutation equivalent to the stable sort over flipped coordinates, scrambling the fa:Z extras (e.g. peak names/lengths) relative to their fs/fl annotations on extraction. Fix pre-pairs extras with annotations before the flip+sort so the permutation is consistent. BAM encodings were always correct; the bug only affected extraction. Affects v0.8.0โ€“v0.8.2. + +### ๐Ÿงช Tests + +- `tests/call_peaks_test.rs`: end-to-end coverage of the new peak caller. +- `tests/fibertig_test.rs`: regression test for the fa:Z reverse-strand shared-start pairing bug, using Anna Minkina's exact failing input. +- Unit tests in `src/utils/input_bam.rs` covering `passes_fire_filter` semantics: skip-no-m6a behavior, min-msp, min-ave-msp-size, the `--fire-filter` combo, and explicit-flag overrides. + +### โšก Performance + +- Per-record FIRE feature extraction in `ft fire -f` is now parallelized (#104). + ### [0.8.2] #### Bug Fixes diff --git a/Cargo.lock b/Cargo.lock index 7e8fc854..754d38db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,27 +158,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "argminmax" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65" -dependencies = [ - "num-traits", -] - -[[package]] -name = "array-init-cursor" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3" - -[[package]] -name = "arrayref" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" - [[package]] name = "arrayvec" version = "0.7.6" @@ -222,54 +201,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - -[[package]] -name = "async-trait" -version = "0.1.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - -[[package]] -name = "atoi_simd" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf" -dependencies = [ - "debug_unsafe", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - [[package]] name = "atomic_float" version = "1.1.0" @@ -406,20 +337,10 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" dependencies = [ - "bincode_derive", "serde", "unty", ] -[[package]] -name = "bincode_derive" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" -dependencies = [ - "virtue", -] - [[package]] name = "bio" version = "1.6.0" @@ -529,19 +450,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" -[[package]] -name = "blake3" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "block" version = "0.1.6" @@ -557,12 +465,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "boxcar" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" - [[package]] name = "bstr" version = "1.12.0" @@ -1022,9 +924,6 @@ name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" -dependencies = [ - "serde", -] [[package]] name = "bytesize" @@ -1154,28 +1053,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" -[[package]] -name = "chrono" -version = "0.4.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" -dependencies = [ - "iana-time-zone", - "num-traits", - "serde", - "windows-link 0.2.0", -] - -[[package]] -name = "chrono-tz" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" -dependencies = [ - "chrono", - "phf", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1297,17 +1174,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "comfy-table" -version = "7.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b" -dependencies = [ - "crossterm 0.29.0", - "unicode-segmentation", - "unicode-width 0.2.0", -] - [[package]] name = "compact_str" version = "0.8.1" @@ -1322,21 +1188,6 @@ dependencies = [ "static_assertions", ] -[[package]] -name = "compact_str" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "rustversion", - "ryu", - "serde", - "static_assertions", -] - [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1365,12 +1216,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" -[[package]] -name = "constant_time_eq" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" - [[package]] name = "core-foundation" version = "0.9.4" @@ -1381,16 +1226,6 @@ dependencies = [ "libc", ] -[[package]] -name = "core-foundation" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1404,7 +1239,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "libc", ] @@ -1471,15 +1306,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1502,20 +1328,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "crossterm" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" -dependencies = [ - "bitflags 2.9.1", - "crossterm_winapi", - "document-features", - "parking_lot", - "rustix 1.0.8", - "winapi", -] - [[package]] name = "crossterm_winapi" version = "0.9.1" @@ -2020,12 +1832,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" -[[package]] -name = "debug_unsafe" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85d3cef41d236720ed453e102153a53e4cc3d2fde848c0078a50cf249e8e3e5b" - [[package]] name = "deranged" version = "0.4.0" @@ -2238,12 +2044,6 @@ dependencies = [ "shared_thread", ] -[[package]] -name = "dyn-clone" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" - [[package]] name = "dyn-stack" version = "0.10.0" @@ -2418,12 +2218,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "ethnum" -version = "1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" - [[package]] name = "event-listener" version = "5.4.0" @@ -2472,12 +2266,6 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" -[[package]] -name = "fast-float2" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" - [[package]] name = "fastrand" version = "2.3.0" @@ -2501,7 +2289,7 @@ checksum = "835a3dc7d1ec9e75e2b5fb4ba75396837112d2060b03f7d43bc1897c7f7211da" [[package]] name = "fibertools-rs" -version = "0.8.2" +version = "0.9.0" dependencies = [ "anstyle", "anyhow", @@ -2528,13 +2316,11 @@ dependencies = [ "noodles", "num", "ordered-float 3.9.2", - "polars", "predicates", "rand 0.8.5", "rayon", "regex", "rust-htslib", - "rust-lapper", "serde", "serde_json", "serde_yaml", @@ -2647,31 +2433,6 @@ dependencies = [ "quick-error 1.2.3", ] -[[package]] -name = "fs4" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" -dependencies = [ - "rustix 1.0.8", - "windows-sys 0.59.0", -] - -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -2679,7 +2440,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", - "futures-sink", ] [[package]] @@ -2688,17 +2448,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" -[[package]] -name = "futures-executor" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - [[package]] name = "futures-io" version = "0.3.31" @@ -2753,13 +2502,10 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ - "futures-channel", "futures-core", - "futures-io", "futures-macro", "futures-sink", "futures-task", - "memchr", "pin-project-lite", "pin-utils", "slab", @@ -3055,11 +2801,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", - "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", - "wasm-bindgen", ] [[package]] @@ -3278,25 +3022,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "h2" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" -dependencies = [ - "atomic-waker", - "bytes", - "fnv", - "futures-core", - "futures-sink", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "half" version = "2.6.0" @@ -3336,7 +3061,6 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash", - "rayon", "serde", ] @@ -3376,12 +3100,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "hexf-parse" version = "0.2.1" @@ -3483,7 +3201,6 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2", "http", "http-body", "httparse", @@ -3492,24 +3209,6 @@ dependencies = [ "pin-project-lite", "smallvec", "tokio", - "want", -] - -[[package]] -name = "hyper-rustls" -version = "0.27.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls", - "rustls-native-certs", - "rustls-pki-types", - "tokio", - "tokio-rustls", - "tower-service", ] [[package]] @@ -3518,46 +3217,14 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" dependencies = [ - "base64", "bytes", - "futures-channel", "futures-core", - "futures-util", "http", "http-body", "hyper", - "ipnet", - "libc", - "percent-encoding", "pin-project-lite", - "socket2", "tokio", "tower-service", - "tracing", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "log", - "wasm-bindgen", - "windows-core 0.61.2", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", ] [[package]] @@ -3726,7 +3393,6 @@ checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown 0.15.4", - "serde", ] [[package]] @@ -3794,34 +3460,18 @@ dependencies = [ ] [[package]] -name = "ipnet" -version = "2.11.0" +name = "is_terminal_polyfill" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] -name = "iri-string" -version = "0.7.8" +name = "itertools" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "is_terminal_polyfill" -version = "1.70.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" - -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", + "either", ] [[package]] @@ -4169,31 +3819,6 @@ dependencies = [ "hashbrown 0.15.4", ] -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" - -[[package]] -name = "lz4" -version = "1.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "lzma-sys" version = "0.1.20" @@ -4374,7 +3999,7 @@ dependencies = [ "log", "num-traits", "once_cell", - "rustc-hash 1.1.0", + "rustc-hash", "spirv", "strum 0.26.3", "thiserror 2.0.12", @@ -4605,15 +4230,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" -[[package]] -name = "now" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" -dependencies = [ - "chrono", -] - [[package]] name = "ntapi" version = "0.4.1" @@ -4832,41 +4448,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc4f07659e11cd45a341cd24d71e683e3be65d9ff1f8150061678fe60437496" -dependencies = [ - "async-trait", - "base64", - "bytes", - "chrono", - "form_urlencoded", - "futures", - "http", - "http-body-util", - "humantime", - "hyper", - "itertools 0.14.0", - "parking_lot", - "percent-encoding", - "quick-xml", - "rand 0.9.1", - "reqwest", - "ring", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror 2.0.12", - "tokio", - "tracing", - "url", - "walkdir", - "wasm-bindgen-futures", - "web-time", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -4895,12 +4476,6 @@ dependencies = [ "strum 0.27.1", ] -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - [[package]] name = "openssl-src" version = "300.5.1+3.5.1" @@ -5037,24 +4612,6 @@ dependencies = [ "indexmap", ] -[[package]] -name = "phf" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" -dependencies = [ - "phf_shared", -] - -[[package]] -name = "phf_shared" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" -dependencies = [ - "siphasher", -] - [[package]] name = "pin-project" version = "1.1.10" @@ -5093,16 +4650,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" -[[package]] -name = "planus" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71" -dependencies = [ - "array-init-cursor", - "hashbrown 0.15.4", -] - [[package]] name = "png" version = "0.17.16" @@ -5116,513 +4663,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "polars" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5f7feb5d56b954e691dff22a8b2d78d77433dcc93c35fe21c3777fdc121b697" -dependencies = [ - "getrandom 0.2.16", - "getrandom 0.3.3", - "polars-arrow", - "polars-core", - "polars-error", - "polars-io", - "polars-lazy", - "polars-ops", - "polars-parquet", - "polars-sql", - "polars-time", - "polars-utils", - "version_check", -] - -[[package]] -name = "polars-arrow" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b4fed2343961b3eea3db2cee165540c3e1ad9d5782350cc55a9e76cf440148" -dependencies = [ - "atoi_simd", - "bitflags 2.9.1", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "getrandom 0.2.16", - "getrandom 0.3.3", - "hashbrown 0.15.4", - "itoa", - "lz4", - "num-traits", - "polars-arrow-format", - "polars-error", - "polars-schema", - "polars-utils", - "serde", - "simdutf8", - "streaming-iterator", - "strum_macros 0.27.1", - "version_check", - "zstd 0.13.3", -] - -[[package]] -name = "polars-arrow-format" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675" -dependencies = [ - "planus", - "serde", -] - -[[package]] -name = "polars-compute" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "138785beda4e4a90a025219f09d0d15a671b2be9091513ede58e05db6ad4413f" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "either", - "fast-float2", - "hashbrown 0.15.4", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "rand 0.9.1", - "ryu", - "serde", - "skiplist", - "strength_reduce", - "strum_macros 0.27.1", - "version_check", -] - -[[package]] -name = "polars-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e77b1f08ef6dbb032bb1d0d3365464be950df9905f6827a95b24c4ca5518901d" -dependencies = [ - "bitflags 2.9.1", - "boxcar", - "bytemuck", - "chrono", - "chrono-tz", - "comfy-table", - "either", - "hashbrown 0.15.4", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-row", - "polars-schema", - "polars-utils", - "rand 0.9.1", - "rand_distr 0.5.1", - "rayon", - "regex", - "serde", - "serde_json", - "strum_macros 0.27.1", - "uuid", - "version_check", - "xxhash-rust", -] - -[[package]] -name = "polars-dtype" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89c43d0ea57168be4546c4d8064479ed8b29a9c79c31a0c7c367ee734b9b7158" -dependencies = [ - "boxcar", - "hashbrown 0.15.4", - "polars-arrow", - "polars-error", - "polars-utils", - "serde", - "uuid", -] - -[[package]] -name = "polars-error" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9cb5d98f59f8b94673ee391840440ad9f0d2170afced95fc98aa86f895563c0" -dependencies = [ - "object_store", - "parking_lot", - "polars-arrow-format", - "regex", - "signal-hook", - "simdutf8", -] - -[[package]] -name = "polars-expr" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343931b818cf136349135ba11dbc18c27683b52c3477b1ba8ca606cf5ab1965c" -dependencies = [ - "bitflags 2.9.1", - "hashbrown 0.15.4", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-io", - "polars-ops", - "polars-plan", - "polars-row", - "polars-time", - "polars-utils", - "rand 0.9.1", - "rayon", - "recursive", -] - -[[package]] -name = "polars-io" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10388c64b8155122488229a881d1c6f4fdc393bc988e764ab51b182fcb2307e4" -dependencies = [ - "async-trait", - "atoi_simd", - "blake3", - "bytes", - "chrono", - "fast-float2", - "flate2", - "fs4", - "futures", - "glob", - "hashbrown 0.15.4", - "home", - "itoa", - "memchr", - "memmap2", - "num-traits", - "object_store", - "percent-encoding", - "polars-arrow", - "polars-core", - "polars-error", - "polars-parquet", - "polars-schema", - "polars-time", - "polars-utils", - "rayon", - "regex", - "reqwest", - "ryu", - "serde", - "serde_json", - "simdutf8", - "tokio", - "tokio-util", - "url", - "zstd 0.13.3", -] - -[[package]] -name = "polars-lazy" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb6e2c6c2fa4ea0c660df1c06cf56960c81e7c2683877995bae3d4e3d408147" -dependencies = [ - "bitflags 2.9.1", - "chrono", - "either", - "memchr", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-plan", - "polars-stream", - "polars-time", - "polars-utils", - "rayon", - "version_check", -] - -[[package]] -name = "polars-mem-engine" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20a856e98e253587c28d8132a5e7e5a75cb2c44731ca090f1481d45f1d123771" -dependencies = [ - "memmap2", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rayon", - "recursive", -] - -[[package]] -name = "polars-ops" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf6062173fdc9ba05775548beb66e76643a148d9aeadc9984ed712bc4babd76" -dependencies = [ - "argminmax", - "base64", - "bytemuck", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.15.4", - "hex", - "indexmap", - "libm", - "memchr", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-schema", - "polars-utils", - "rayon", - "regex", - "regex-syntax", - "strum_macros 0.27.1", - "unicode-normalization", - "unicode-reverse", - "version_check", -] - -[[package]] -name = "polars-parquet" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1d769180dec070df0dc4b89299b364bf2cfe32b218ecc4ddd8f1a49ae60669" -dependencies = [ - "async-stream", - "base64", - "bytemuck", - "ethnum", - "futures", - "hashbrown 0.15.4", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-parquet-format", - "polars-utils", - "serde", - "simdutf8", - "streaming-decompression", -] - -[[package]] -name = "polars-parquet-format" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" -dependencies = [ - "async-trait", - "futures", -] - -[[package]] -name = "polars-plan" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd3a2e33ae4484fe407ab2d2ba5684f0889d1ccf3ad6b844103c03638e6d0a0" -dependencies = [ - "bitflags 2.9.1", - "bytemuck", - "bytes", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.15.4", - "memmap2", - "num-traits", - "percent-encoding", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-io", - "polars-ops", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "regex", - "sha2", - "strum_macros 0.27.1", - "version_check", -] - -[[package]] -name = "polars-row" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18734f17e0e348724df3ae65f3ee744c681117c04b041cac969dfceb05edabc0" -dependencies = [ - "bitflags 2.9.1", - "bytemuck", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-utils", -] - -[[package]] -name = "polars-schema" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6c1ab13e04d5167661a9854ed1ea0482b2ed9b8a0f1118dabed7cd994a85e3" -dependencies = [ - "indexmap", - "polars-error", - "polars-utils", - "serde", - "version_check", -] - -[[package]] -name = "polars-sql" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e7766da02cc1d464994404d3e88a7a0ccd4933df3627c325480fbd9bbc0a11" -dependencies = [ - "bitflags 2.9.1", - "hex", - "polars-core", - "polars-error", - "polars-lazy", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rand 0.9.1", - "regex", - "serde", - "sqlparser", -] - -[[package]] -name = "polars-stream" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f6c6ca1ea01f9dea424d167e4f33f5ec44cd67fbfac9efd40575ed20521f14" -dependencies = [ - "async-channel", - "async-trait", - "atomic-waker", - "bitflags 2.9.1", - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-queue", - "crossbeam-utils", - "futures", - "memmap2", - "parking_lot", - "percent-encoding", - "pin-project-lite", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-utils", - "rand 0.9.1", - "rayon", - "recursive", - "slotmap", - "tokio", - "tokio-util", - "version_check", -] - -[[package]] -name = "polars-time" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6a3a6e279a7a984a0b83715660f9e880590c6129ec2104396bfa710bcd76dee" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "chrono-tz", - "now", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-ops", - "polars-utils", - "rayon", - "regex", - "strum_macros 0.27.1", -] - -[[package]] -name = "polars-utils" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57b267021b0e5422d7fbc70fd79e51b9f9a8466c585779373a18b0199e973f29" -dependencies = [ - "bincode", - "bytemuck", - "bytes", - "compact_str 0.9.0", - "either", - "flate2", - "foldhash", - "hashbrown 0.15.4", - "indexmap", - "libc", - "memmap2", - "num-traits", - "polars-error", - "rand 0.9.1", - "raw-cpuid 11.5.0", - "rayon", - "regex", - "rmp-serde", - "serde", - "serde_json", - "serde_stacker", - "slotmap", - "stacker", - "uuid", - "version_check", -] - [[package]] name = "portable-atomic" version = "1.11.1" @@ -5800,15 +4840,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "psm" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" -dependencies = [ - "cc", -] - [[package]] name = "pulp" version = "0.18.22" @@ -5839,87 +4870,22 @@ dependencies = [ name = "qoi" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" -dependencies = [ - "bytemuck", -] - -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - -[[package]] -name = "quick-error" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" - -[[package]] -name = "quick-xml" -version = "0.38.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "quinn" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.1", - "rustls", - "socket2", - "thiserror 2.0.12", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" -dependencies = [ - "bytes", - "getrandom 0.3.3", - "lru-slab", - "rand 0.9.1", - "ring", - "rustc-hash 2.1.1", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.12", - "tinyvec", - "tracing", - "web-time", +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", ] [[package]] -name = "quinn-udp" -version = "0.5.14" +name = "quick-error" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.60.2", -] +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" @@ -6051,8 +5017,8 @@ checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ "bitflags 2.9.1", "cassowary", - "compact_str 0.8.1", - "crossterm 0.28.1", + "compact_str", + "crossterm", "indoc", "instability", "itertools 0.13.0", @@ -6171,26 +5137,6 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" -[[package]] -name = "recursive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" -dependencies = [ - "recursive-proc-macro-impl", - "stacker", -] - -[[package]] -name = "recursive-proc-macro-impl" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" -dependencies = [ - "quote", - "syn 2.0.104", -] - [[package]] name = "redox_syscall" version = "0.5.13" @@ -6252,48 +5198,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" -[[package]] -name = "reqwest" -version = "0.12.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" -dependencies = [ - "base64", - "bytes", - "futures-core", - "futures-util", - "h2", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", - "hyper-util", - "js-sys", - "log", - "percent-encoding", - "pin-project-lite", - "quinn", - "rustls", - "rustls-native-certs", - "rustls-pki-types", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper", - "tokio", - "tokio-rustls", - "tokio-util", - "tower", - "tower-http", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-streams", - "web-sys", -] - [[package]] name = "rgb" version = "0.8.52" @@ -6418,15 +5322,6 @@ dependencies = [ "url", ] -[[package]] -name = "rust-lapper" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2274b9cc4f205bc0945b7be3e4fc1a102b0b7119ba6482faaedb9c4f76dde5d1" -dependencies = [ - "num-traits", -] - [[package]] name = "rustc-demangle" version = "0.1.25" @@ -6439,12 +5334,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc-hash" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" - [[package]] name = "rustc_version" version = "0.1.7" @@ -6504,25 +5393,12 @@ dependencies = [ "zeroize", ] -[[package]] -name = "rustls-native-certs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" -dependencies = [ - "openssl-probe", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-pki-types" version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ - "web-time", "zeroize", ] @@ -6596,15 +5472,6 @@ dependencies = [ "regex", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.0", -] - [[package]] name = "scheduled-thread-pool" version = "0.2.7" @@ -6620,29 +5487,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "security-framework" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b369d18893388b345804dc0007963c99b7d665ae71d275812d828c6f089640" -dependencies = [ - "bitflags 2.9.1", - "core-foundation 0.10.1", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "0.1.20" @@ -6753,17 +5597,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_stacker" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4936375d50c4be7eff22293a9344f8e46f323ed2b3c243e52f89138d9bb0f4a" -dependencies = [ - "serde", - "serde_core", - "stacker", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -6912,28 +5745,6 @@ dependencies = [ "quote", ] -[[package]] -name = "simdutf8" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" - -[[package]] -name = "siphasher" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" - -[[package]] -name = "skiplist" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f354fd282d3177c2951004953e2fdc4cb342fa159bbee8b829852b6a081c8ea1" -dependencies = [ - "rand 0.9.1", - "thiserror 2.0.12", -] - [[package]] name = "slab" version = "0.4.10" @@ -6993,34 +5804,12 @@ dependencies = [ "bitflags 2.9.1", ] -[[package]] -name = "sqlparser" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" -dependencies = [ - "log", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "stacker" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" -dependencies = [ - "cc", - "cfg-if", - "libc", - "psm", - "windows-sys 0.59.0", -] - [[package]] name = "static_assertions" version = "1.1.0" @@ -7040,27 +5829,6 @@ dependencies = [ "rand 0.8.5", ] -[[package]] -name = "streaming-decompression" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" -dependencies = [ - "fallible-streaming-iterator", -] - -[[package]] -name = "streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" - -[[package]] -name = "strength_reduce" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" - [[package]] name = "strsim" version = "0.10.0" @@ -7169,9 +5937,6 @@ name = "sync_wrapper" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" -dependencies = [ - "futures-core", -] [[package]] name = "synstructure" @@ -7494,16 +6259,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "tokio-rustls" -version = "0.26.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" -dependencies = [ - "rustls", - "tokio", -] - [[package]] name = "tokio-tungstenite" version = "0.26.2" @@ -7528,20 +6283,6 @@ dependencies = [ "tungstenite 0.27.0", ] -[[package]] -name = "tokio-util" -version = "0.7.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "futures-util", - "pin-project-lite", - "tokio", -] - [[package]] name = "toml" version = "0.8.23" @@ -7646,24 +6387,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tower-http" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" -dependencies = [ - "bitflags 2.9.1", - "bytes", - "futures-util", - "http", - "http-body", - "iri-string", - "pin-project-lite", - "tower", - "tower-layer", - "tower-service", -] - [[package]] name = "tower-layer" version = "0.3.3" @@ -7752,12 +6475,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22048bc95dfb2ffd05b1ff9a756290a009224b60b2f0e7525faeee7603851e63" -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - [[package]] name = "tungstenite" version = "0.26.2" @@ -7834,15 +6551,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-reverse" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -7952,7 +6660,6 @@ dependencies = [ "getrandom 0.3.3", "js-sys", "rand 0.9.1", - "serde", "wasm-bindgen", ] @@ -8054,12 +6761,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "virtue" -version = "0.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" - [[package]] name = "void" version = "1.0.2" @@ -8085,15 +6786,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -8180,19 +6872,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "web-sys" version = "0.3.77" @@ -8286,7 +6965,7 @@ dependencies = [ "portable-atomic", "profiling", "raw-window-handle", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "thiserror 2.0.12", "wgpu-core-deps-apple", @@ -8466,7 +7145,7 @@ dependencies = [ "windows-collections", "windows-core 0.61.2", "windows-future", - "windows-link 0.1.3", + "windows-link", "windows-numerics", ] @@ -8512,7 +7191,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement 0.60.0", "windows-interface 0.59.1", - "windows-link 0.1.3", + "windows-link", "windows-result 0.3.4", "windows-strings 0.4.2", ] @@ -8524,7 +7203,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" dependencies = [ "windows-core 0.61.2", - "windows-link 0.1.3", + "windows-link", "windows-threading", ] @@ -8600,12 +7279,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" -[[package]] -name = "windows-link" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" - [[package]] name = "windows-numerics" version = "0.2.0" @@ -8613,7 +7286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ "windows-core 0.61.2", - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -8640,7 +7313,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -8659,7 +7332,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -8689,15 +7362,6 @@ dependencies = [ "windows-targets 0.53.2", ] -[[package]] -name = "windows-sys" -version = "0.61.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e201184e40b2ede64bc2ea34968b28e33622acdbbf37104f0e4a33f7abe657aa" -dependencies = [ - "windows-link 0.2.0", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -8736,7 +7400,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -8877,12 +7541,6 @@ version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fd8403733700263c6eb89f192880191f1b83e332f7a20371ddcf421c4a337c7" -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - [[package]] name = "yoke" version = "0.7.5" @@ -9020,7 +7678,7 @@ dependencies = [ "aes", "byteorder", "bzip2", - "constant_time_eq 0.1.5", + "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", @@ -9028,7 +7686,7 @@ dependencies = [ "pbkdf2", "sha1", "time", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -9058,16 +7716,7 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe 5.0.2+zstd.1.5.2", -] - -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe 7.2.4", + "zstd-safe", ] [[package]] @@ -9080,15 +7729,6 @@ dependencies = [ "zstd-sys", ] -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - [[package]] name = "zstd-sys" version = "2.0.15+zstd.1.5.7" diff --git a/Cargo.toml b/Cargo.toml index e978181d..a37d481d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ license = "MIT" name = "fibertools-rs" readme = "README.md" repository = "https://github.com/fiberseq/fibertools-rs" -version = "0.8.2" +version = "0.9.0" # exclude py-ft and test data from cargo publish since they are too large -exclude = ["py-ft/", "tests/data/"] +exclude = ["py-ft/", "tests/data/", "Train-FIRE/"] [workspace] exclude = ["py-ft"] @@ -50,7 +50,6 @@ rayon = "1.10" linear-map = "1.2.0" regex = "1.9.1" rust-htslib = "0.46" -rust-lapper = "1.2.0" serde = { version = "1.0.104", features = ["derive"], optional = false } @@ -69,7 +68,6 @@ burn = { version = "0.18.0", optional = true, features = [ num = "0.4.3" rand = "0.8.5" noodles = { version = "0.100.0", features = ["fasta", "bed"] } -polars = { version = "0.51.0", features = ["csv", "decompress", "lazy"] } [build-dependencies] burn-import = { version = "0.18.0", default-features = false, features = [ @@ -98,6 +96,8 @@ burn = ["dep:burn"] [profile.dist] inherits = "release" split-debuginfo = "packed" +lto = true +codegen-units = 1 # generated by 'cargo wizard' [profile.dev] @@ -114,7 +114,13 @@ incremental = true [profile.release] strip = "debuginfo" -codegen-units = 1 +codegen-units = 16 debug = false -lto = true +lto = "thin" panic = "abort" + +[profile.profiling] +inherits = "release" +debug = true +codegen-units = 16 +lto = "thin" diff --git a/Train-FIRE/.tests/config.yaml b/Train-FIRE/.tests/config.yaml new file mode 100644 index 00000000..23274508 --- /dev/null +++ b/Train-FIRE/.tests/config.yaml @@ -0,0 +1,85 @@ +# Tiny smoke-test config. Trains a toy XGBoost model end-to-end from a +# 1.3MB BAM (~154 chr19 reads with m6A tags) and a synthetic fai so the +# whole pipeline โ€” regions, features, training, comparison โ€” runs in +# well under a minute with no network access. + +reference: + fai: .tests/data/ref.fa.fai + fasta: .tests/data/ref.fa + genome: hg38 + +# Local BAM โ€” no s3 streaming. +training_bam: .tests/data/sample.bam +s3_endpoint: null +# `-s 0.99` keeps ~99% of reads; -s 1.0 would be parsed as seed=1, frac=0. +sample_rate: 0.99 + +# test_region held out of training. Not present in the test BAM, so the +# apply/trackhub steps won't find reads and should be skipped via train_only. +test_region: chr1 +test_bam: null + +n_sites: 10 +merge_distance: 147 + +exclude_beds: + - .tests/data/exclude.bed.gz + +# No alt/random/MT/sex chroms exist in the test fai, so the regex can be +# anything that doesn't match 'chr19'. +exclude_chroms_regex: "^chrEBV$" + +train_defaults: + # Lower the MSP length cutoff so the tiny NAPA read set yields enough + # rows for mokapot's qvality spline (needs >=3 unique score bins). + min_msp_length_for_positive_fire_call: 25 + min_msp_length_for_negative_fire_call: 25 + # Smoke test has only ~100 PSMs; mokapot can't find any at 0.05 FDR. Use + # loose thresholds to exercise the pipeline end-to-end, not train a real model. + train_fdr: 0.5 + test_fdr: 0.5 + subset_max_train: 5000 + direction: msp_len_times_m6a_fc + grid_search: false + # A little depth + a few dozen trees is needed for mokapot's confidence + # spline to see >=3 unique score bins on the tiny smoke dataset. + n_estimators: [50] + max_depth: [6] + min_child_weight_fracs: [0.01] + colsample_bytree: [1.0] + gamma: [1] + learning_rate: [0.3] + early_stopping_rounds: 0 + early_stopping_val_frac: 0.15 + balance_train: true + mokapot_max_iter: 3 + +trackhub: + name: test-hub + short_label: "test" + long_label: "smoke test" + email: test@example.com + +region_sets: + smoke: + positive_beds: + - path: .tests/data/positives.bed.gz + negative_exclude_beds: [] + +experiments: + smoke_fixed: + region_set: smoke + + # Second experiment: same shape as smoke_fixed but with early-stopping + + # mokapot_max_iter knobs on. Keeps balance_train=true so pos/neg signal + # isn't perfectly separable and mokapot's triqler spline has enough + # unique score bins to converge. + smoke_es: + region_set: smoke + train: + n_estimators: [100] + max_depth: [6] + learning_rate: [0.2] + early_stopping_rounds: 20 + early_stopping_val_frac: 0.15 + mokapot_max_iter: 5 diff --git a/Train-FIRE/.tests/data/exclude.bed.gz b/Train-FIRE/.tests/data/exclude.bed.gz new file mode 100644 index 00000000..1c78e9cb Binary files /dev/null and b/Train-FIRE/.tests/data/exclude.bed.gz differ diff --git a/Train-FIRE/.tests/data/positives.bed.gz b/Train-FIRE/.tests/data/positives.bed.gz new file mode 100644 index 00000000..20bc8c5a Binary files /dev/null and b/Train-FIRE/.tests/data/positives.bed.gz differ diff --git a/Train-FIRE/.tests/data/ref.fa b/Train-FIRE/.tests/data/ref.fa new file mode 100644 index 00000000..81a4fd77 --- /dev/null +++ b/Train-FIRE/.tests/data/ref.fa @@ -0,0 +1,2 @@ +>chr19 +N diff --git a/Train-FIRE/.tests/data/ref.fa.fai b/Train-FIRE/.tests/data/ref.fa.fai new file mode 100644 index 00000000..ce2dc9df --- /dev/null +++ b/Train-FIRE/.tests/data/ref.fa.fai @@ -0,0 +1 @@ +chr19 58617616 6 60 61 diff --git a/Train-FIRE/.tests/data/sample.bam b/Train-FIRE/.tests/data/sample.bam new file mode 100644 index 00000000..5bcfd1f8 Binary files /dev/null and b/Train-FIRE/.tests/data/sample.bam differ diff --git a/Train-FIRE/.tests/data/sample.bam.bai b/Train-FIRE/.tests/data/sample.bam.bai new file mode 100644 index 00000000..06b055a9 Binary files /dev/null and b/Train-FIRE/.tests/data/sample.bam.bai differ diff --git a/Train-FIRE/README.md b/Train-FIRE/README.md new file mode 100644 index 00000000..69bc295a --- /dev/null +++ b/Train-FIRE/README.md @@ -0,0 +1,74 @@ +# Train-FIRE โ€” model training sweep + +Snakemake workflow (following [SmkTemplate](https://github.com/mrvollger/SmkTemplate)) +for training multiple FIRE models under different training configurations and +visualizing them together in a single UCSC track hub. + +## Layout + +``` +config/config.yaml shared inputs + region_sets + experiment list +workflow/Snakefile entrypoint +workflow/rules/*.smk region building, features, training, hub +workflow/scripts/*.py training + aggregation + hub assembly +workflow/envs/env.yml conda env for every rule +workflow/profiles/default/*.yaml snakemake profile +resources/mixed-positives/ input peak BEDs +results/shared/ shared sampled BAM + feature table +results/region_sets// positives/negatives/training-data (shared across experiments) +results/experiments// per-experiment trained models + FIRE calls +results/trackhub/ UCSC track hub (hub.txt / trackDb.txt / bb/) +results/comparison/ FDR overlay plot + metrics TSV +``` + +## Quickstart + +```bash +# from Train-FIRE/ +pixi install +pixi run test # dry run +pixi run snakemake -j 32 --use-conda # full run +pixi run snakemake -j 32 train_only # train+compare, skip trackhub +``` + +## Adding an experiment + +Experiments reference a named **region_set** (positives + negative-exclude +beds) defined under `region_sets:`. Every experiment sharing a region_set +shares the regions/features/training-data pipeline โ€” only the XGBoost +training step runs per-experiment. + +```yaml +region_sets: + my_regions: + positive_beds: + - path: resources/mixed-positives/ATAC.bed.gz + awk_filter: "$5 >= 1500" + - path: resources/mixed-positives/peaks_CTCF_ENCFF951PEM.bed.gz + negative_exclude_beds: + - path: resources/mixed-positives/GM12878-fire-v0.1-peaks.bed.gz + +experiments: + my_new_model: + region_set: my_regions + train: + n_estimators: [300] + max_depth: [12] +``` + +Feature extraction is shared: `ft fire -f` runs once over the union of every +region_set's positive + negative regions, and each region_set labels that +shared table via `bedtools intersect`. + +## Test region + +`test_region` in the config (default `chr19`) is held out from training +(injected into the exclusion bed automatically) and used as the sole region +for applying every trained model + trackhub visualization. + +## Track hub + +`results/trackhub/` is a standalone UCSC Track Hub. `rsync` it to a +web-accessible location and point the UCSC Genome Browser at `/hub.txt`. +One bigBed9 FIRE-element track per experiment, visibility `squish`, colored +per model. diff --git a/Train-FIRE/config/config.yaml b/Train-FIRE/config/config.yaml new file mode 100644 index 00000000..f3642c9b --- /dev/null +++ b/Train-FIRE/config/config.yaml @@ -0,0 +1,214 @@ +# Shared inputs for FIRE model training sweep +# Paths are resolved relative to the Train-FIRE/ directory unless absolute. + +reference: + # Auto-fetched by rule fetch_reference from the UCSC goldenPath server. + # Override with a local path (and drop the rule's output) if you already have it. + fai: resources/ref/hg38.analysisSet.fa.fai + fasta: resources/ref/hg38.analysisSet.fa + genome: hg38 + +# Input CRAM/BAM used for training feature extraction. +# May be a local path OR an s3://... URL. For s3, set `s3_endpoint` below +# and the pipeline will download the file once to results/shared/input.cram. +training_bam: s3://stergachis/public/FIRE/broadly-consented/GM12878/GM12878-fire-v0.1-filtered.cram +s3_endpoint: https://s3.kopah.uw.edu +sample_rate: 0.10 + +# Held-out region for model application + trackhub visualization. +# This chromosome is ALSO injected into the exclusion bed so no training +# data is drawn from it. +test_region: chr19 +# Optional separate BAM/CRAM for visualization (defaults to training_bam) +test_bam: null + +# Site budget and peak merging +n_sites: 500000 +# Draw this many length-matched negatives per positive. >1 gives mokapot's +# FDR/q-value estimator and triqler's spline a richer decoy pool. Training +# itself still sees balance_train:true downsampling if enabled. +neg_multiplier: 10 +merge_distance: 147 + +# Regions always excluded from training (gaps, SDs, etc.) +exclude_beds: + - resources/exclude/hg38.gap.bed.gz + - resources/exclude/SDs.merged.hg38.bed.gz + +# Chromosome name regex used to drop alt/random/MT/sex chroms from training +exclude_chroms_regex: "_|chrEBV|chrMT|chrX|chrY|chrM" + +# Defaults applied to every experiment unless overridden in its own `train:` block +train_defaults: + min_msp_length_for_positive_fire_call: 85 + min_msp_length_for_negative_fire_call: 85 + train_fdr: 0.05 + test_fdr: 0.05 + subset_max_train: 10000000 + direction: msp_len_times_m6a_fc + grid_search: true + n_estimators: [200, 300] + max_depth: [9, 15] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + # Lower eta + more trees + early stopping generally improves model quality. + # Defaults preserve legacy behavior (eta=0.3, no early stopping, balanced train, + # mokapot max_iter=15). Override per-experiment to opt in. + learning_rate: [0.3] + early_stopping_rounds: 0 # >0 enables early stopping against a held-out val split + early_stopping_val_frac: 0.15 # fraction of train held out for the ES val set + balance_train: true # downsample majority class in train; set false to let scale_pos_weight handle imbalance + mokapot_max_iter: 15 # number of mokapot iterative refit rounds + mokapot_override: false # if true, downgrade "model performs worse than direction" from error to warning + +# Trackhub metadata +trackhub: + name: fire-models + short_label: "FIRE model sweep" + long_label: "FIRE model sweep across training configurations" + email: mrvollger@genetics.utah.edu + +# ------------------------------------------------------------------ +# Region sets. A `region_set` is the (positives, negative_exclude_beds) +# input pair. Every experiment references one by name, and the whole +# regions + training-data pipeline runs ONCE per region_set (not per +# experiment) so we don't redo identical bedtools work. +# +# positive_beds: used for +1 labels +# negative_exclude_beds: NOT used as positives, but their regions are +# also excluded when sampling negatives so we +# don't place negatives on top of likely signal +# ------------------------------------------------------------------ +region_sets: + default: + positive_beds: + - path: resources/mixed-positives/GM12878-fire-v0.1-peaks-20-percent.bed.gz + - path: resources/mixed-positives/GM12878_DHS.bed.gz + - path: resources/mixed-positives/hotspots_ENCFF452DZE.bed.gz + - path: resources/mixed-positives/hotspots_ENCFF828AUX.bed.gz + - path: resources/mixed-positives/peaks_CTCF_ENCFF951PEM.bed.gz + - path: resources/mixed-positives/peaks_ENCFF073ORT.bed.gz + - path: resources/mixed-positives/peaks_ENCFF598KWZ.bed.gz + negative_exclude_beds: + - path: resources/mixed-positives/ATAC.bed.gz + - path: resources/mixed-positives/GM12878-fire-v0.1-peaks.bed.gz + + +# ------------------------------------------------------------------ +# Experiments. Each is a distinct training config. +# `region_set:` names an entry above. Experiments that share a +# region_set share positives/negatives/features/training-data. +# ------------------------------------------------------------------ +experiments: + + # Baseline: train_defaults (moderate grid search over n_estimators/max_depth/etc.) + default_grid: + region_set: default + + default_grid_unbalanced: + region_set: default + train: + balance_train: false + train_fdr: 0.20 + + # Fast, non-CV baseline: a single shallow XGBoost model with no grid search. + # Useful reference for how much the grid search is buying. This config is + # weak enough that XGBoost may not beat the `direction` feature alone, so + # we allow mokapot's "model performs worse" check to be a warning here. + shallow_fixed: + region_set: default + train: + grid_search: false + n_estimators: [100] + max_depth: [6] + min_child_weight_fracs: [0.01] + colsample_bytree: [1.0] + gamma: [1] + mokapot_override: true + + # Higher-capacity grid: deeper trees + more boosters. + # Risk is overfitting; compare to default_grid to see if capacity helps. + deep_wide_grid: + region_set: default + train: + n_estimators: [400, 600] + max_depth: [15, 21] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + + deep_wide_grid_unbalanced: + region_set: default + train: + n_estimators: [400, 600] + max_depth: [15, 21] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + balance_train: false + train_fdr: 0.20 + + # Strong-regularization grid: higher gamma + bigger leaf hessian, + # smaller column sample. Tests the overfitting hypothesis. + strong_reg_grid: + region_set: default + train: + n_estimators: [200, 300] + max_depth: [9] + min_child_weight_fracs: [0.005, 0.02] + colsample_bytree: [0.5, 0.75] + gamma: [5, 10] + + strong_reg_grid_unbalanced: + region_set: default + train: + n_estimators: [200, 300] + max_depth: [9] + min_child_weight_fracs: [0.005, 0.02] + colsample_bytree: [0.5, 0.75] + gamma: [5, 10] + balance_train: false + train_fdr: 0.20 + + # Slow-learn grid: low eta with a large n_estimators cap, early stopping + # picks the actual tree count per fold. + slow_learn_es: + region_set: default + train: + n_estimators: [2000] + max_depth: [9, 15] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + learning_rate: [0.05, 0.1] + early_stopping_rounds: 50 + + slow_learn_es_unbalanced: + region_set: default + train: + n_estimators: [2000] + max_depth: [9, 15] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + learning_rate: [0.05, 0.1] + early_stopping_rounds: 50 + balance_train: false + train_fdr: 0.20 + + # Keep all training data (no balance downsample) and give mokapot more + # refit iterations. scale_pos_weight handles class imbalance inside XGBoost. + full_data_long_mokapot: + region_set: default + train: + n_estimators: [2000] + max_depth: [9, 15] + min_child_weight_fracs: [0.001, 0.005] + colsample_bytree: [0.5, 1.0] + gamma: [1] + learning_rate: [0.05] + early_stopping_rounds: 50 + balance_train: false + mokapot_max_iter: 30 + train_fdr: 0.20 diff --git a/Train-FIRE/pixi.toml b/Train-FIRE/pixi.toml new file mode 100644 index 00000000..3af8541c --- /dev/null +++ b/Train-FIRE/pixi.toml @@ -0,0 +1,39 @@ +[workspace] +authors = ["Mitchell R. Vollger "] +channels = ["conda-forge", "bioconda"] +description = "FIRE model training sweep + trackhub" +name = "Train-FIRE" +platforms = ["osx-64", "linux-64"] +version = "0.1.0" + +[tasks] +fmt = "ruff format . && taplo format pixi.toml && snakefmt workflow/" +dry-run = { cmd = ["snakemake", "--configfile", "config/config.yaml", "-n"] } +# End-to-end smoke test on the tiny .tests/ dataset: ~150 chr19 reads at +# the NAPA locus (positives) + ~250 reads from a quiet chr19 region at +# ~chr19:9.09-9.11Mb (negatives, no overlap with any peak/gap/SD). +# Runs regions -> features -> training-data -> xgboost+mokapot end-to-end. +# First invocation materializes the conda env. +test = { cmd = [ + "snakemake", + "--configfile", + ".tests/config.yaml", + "-j", + "4", + "--use-conda", + "--rerun-incomplete", + "train_only", +] } +test-clean = "rm -rf results .snakemake" +# Primary production entrypoint (also lets users `cd` into the project root). +snakemake = { cmd = "cd $INIT_CWD && snakemake --configfile $PIXI_PROJECT_ROOT/config/config.yaml -s $PIXI_PROJECT_ROOT/workflow/Snakefile" } + + +[dependencies] +conda = "*" +snakemake = "==9.17.1" +snakefmt = "*" +ruff = "*" +taplo = "*" +snakemake-executor-plugin-slurm = "*" +snakemake-logger-plugin-snkmt = "*" diff --git a/Train-FIRE/resources/exclude/SDs.merged.hg38.bed.gz b/Train-FIRE/resources/exclude/SDs.merged.hg38.bed.gz new file mode 100644 index 00000000..a6f8fd32 Binary files /dev/null and b/Train-FIRE/resources/exclude/SDs.merged.hg38.bed.gz differ diff --git a/Train-FIRE/resources/exclude/hg38.gap.bed.gz b/Train-FIRE/resources/exclude/hg38.gap.bed.gz new file mode 100644 index 00000000..b05ad873 Binary files /dev/null and b/Train-FIRE/resources/exclude/hg38.gap.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/ATAC.bed.gz b/Train-FIRE/resources/mixed-positives/ATAC.bed.gz new file mode 100644 index 00000000..f9c66a17 Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/ATAC.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks-20-percent.bed.gz b/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks-20-percent.bed.gz new file mode 100644 index 00000000..905cfd92 Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks-20-percent.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks.bed.gz b/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks.bed.gz new file mode 100644 index 00000000..2618268f Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/GM12878-fire-v0.1-peaks.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/GM12878_DHS.bed.gz b/Train-FIRE/resources/mixed-positives/GM12878_DHS.bed.gz new file mode 100644 index 00000000..df9f7e93 Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/GM12878_DHS.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/download.sh b/Train-FIRE/resources/mixed-positives/download.sh new file mode 100755 index 00000000..7f8a78d8 --- /dev/null +++ b/Train-FIRE/resources/mixed-positives/download.sh @@ -0,0 +1,9 @@ +#GM12878 hotspots: +wget https://www.encodeproject.org/files/ENCFF828AUX/@@download/ENCFF828AUX.bed.gz -O hotspots_ENCFF828AUX.bed.gz +wget https://www.encodeproject.org/files/ENCFF452DZE/@@download/ENCFF452DZE.bed.gz -O hotspots_ENCFF452DZE.bed.gz +#GM12878 peaks: +wget https://www.encodeproject.org/files/ENCFF598KWZ/@@download/ENCFF598KWZ.bed.gz -O peaks_ENCFF598KWZ.bed.gz +wget https://www.encodeproject.org/files/ENCFF073ORT/@@download/ENCFF073ORT.bed.gz -O peaks_ENCFF073ORT.bed.gz +# CTCF peaks +wget https://www.encodeproject.org/files/ENCFF951PEM/@@download/ENCFF951PEM.bed.gz -O peaks_CTCF_ENCFF951PEM.bed.gz + diff --git a/Train-FIRE/resources/mixed-positives/hotspots_ENCFF452DZE.bed.gz b/Train-FIRE/resources/mixed-positives/hotspots_ENCFF452DZE.bed.gz new file mode 100644 index 00000000..c107b5de Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/hotspots_ENCFF452DZE.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/hotspots_ENCFF828AUX.bed.gz b/Train-FIRE/resources/mixed-positives/hotspots_ENCFF828AUX.bed.gz new file mode 100644 index 00000000..74cc1bda Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/hotspots_ENCFF828AUX.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/peaks_CTCF_ENCFF951PEM.bed.gz b/Train-FIRE/resources/mixed-positives/peaks_CTCF_ENCFF951PEM.bed.gz new file mode 100644 index 00000000..452b3fde Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/peaks_CTCF_ENCFF951PEM.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/peaks_ENCFF073ORT.bed.gz b/Train-FIRE/resources/mixed-positives/peaks_ENCFF073ORT.bed.gz new file mode 100644 index 00000000..ca4a10b3 Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/peaks_ENCFF073ORT.bed.gz differ diff --git a/Train-FIRE/resources/mixed-positives/peaks_ENCFF598KWZ.bed.gz b/Train-FIRE/resources/mixed-positives/peaks_ENCFF598KWZ.bed.gz new file mode 100644 index 00000000..ffa92c5f Binary files /dev/null and b/Train-FIRE/resources/mixed-positives/peaks_ENCFF598KWZ.bed.gz differ diff --git a/Train-FIRE/workflow/Snakefile b/Train-FIRE/workflow/Snakefile new file mode 100644 index 00000000..f239c01a --- /dev/null +++ b/Train-FIRE/workflow/Snakefile @@ -0,0 +1,66 @@ +from snakemake.utils import min_version + +min_version("9.17") + + +# Config is always supplied at the CLI via `--configfile`, so a test run +# with `--configfile .tests/config.yaml` fully replaces experiments + +# region_sets instead of deep-merging with the production config. +if not config: + raise WorkflowError( + "no config loaded โ€” pass --configfile config/config.yaml " + "(or --configfile .tests/config.yaml for the smoke test)" + ) + + +EXPERIMENTS = list(config["experiments"].keys()) +REGION_SET_NAMES = list(config.get("region_sets", {}).keys()) +MODELS = ["baseline"] + EXPERIMENTS + + +wildcard_constraints: + exp="|".join(EXPERIMENTS), + model="|".join(MODELS), + rs="|".join(REGION_SET_NAMES) if REGION_SET_NAMES else "NEVER_MATCHES_ANYTHING", + chrom=r"[A-Za-z0-9_.-]+", + + +include: "rules/common.smk" +include: "rules/resources.smk" +include: "rules/regions.smk" +include: "rules/features.smk" +include: "rules/train.smk" +include: "rules/apply.smk" +include: "rules/trackhub.smk" + + +rule all: + input: + expand("results/experiments/{exp}/FIRE.gbdt.json", exp=EXPERIMENTS), + expand("results/experiments/{exp}/FIRE.FDR.pdf", exp=EXPERIMENTS), + expand("results/models/{model}/fire-fibers.bb", model=MODELS), + expand("results/models/{model}/fire-fiber-decorators.bb", model=MODELS), + "results/comparison/fdr_overlay.pdf", + "results/comparison/metrics.tsv", + "results/trackhub/hub.txt", + "results/trackhub/trackDb.txt", + + +rule train_only: + input: + expand("results/experiments/{exp}/FIRE.gbdt.json", exp=EXPERIMENTS), + expand("results/experiments/{exp}/FIRE.FDR.pdf", exp=EXPERIMENTS), + "results/comparison/fdr_overlay.pdf", + + +rule data_only: + """Pipeline smoke target: runs everything up to and including the +per-region_set training-data.bed.gz. Skips XGBoost/mokapot training, which +requires real multi-locus features to produce a usable model.""" + input: + expand("results/region_sets/{rs}/training-data.bed.gz", rs=REGION_SET_NAMES), + + +rule clean: + shell: + "rm -rf results" diff --git a/Train-FIRE/workflow/envs/env.yml b/Train-FIRE/workflow/envs/env.yml new file mode 100644 index 00000000..30f1f547 --- /dev/null +++ b/Train-FIRE/workflow/envs/env.yml @@ -0,0 +1,18 @@ +channels: + - conda-forge + - bioconda +dependencies: + - bedtools + - samtools ==1.21 + - htslib ==1.21 + - ucsc-bedtobigbed + - rustybam ==0.1.34 + - fibertools-rs + - python ==3.11 + - numpy + - pandas + - scikit-learn + - matplotlib + - xgboost + - mokapot ==0.9 + - pysam diff --git a/Train-FIRE/workflow/profiles/default/config.yaml b/Train-FIRE/workflow/profiles/default/config.yaml new file mode 100644 index 00000000..710df496 --- /dev/null +++ b/Train-FIRE/workflow/profiles/default/config.yaml @@ -0,0 +1,11 @@ +rerun-incomplete: True +show-failed-logs: True +rerun-triggers: mtime +restart-times: 0 +software-deployment-method: + - apptainer + - conda +printshellcmds: True +cores: 32 +local-cores: 4 +logger: snkmt diff --git a/Train-FIRE/workflow/rules/apply.smk b/Train-FIRE/workflow/rules/apply.smk new file mode 100644 index 00000000..2b5f81b7 --- /dev/null +++ b/Train-FIRE/workflow/rules/apply.smk @@ -0,0 +1,128 @@ +rule test_region_bam: + """Extract test_region (default chr19) from the test alignment. Streams directly from s3 when given an s3:// URL.""" + input: + aln=aln_input(TEST_BAM), + fasta=FASTA, + output: + bam="results/shared/test_region.bam", + csi="results/shared/test_region.bam.csi", + conda: + "../envs/env.yml" + threads: 8 + resources: + mem_mb=get_mem_mb, + params: + src=TEST_BAM, + region=TEST_REGION, + shell: + r""" + samtools view -@ {threads} -b --reference {input.fasta} \ + {params.src} {params.region} \ + -o {output.bam} --write-index + """ + + +rule apply_fire_bam: + """Annotate the test-region BAM with FIRE calls from this experiment's model.""" + input: + bam="results/shared/test_region.bam", + model="results/experiments/{exp}/FIRE.gbdt.json", + conf="results/experiments/{exp}/FIRE.conf.json", + output: + bam="results/experiments/{exp}/fire.bam", + csi="results/experiments/{exp}/fire.bam.csi", + conda: + "../envs/env.yml" + threads: 16 + resources: + mem_mb=get_mem_mb, + shell: + r""" + ft fire -t {threads} \ + --model {input.model} \ + --fdr-table {input.conf} \ + {input.bam} \ + | samtools sort -@ {threads} --write-index -o {output.bam}##idx##{output.csi} - + """ + + +def fire_bam_for_decorate(wildcards): + """ + Source BAM for ft track-decorators: + - baseline -> the pre-annotated input CRAM subset to test_region + (FIRE calls from whatever model was baked into the input) + - per-experiment -> apply_fire_bam output with the freshly trained model + """ + if wildcards.model == "baseline": + return "results/shared/test_region.bam" + return f"results/experiments/{wildcards.model}/fire.bam" + + +rule decorate_fibers: + """Produce the base bed12 track + decorator overlay via `ft track-decorators`.""" + input: + bam=fire_bam_for_decorate, + output: + base="results/models/{model}/fire-fibers.bed.gz", + dec="results/models/{model}/fire-fiber-decorators.bed.gz", + base_unsorted=temp("results/models/{model}/fire-fibers.unsorted.bed"), + conda: + "../envs/env.yml" + threads: 8 + resources: + mem_mb=get_mem_mb, + shell: + r""" + samtools view -@ {threads} -u {input.bam} \ + | ft track-decorators -t {threads} --bed12 {output.base_unsorted} \ + | grep -v '^#' | grep -vw 'NUC' \ + | sort -k1,1 -k2,2n -k3,3n -k4,4 \ + | bgzip -@ {threads} > {output.dec} + + sort -k1,1 -k2,2n -k3,3n -k4,4 {output.base_unsorted} \ + | bgzip -@ {threads} > {output.base} + """ + + +rule base_to_bigbed: + """bigBed 12+ for the FIRE-fibers base track.""" + input: + bed="results/models/{model}/fire-fibers.bed.gz", + sizes="results/shared/chrom.sizes", + bed_as=workflow.source_path("../templates/bed12_filter.as"), + output: + bb="results/models/{model}/fire-fibers.bb", + bed=temp("results/models/{model}/fire-fibers.bed"), + conda: + "../envs/env.yml" + threads: 4 + resources: + mem_mb=get_mem_mb, + shell: + r""" + zcat -f {input.bed} > {output.bed} + bedToBigBed -allow1bpOverlap -type=bed12+ -as={input.bed_as} \ + {output.bed} {input.sizes} {output.bb} + """ + + +rule decorator_to_bigbed: + """bigBed 12+ decorator overlay track.""" + input: + bed="results/models/{model}/fire-fiber-decorators.bed.gz", + sizes="results/shared/chrom.sizes", + dec_as=workflow.source_path("../templates/decoration.as"), + output: + bb="results/models/{model}/fire-fiber-decorators.bb", + bed=temp("results/models/{model}/fire-fiber-decorators.bed"), + conda: + "../envs/env.yml" + threads: 4 + resources: + mem_mb=get_mem_mb, + shell: + r""" + zcat -f {input.bed} > {output.bed} + bedToBigBed -allow1bpOverlap -type=bed12+ -as={input.dec_as} \ + {output.bed} {input.sizes} {output.bb} + """ diff --git a/Train-FIRE/workflow/rules/common.smk b/Train-FIRE/workflow/rules/common.smk new file mode 100644 index 00000000..5b6373db --- /dev/null +++ b/Train-FIRE/workflow/rules/common.smk @@ -0,0 +1,155 @@ +import os +import re +import shlex + + +def get_mem_mb(wildcards, attempt): + if attempt < 3: + return attempt * 1024 * 8 + return attempt * 1024 * 16 + + +def _expand(p): + return os.path.expanduser(os.path.expandvars(str(p))) + + +FAI = _expand(config["reference"]["fai"]) +FASTA = _expand(config["reference"]["fasta"]) +GENOME = config["reference"]["genome"] + +S3_ENDPOINT = config.get("s3_endpoint", None) +SAMPLE_RATE = float(config["sample_rate"]) + + +def is_remote(p): + s = str(p) + return s.startswith(("s3://", "http://", "https://")) + + +def resolve_alignment(src): + """ + Convert an s3://bucket/key URL into https:///bucket/key + using the configured s3_endpoint. htslib can read the resulting HTTPS + URL directly via libcurl with no auth, which works for public buckets + where its S3 transport fails ('Resource temporarily unavailable'). + Local paths and existing http(s) URLs pass through unchanged. + """ + s = _expand(src) + if s.startswith("s3://"): + if not S3_ENDPOINT: + raise ValueError("training_bam is s3:// but config.s3_endpoint is not set") + return S3_ENDPOINT.rstrip("/") + "/" + s[len("s3://") :] + return s + + +TRAINING_BAM = resolve_alignment(config["training_bam"]) +TEST_BAM = resolve_alignment(config["test_bam"] or config["training_bam"]) + + +def aln_input(src): + """If local, track as a snakemake input; if remote (http/https), skip.""" + return [] if is_remote(src) else [src] + + +TEST_REGION = str(config["test_region"]) +N_SITES = int(config["n_sites"]) +NEG_MULTIPLIER = int(config.get("neg_multiplier", 1)) +MERGE_DIST = int(config["merge_distance"]) +EXCLUDE_CHROMS_RE = config["exclude_chroms_regex"] +EXCLUDE_BEDS = [_expand(p) for p in config["exclude_beds"]] + + +def features_shards(wildcards): + """Resolve per-chrom feature shards at DAG-eval time so we can read the + chrom list from the FAI even when fetch_reference produces it.""" + chroms_file = checkpoints.build_chrom_list.get(**wildcards).output.chroms + with open(chroms_file) as fh: + chroms = [line.strip() for line in fh if line.strip()] + return expand("results/shared/features_per_chrom/{chrom}.tsv.gz", chrom=chroms) + + +TRAIN_DEFAULTS = config["train_defaults"] + + +REGION_SETS = config.get("region_sets", {}) + + +def rs_cfg(rs): + return REGION_SETS[rs] + + +def rs_positive_specs(rs): + """ + Return list of (path, awk_filter_or_None) for a region_set's positive beds. + Accepts either bare strings or {path:, awk_filter:} dicts in config. + """ + out = [] + for item in rs_cfg(rs)["positive_beds"]: + if isinstance(item, str): + out.append((_expand(item), None)) + else: + out.append((_expand(item["path"]), item.get("awk_filter"))) + return out + + +def rs_positive_paths(rs): + return [p for p, _ in rs_positive_specs(rs)] + + +def positive_source_cmds(rs): + """ + Bash snippet that concatenates every positive bed, applying any + per-source awk_filter, and emits 3-column output. Uses `zcat -f` + so both gzipped and plain-text beds work transparently. + """ + parts = [] + for path, awkf in rs_positive_specs(rs): + q = shlex.quote(path) + if awkf: + parts.append(f"zcat -f {q} | awk {shlex.quote(awkf)} | cut -f1-3") + else: + parts.append(f"zcat -f {q} | cut -f1-3") + return "; ".join(parts) + + +def rs_negative_exclude_paths(rs): + """Extra beds whose regions are excluded from negative sampling but not used as positives.""" + out = [] + for item in rs_cfg(rs).get("negative_exclude_beds", []) or []: + out.append(_expand(item) if isinstance(item, str) else _expand(item["path"])) + return out + + +def rs_neg_mask_paths(rs): + """Union of positives + negative_exclude_beds used to keep negatives off signal.""" + return rs_positive_paths(rs) + rs_negative_exclude_paths(rs) + + +def exp_cfg(exp): + return config["experiments"][exp] + + +def exp_region_set(exp): + """Name of the region_set this experiment draws positives/negatives from.""" + rs = exp_cfg(exp).get("region_set") + if not rs: + raise ValueError(f"experiment '{exp}' is missing required field 'region_set'") + if rs not in REGION_SETS: + raise ValueError(f"experiment '{exp}' references undefined region_set '{rs}'") + return rs + + +def exp_train_params(exp): + """Merge train_defaults with per-experiment overrides.""" + params = dict(TRAIN_DEFAULTS) + params.update(exp_cfg(exp).get("train", {}) or {}) + return params + + +def all_rs_region_files(): + """Union of regions across region_sets (positives + negatives) needed for feature extraction.""" + paths = [] + for rs in REGION_SETS.keys(): + paths.append(f"results/region_sets/{rs}/positives.bed.gz") + paths.append(f"results/region_sets/{rs}/negatives.bed.gz") + return paths diff --git a/Train-FIRE/workflow/rules/features.smk b/Train-FIRE/workflow/rules/features.smk new file mode 100644 index 00000000..d000cd75 --- /dev/null +++ b/Train-FIRE/workflow/rules/features.smk @@ -0,0 +1,144 @@ +rule sample_bam: + """Fractional subsample of the training alignment (local or remote URL) into a BAM.""" + input: + aln=aln_input(TRAINING_BAM), + fasta=FASTA, + output: + bam="results/shared/sample.bam", + csi="results/shared/sample.bam.csi", + conda: + "../envs/env.yml" + threads: 16 + resources: + mem_mb=get_mem_mb, + params: + src=TRAINING_BAM, + rate=SAMPLE_RATE, + shell: + r""" + samtools view -@ {threads} -b -s {params.rate} \ + --reference {input.fasta} \ + {params.src} \ + -o {output.bam} --write-index + """ + + +rule union_regions: + """Concat every region_set's positives + negatives so we extract features from the BAM exactly once.""" + input: + beds=all_rs_region_files(), + fai=FAI, + output: + bed="results/shared/union_regions.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + shell: + r""" + zcat -f {input.beds} \ + | cut -f1-3 \ + | bedtools sort -g {input.fai} \ + | bedtools merge \ + | bgzip > {output.bed} + """ + + +rule regions_bam: + """Subset the sampled BAM to reads overlapping any training region. Indexes +the output so extract_features can scatter samtools views by chromosome.""" + input: + bam="results/shared/sample.bam", + regions="results/shared/union_regions.bed.gz", + output: + bam="results/shared/regions.bam", + csi="results/shared/regions.bam.csi", + conda: + "../envs/env.yml" + threads: 16 + resources: + mem_mb=get_mem_mb, + shell: + r""" + samtools view -@ {threads} -b -M \ + -L <(zcat -f {input.regions}) \ + {input.bam} -o {output.bam} --write-index + """ + + +rule extract_features_chrom: + """Run ft fire -f on a single chromosome's reads. Scattering across chroms +lets Snakemake schedule parallel jobs and avoids the single BAM reader +bottleneck of one ft fire process streaming the whole BAM.""" + input: + bam="results/shared/regions.bam", + csi="results/shared/regions.bam.csi", + output: + tsv=temp("results/shared/features_per_chrom/{chrom}.tsv.gz"), + conda: + "../envs/env.yml" + threads: 4 + resources: + mem_mb=get_mem_mb, + params: + min_msp=TRAIN_DEFAULTS["min_msp_length_for_positive_fire_call"], + shell: + r""" + samtools view -@ 1 -u {input.bam} {wildcards.chrom} \ + | ft fire -t {threads} --min-msp-length-for-positive-fire-call {params.min_msp} -f - \ + | awk 'NR == 1 || ($3 > $2 && $3 - $2 < 10000)' \ + | bgzip -@ 2 > {output.tsv} + """ + + +rule extract_features: + """Gather the per-chrom feature tables into one. awk enforces a single +header line so repeated headers from per-chrom shards are deduped, and +empty shards (chroms with no reads) contribute nothing.""" + input: + shards=features_shards, + output: + feats="results/shared/features.tsv.gz", + conda: + "../envs/env.yml" + threads: 4 + resources: + mem_mb=get_mem_mb, + shell: + r""" + for f in {input.shards}; do + [ -s "$f" ] || continue + bgzip -dc "$f" + done \ + | awk 'NR == 1 {{ hdr=$0; print; next }} $0 != hdr {{ print }}' \ + | bgzip -@ {threads} > {output.feats} + """ + + +rule build_training_data: + """Label features: +1 if row overlaps positives (f>=0.25); -1 if overlaps negatives and no unfiltered positive.""" + input: + feats="results/shared/features.tsv.gz", + positives="results/region_sets/{rs}/positives.bed.gz", + negatives="results/region_sets/{rs}/negatives.bed.gz", + mask="results/region_sets/{rs}/neg_mask.bed.gz", + output: + bed="results/region_sets/{rs}/training-data.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + shell: + r""" + header=$(zcat -f {input.feats} | head -n 1 || true) + ( + printf '%s\tLabel\n' "$header" + bedtools intersect -f 0.25 -u -a {input.feats} -b {input.positives} \ + | sed 's/$/\t1/' | awk '/^chr/' + bedtools intersect -f 0.25 -u -a {input.feats} -b {input.negatives} \ + | bedtools intersect -v -a - -b {input.mask} \ + | sed 's/$/\t-1/' | awk '/^chr/' + ) | bgzip -@ 8 > {output.bed} + echo "[{wildcards.rs}] training label counts:" >&2 + zcat -f {output.bed} | tail -n +2 | awk -F'\t' '{{print $NF}}' | sort | uniq -c >&2 + """ diff --git a/Train-FIRE/workflow/rules/regions.smk b/Train-FIRE/workflow/rules/regions.smk new file mode 100644 index 00000000..717540b1 --- /dev/null +++ b/Train-FIRE/workflow/rules/regions.smk @@ -0,0 +1,138 @@ +rule build_exclude: + """Union of user-supplied exclusion beds + regex-excluded chroms + the held-out test chromosome.""" + input: + beds=EXCLUDE_BEDS, + fai=FAI, + output: + bed="results/shared/exclude.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + chrom_re=EXCLUDE_CHROMS_RE, + test_region=TEST_REGION, + shell: + r""" + ( \ + zcat -f {input.beds} | cut -f1-3; \ + awk -v OFS='\t' 'BEGIN{{re="{params.chrom_re}"}} $1 ~ re {{print $1,0,$2}}' {input.fai}; \ + awk -v OFS='\t' '$1 == "{params.test_region}" {{print $1,0,$2}}' {input.fai} \ + ) \ + | cut -f1-3 \ + | awk 'NR==FNR{{v[$1]; next}} ($1 in v)' {input.fai} - \ + | bedtools sort -g {input.fai} \ + | bedtools merge \ + | bgzip > {output.bed} + """ + + +rule build_positives: + """Union region_set's positive beds (with optional awk_filter), merge, subsample to n_sites, subtract exclude.""" + input: + beds=lambda wc: rs_positive_paths(wc.rs), + exclude="results/shared/exclude.bed.gz", + fai=FAI, + output: + bed="results/region_sets/{rs}/positives.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + n_sites=N_SITES, + merge_dist=MERGE_DIST, + sources=lambda wc: positive_source_cmds(wc.rs), + shell: + r""" + ( {params.sources} ) \ + | awk 'NR==FNR{{v[$1]; next}} ($1 in v)' {input.fai} - \ + | bedtools sort -g {input.fai} \ + | bedtools merge -d {params.merge_dist} \ + | (shuf | head -n {params.n_sites} || true) \ + | bedtools sort -g {input.fai} \ + | bedtools subtract -a - -b {input.exclude} \ + | bgzip > {output.bed} + printf "[{wildcards.rs}] positives: " >&2 + zcat -f {output.bed} | wc -l >&2 + """ + + +rule build_neg_mask: + """Union of positives + negative_exclude_beds; negatives will be kept off these regions.""" + input: + beds=lambda wc: rs_neg_mask_paths(wc.rs), + exclude="results/shared/exclude.bed.gz", + fai=FAI, + output: + bed="results/region_sets/{rs}/neg_mask.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + merge_dist=MERGE_DIST, + shell: + r""" + zcat -f {input.beds} \ + | cut -f1-3 \ + | awk 'NR==FNR{{v[$1]; next}} ($1 in v)' {input.fai} - \ + | bedtools sort -g {input.fai} \ + | bedtools merge -d {params.merge_dist} \ + | bedtools subtract -a - -b {input.exclude} \ + | bgzip > {output.bed} + """ + + +rule build_complement_negatives: + """Everything not in the negative-exclusion mask, minus exclude.""" + input: + mask="results/region_sets/{rs}/neg_mask.bed.gz", + exclude="results/shared/exclude.bed.gz", + fai=FAI, + output: + bed="results/region_sets/{rs}/complement_negatives.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + shell: + r""" + bedtools complement -i {input.mask} -g {input.fai} \ + | bedtools subtract -a - -b {input.exclude} \ + | bgzip > {output.bed} + """ + + +rule build_negatives: + """Shuffle positives into the complement so negatives length-match positives on the same chromosome. + Runs the shuffle neg_multiplier times with distinct seeds so the decoy pool is K x positives, giving + mokapot's FDR estimator and triqler's spline a richer null distribution. Seeds' outputs are concatenated + without merging so the final count is ~K x positives.""" + input: + positives="results/region_sets/{rs}/positives.bed.gz", + mask="results/region_sets/{rs}/neg_mask.bed.gz", + complement="results/region_sets/{rs}/complement_negatives.bed.gz", + fai=FAI, + output: + bed="results/region_sets/{rs}/negatives.bed.gz", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + neg_multiplier=NEG_MULTIPLIER, + shell: + r""" + for seed in $(seq 42 $((42 + {params.neg_multiplier} - 1))); do + bedtools shuffle \ + -excl {input.mask} \ + -incl {input.complement} \ + -i {input.positives} \ + -chrom -seed "$seed" -g {input.fai} + done \ + | sort -k1,1 -k2,2n \ + | bgzip > {output.bed} + printf "[{wildcards.rs}] negatives: " >&2 + zcat -f {output.bed} | wc -l >&2 + """ diff --git a/Train-FIRE/workflow/rules/resources.smk b/Train-FIRE/workflow/rules/resources.smk new file mode 100644 index 00000000..b0eaafd8 --- /dev/null +++ b/Train-FIRE/workflow/rules/resources.smk @@ -0,0 +1,55 @@ +rule chrom_sizes: + """UCSC chrom.sizes from the .fai (cols 1, 2). Shared across rules.""" + input: + fai=FAI, + output: + sizes="results/shared/chrom.sizes", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + shell: + r"cut -f1,2 {input.fai} > {output.sizes}" + + +checkpoint build_chrom_list: + """Snapshot the chroms that actually have reads in regions.bam so +extract_features can scatter per-chrom. Deriving from idxstats (vs the FAI) +avoids launching empty jobs for chroms with no training-region coverage.""" + input: + bam="results/shared/regions.bam", + csi="results/shared/regions.bam.csi", + output: + chroms="results/shared/chroms.txt", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + exclude_re=EXCLUDE_CHROMS_RE, + shell: + r""" + samtools idxstats {input.bam} \ + | awk '$1 != "*" && $3 > 0 {{ print $1 }}' \ + | grep -Ev {params.exclude_re:q} > {output.chroms} + """ + + +rule fetch_reference: + """Download hg38 analysis set from UCSC (once) and index it.""" + output: + fa="resources/ref/hg38.analysisSet.fa", + fai="resources/ref/hg38.analysisSet.fa.fai", + conda: + "../envs/env.yml" + threads: 4 + resources: + mem_mb=get_mem_mb, + params: + url="https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/analysisSet/hg38.analysisSet.fa.gz", + shell: + r""" + curl -L --fail --retry 3 {params.url} \ + | gunzip -c > {output.fa} + samtools faidx {output.fa} + """ diff --git a/Train-FIRE/workflow/rules/trackhub.smk b/Train-FIRE/workflow/rules/trackhub.smk new file mode 100644 index 00000000..4cd4e672 --- /dev/null +++ b/Train-FIRE/workflow/rules/trackhub.smk @@ -0,0 +1,34 @@ +rule build_trackhub: + input: + base=expand("results/models/{model}/fire-fibers.bb", model=MODELS), + dec=expand("results/models/{model}/fire-fiber-decorators.bb", model=MODELS), + sizes="results/shared/chrom.sizes", + output: + hub="results/trackhub/hub.txt", + genomes="results/trackhub/genomes.txt", + trackdb="results/trackhub/trackDb.txt", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + name=config["trackhub"]["name"], + short=config["trackhub"]["short_label"], + long=config["trackhub"]["long_label"], + email=config["trackhub"]["email"], + genome=GENOME, + models=MODELS, + script=workflow.source_path("../scripts/make_trackhub.py"), + shell: + r""" + python {params.script} \ + --hub-dir results/trackhub \ + --name "{params.name}" \ + --short-label "{params.short}" \ + --long-label "{params.long}" \ + --email "{params.email}" \ + --genome {params.genome} \ + --models {params.models} \ + --results-root results/models \ + --chrom-sizes {input.sizes} + """ diff --git a/Train-FIRE/workflow/rules/train.smk b/Train-FIRE/workflow/rules/train.smk new file mode 100644 index 00000000..808f68fe --- /dev/null +++ b/Train-FIRE/workflow/rules/train.smk @@ -0,0 +1,74 @@ +rule train_model: + """Train one FIRE model on its region_set's training-data.bed.gz.""" + input: + training=lambda wc: f"results/region_sets/{exp_region_set(wc.exp)}/training-data.bed.gz", + output: + gbdt="results/experiments/{exp}/FIRE.gbdt.json", + xgb="results/experiments/{exp}/FIRE.xgb.json", + conf="results/experiments/{exp}/FIRE.conf.json", + fdr_pdf="results/experiments/{exp}/FIRE.FDR.pdf", + feat_pdf="results/experiments/{exp}/FIRE.feature.importance.pdf", + metrics="results/experiments/{exp}/metrics.json", + conda: + "../envs/env.yml" + threads: 96 + resources: + mem_mb=get_mem_mb, + params: + outdir=lambda wc: f"results/experiments/{wc.exp}", + p=lambda wc: exp_train_params(wc.exp), + script=workflow.source_path("../scripts/train-fire-model.py"), + outer_jobs=12, + inner_jobs=8, + shell: + r""" + export OMP_NUM_THREADS={params.inner_jobs} + python {params.script} \ + {input.training} \ + --outdir {params.outdir} \ + --outer-jobs {params.outer_jobs} \ + --inner-jobs {params.inner_jobs} \ + --train-fdr {params.p[train_fdr]} \ + --test-fdr {params.p[test_fdr]} \ + --subset-max-train {params.p[subset_max_train]} \ + --direction {params.p[direction]} \ + --min-msp-length-for-positive-fire-call {params.p[min_msp_length_for_positive_fire_call]} \ + --min-msp-length-for-negative-fire-call {params.p[min_msp_length_for_negative_fire_call]} \ + $( [ "{params.p[grid_search]}" = "True" ] && echo --grid-search ) \ + --n-estimators-grid "{params.p[n_estimators]}" \ + --max-depth-grid "{params.p[max_depth]}" \ + --min-child-weight-fracs "{params.p[min_child_weight_fracs]}" \ + --colsample-bytree-grid "{params.p[colsample_bytree]}" \ + --gamma-grid "{params.p[gamma]}" \ + --learning-rate-grid "{params.p[learning_rate]}" \ + --early-stopping-rounds {params.p[early_stopping_rounds]} \ + --early-stopping-val-frac {params.p[early_stopping_val_frac]} \ + --mokapot-max-iter {params.p[mokapot_max_iter]} \ + $( [ "{params.p[balance_train]}" = "True" ] && echo --balance-train || echo --no-balance-train ) \ + $( [ "{params.p[mokapot_override]}" = "True" ] && echo --mokapot-override ) + """ + + +rule compare_models: + """Aggregate per-experiment FDR curves + metrics into a single plot + TSV.""" + input: + metrics=expand("results/experiments/{exp}/metrics.json", exp=EXPERIMENTS), + confs=expand("results/experiments/{exp}/FIRE.conf.json", exp=EXPERIMENTS), + output: + pdf="results/comparison/fdr_overlay.pdf", + tsv="results/comparison/metrics.tsv", + conda: + "../envs/env.yml" + resources: + mem_mb=get_mem_mb, + params: + exps=EXPERIMENTS, + script=workflow.source_path("../scripts/compare_models.py"), + shell: + r""" + python {params.script} \ + --root results/experiments \ + --experiments {params.exps} \ + --out-pdf {output.pdf} \ + --out-tsv {output.tsv} + """ diff --git a/Train-FIRE/workflow/scripts/compare_models.py b/Train-FIRE/workflow/scripts/compare_models.py new file mode 100644 index 00000000..38fca392 --- /dev/null +++ b/Train-FIRE/workflow/scripts/compare_models.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +"""Aggregate per-experiment FIRE.conf.json + metrics.json into one plot and TSV.""" + +import argparse +import json +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import pandas as pd + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--root", required=True, help="results/experiments dir") + ap.add_argument("--experiments", nargs="+", required=True) + ap.add_argument("--out-pdf", required=True) + ap.add_argument("--out-tsv", required=True) + args = ap.parse_args() + + root = Path(args.root) + Path(args.out_pdf).parent.mkdir(parents=True, exist_ok=True) + + rows = [] + fig, ax = plt.subplots(1, 1, figsize=(7, 5)) + + for exp in args.experiments: + conf_path = root / exp / "FIRE.conf.json" + m_path = root / exp / "metrics.json" + if not conf_path.exists() or not m_path.exists(): + continue + metrics = json.loads(m_path.read_text()) + rows.append(metrics) + + conf = json.loads(conf_path.read_text()) + df = pd.DataFrame(conf["data"], columns=conf["columns"]) + df = df.sort_values("mokapot score") + ax.plot(df["mokapot q-value"], range(len(df), 0, -1), label=exp) + + ax.set_xlabel("q-value") + ax.set_ylabel("Cumulative PSMs") + ax.set_xscale("log") + ax.legend(frameon=False, fontsize=8) + ax.set_title("FIRE model FDR curves") + plt.tight_layout() + plt.savefig(args.out_pdf) + plt.close() + + pd.DataFrame(rows).to_csv(args.out_tsv, sep="\t", index=False) + + +if __name__ == "__main__": + main() diff --git a/Train-FIRE/workflow/scripts/make_trackhub.py b/Train-FIRE/workflow/scripts/make_trackhub.py new file mode 100644 index 00000000..648871be --- /dev/null +++ b/Train-FIRE/workflow/scripts/make_trackhub.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Author: Mitchell R. Vollger +""" +Assemble a UCSC Track Hub with one decorator track per model. + +Flat layout (avoids double-nesting with hubCheck): + / + hub.txt + genomes.txt + trackDb.txt + chrom.sizes + bb/.fire-fibers.bb + bb/.fire-fiber-decorators.bb +""" + +import argparse +import shutil +from pathlib import Path + + +BASELINE_COLOR = "0,0,0" +PALETTE = [ + (166, 54, 3), + (217, 95, 14), + (54, 144, 192), + (34, 94, 168), + (5, 112, 176), + (35, 139, 69), + (116, 196, 118), + (49, 130, 189), + (140, 81, 10), + (128, 0, 128), +] + + +def color(i): + return ",".join(str(c) for c in PALETTE[i % len(PALETTE)]) + + +TRACK_TEMPLATE = """track {model} +shortLabel {short_label} +longLabel {long_label} +type bigBed 12 + +itemRgb on +visibility squish +color {color} +bigDataUrl bb/{model}.fire-fibers.bb +decorator.default.bigDataUrl bb/{model}.fire-fiber-decorators.bb +decorator.default.filterValues.keywords 5mC,m6A,NUC,LINKER,FIRE +decorator.default.filterValuesDefault.keywords LINKER,FIRE +""" + + +def render_block(model, is_baseline, palette_idx): + if is_baseline: + return TRACK_TEMPLATE.format( + model=model, + short_label="baseline (input CRAM)", + long_label="FIRE fibers from the model baked into the input CRAM (no retraining)", + color=BASELINE_COLOR, + ) + return TRACK_TEMPLATE.format( + model=model, + short_label=model, + long_label=f"FIRE fibers, trained model={model}", + color=color(palette_idx), + ) + + +def main(): + ap = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + ap.add_argument("--hub-dir", required=True) + ap.add_argument("--name", required=True) + ap.add_argument("--short-label", required=True) + ap.add_argument("--long-label", required=True) + ap.add_argument("--email", required=True) + ap.add_argument("--genome", required=True) + ap.add_argument("--models", nargs="+", required=True) + ap.add_argument("--results-root", required=True) + ap.add_argument("--chrom-sizes", required=True) + args = ap.parse_args() + + hub = Path(args.hub_dir) + (hub / "bb").mkdir(parents=True, exist_ok=True) + + shutil.copyfile(args.chrom_sizes, hub / "chrom.sizes") + + (hub / "hub.txt").write_text( + f"hub {args.name}\n" + f"shortLabel {args.short_label}\n" + f"longLabel {args.long_label}\n" + f"genomesFile genomes.txt\n" + f"email {args.email}\n" + ) + (hub / "genomes.txt").write_text(f"genome {args.genome}\ntrackDb trackDb.txt\n") + + ordered = (["baseline"] if "baseline" in args.models else []) + [ + m for m in args.models if m != "baseline" + ] + + root = Path(args.results_root) + blocks = [] + palette_idx = 0 + for model in ordered: + for name in ("fire-fibers.bb", "fire-fiber-decorators.bb"): + shutil.copyfile(root / model / name, hub / "bb" / f"{model}.{name}") + is_baseline = model == "baseline" + blocks.append(render_block(model, is_baseline, palette_idx)) + if not is_baseline: + palette_idx += 1 + + (hub / "trackDb.txt").write_text("\n".join(blocks)) + + +if __name__ == "__main__": + main() diff --git a/Train-FIRE/workflow/scripts/train-fire-model.py b/Train-FIRE/workflow/scripts/train-fire-model.py new file mode 100644 index 00000000..6b6c3861 --- /dev/null +++ b/Train-FIRE/workflow/scripts/train-fire-model.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python +""" +Train a FIRE XGBoost classifier via mokapot and emit: + /FIRE.xgb.bin + /FIRE.gbdt.json + /FIRE.conf.json + /FIRE.FDR.pdf + /FIRE.feature.importance.pdf + /metrics.json + +Adapted from Train-FIRE/train-fire-model.py with CLI-driven hyperparameters +and a configurable output directory so multiple experiments can run in parallel. +""" + +from __future__ import print_function + +import argparse +import ast +import json +import logging +import os +from pathlib import Path +from typing import List, Optional + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import mokapot +import numpy as np +import pandas as pd +import xgboost as xgb +from sklearn.model_selection import GridSearchCV, train_test_split +from xgboost import XGBClassifier + + +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) + + +class XGBEarlyStop(XGBClassifier): + """XGBClassifier that holds out an internal validation split for early stopping. + + mokapot calls estimator.fit(X, y) without an eval_set, and GridSearchCV + clones the estimator per fold, so we carve the val split off inside fit(). + """ + + def __init__(self, *, val_frac=0.15, **kwargs): + super().__init__(**kwargs) + self.val_frac = val_frac + + def get_xgb_params(self): + # val_frac is a wrapper hyperparam, not an XGBoost booster param. + # Strip it so XGBoost doesn't warn "Parameters ... are not used". + params = super().get_xgb_params() + params.pop("val_frac", None) + return params + + def fit(self, X, y, **kwargs): + stratify = y if len(np.unique(y)) > 1 else None + Xt, Xv, yt, yv = train_test_split( + X, + y, + test_size=self.val_frac, + stratify=stratify, + random_state=RANDOM_SEED, + ) + kwargs.pop("eval_set", None) + kwargs.pop("verbose", None) + return super().fit(Xt, yt, eval_set=[(Xv, yv)], verbose=False, **kwargs) + + +def parse_list(s: str, cast=float) -> List: + """Parse a CLI-supplied list like '[0.5, 1]', '0.5,1', or '0.5 1' into [cast(x), ...].""" + if s is None: + return [] + s = s.strip() + if s.startswith("["): + return [cast(x) for x in ast.literal_eval(s)] + if "," in s: + return [cast(x) for x in s.split(",") if x.strip() != ""] + return [cast(x) for x in s.split() if x.strip() != ""] + + +def convert_to_gbdt(input_model: str, output_file: str) -> int: + """Convert an xgboost model to the JSON format fibertools-rs reads (base_score + tree dump).""" + booster = xgb.Booster() + booster.load_model(input_model) + config = json.loads(booster.save_config()) + # XGBoost >=2 serializes base_score as a bracketed vector string like + # '[4.2412016E-1]'; older versions gave a bare float. Handle both. + raw = config["learner"]["learner_model_param"]["base_score"] + parsed = json.loads(raw) if isinstance(raw, str) and raw.startswith("[") else raw + base_score = float(parsed[0] if isinstance(parsed, list) else parsed) + tmp_file = output_file + ".mid" + booster.dump_model(tmp_file, dump_format="json") + with open(output_file, "w") as out: + out.write(repr(base_score) + "\n") + with open(tmp_file) as f: + out.write(f.read()) + os.remove(tmp_file) + return 0 + + +def save_model(model, test_conf, outdir: Path, max_fdr: float) -> dict: + conf_df = test_conf.confidence_estimates["psms"] + conf_df = pd.concat([conf_df, test_conf.decoy_confidence_estimates["psms"]]) + simple = ( + conf_df[["mokapot score", "mokapot q-value"]] + .drop_duplicates() + .sort_values(by=["mokapot score", "mokapot q-value"]) + ) + simple.loc[simple["mokapot q-value"] > max_fdr, "mokapot q-value"] = 1.0 + g = simple.sort_values(["mokapot q-value", "mokapot score"]).groupby( + "mokapot q-value" + ) + simple = ( + pd.concat([g.head(1), g.tail(1)]) + .drop_duplicates() + .sort_values("mokapot score") + .reset_index(drop=True) + ) + + model.estimator.save_model(str(outdir / "FIRE.xgb.json")) + convert_to_gbdt(str(outdir / "FIRE.xgb.json"), str(outdir / "FIRE.gbdt.json")) + (outdir / "FIRE.conf.json").write_text(simple.to_json(orient="split", index=False)) + + model.estimator.get_booster().feature_names = model.features + ax = xgb.plot_importance(model.estimator) + ax.figure.set_size_inches(10, 8) + ax.figure.tight_layout() + ax.figure.savefig(outdir / "FIRE.feature.importance.pdf") + plt.close(ax.figure) + + _fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + test_conf.plot_qvalues(c="#24B8A0", ax=ax, label="mokapot", threshold=max_fdr * 2) + ax.axvline(max_fdr, color="#343131", linestyle="--") + ax.legend(frameon=False) + plt.tight_layout() + plt.savefig(outdir / "FIRE.FDR.pdf") + plt.close() + + psms = test_conf.confidence_estimates["psms"] + n_at_fdr = int((psms["mokapot q-value"] <= max_fdr).sum()) + return { + "n_test_psms": int(len(psms)), + "n_test_psms_at_fdr": n_at_fdr, + "max_fdr": max_fdr, + } + + +def _make_xgb(args, scale_pos_weight, n_jobs, **fixed): + """Build a single XGB estimator, wrapping with early-stopping if enabled.""" + common = dict( + eval_metric="auc", + scale_pos_weight=scale_pos_weight, + seed=RANDOM_SEED, + n_jobs=n_jobs, + ) + common.update(fixed) + if args.early_stopping_rounds > 0: + return XGBEarlyStop( + val_frac=args.early_stopping_val_frac, + early_stopping_rounds=args.early_stopping_rounds, + **common, + ) + return XGBClassifier(**common) + + +def train_classifier(train_df, test_df, args, scale_pos_weight): + if args.grid_search: + mcw = (len(train_df) * np.array(args.min_child_weight_fracs)).astype(int) + grid = { + "n_estimators": args.n_estimators_grid, + "scale_pos_weight": [scale_pos_weight], + "max_depth": args.max_depth_grid, + "min_child_weight": mcw.tolist(), + "colsample_bytree": args.colsample_bytree_grid, + "gamma": args.gamma_grid, + "learning_rate": args.learning_rate_grid, + } + xgb_model = GridSearchCV( + _make_xgb(args, scale_pos_weight, n_jobs=args.inner_jobs), + param_grid=grid, + cv=5, + scoring="roc_auc", + verbose=2, + n_jobs=args.outer_jobs, + ) + else: + # no grid search -> give all threads to XGBoost + xgb_model = _make_xgb( + args, + scale_pos_weight, + n_jobs=args.outer_jobs * args.inner_jobs, + n_estimators=args.n_estimators_grid[0], + max_depth=args.max_depth_grid[0], + min_child_weight=int(len(train_df) * args.min_child_weight_fracs[0]), + gamma=args.gamma_grid[0], + colsample_bytree=args.colsample_bytree_grid[0], + learning_rate=args.learning_rate_grid[0], + ) + + train_psms = mokapot.read_pin(train_df) + model = mokapot.Model( + xgb_model, + train_fdr=args.train_fdr, + subset_max_train=args.subset_max_train, + direction=args.direction, + max_iter=args.mokapot_max_iter, + override=args.mokapot_override, + ) + model.fit(train_psms) + test_psms = mokapot.read_pin(test_df) + scores = model.predict(test_psms) + try: + test_conf = test_psms.assign_confidence(scores) + except IndexError: + # triqler's qvality spline (called inside assign_confidence) fails with + # `gamma[-3]` when there are <3 unique score bins. Weak models can + # collapse their output onto a handful of values. Break ties with + # sub-nanoscale deterministic jitter so the spline has enough bins. + if not args.mokapot_override: + raise + logging.warning( + "assign_confidence hit triqler IndexError; retrying with jittered scores" + ) + rng = np.random.default_rng(RANDOM_SEED) + scores = np.asarray(scores, dtype=float) + rng.uniform( + -1e-12, 1e-12, size=len(scores) + ) + test_conf = test_psms.assign_confidence(scores) + return model, test_conf + + +def balance_df(df): + min_count = df["Label"].value_counts().min() + return ( + df.groupby("Label", group_keys=False).sample(n=min_count).reset_index(drop=True) + ) + + +def read_features(infile, args): + df = pd.read_csv(infile, sep="\t") + df.insert(0, "SpecId", df.index) + df["Peptide"] = df.SpecId + df["Proteins"] = df.SpecId + df["scannr"] = df.SpecId + assert "Label" in df.columns + logging.info(f"Label counts raw: {df.Label.value_counts().to_dict()}") + + df = df[ + (df.msp_len >= args.min_msp_length_for_positive_fire_call) | (df.Label == -1) + ] + df = df[ + (df.msp_len >= args.min_msp_length_for_negative_fire_call) | (df.Label == 1) + ] + df = df.groupby(["fiber", "Label"]).sample(n=1).reset_index(drop=True) + + for col in df.columns: + if "AT" in col or "rle" in col: + df[col] = 1 + + df.drop(columns=["#chrom", "start", "end", "fiber"], inplace=True) + r = np.random.rand(len(df)) + train = df[r < 0.80] + test = df[r >= 0.80] + train = train[ + (train.msp_len >= args.min_msp_length_for_positive_fire_call) + | (train.Label == 1) + ] + train_out = ( + balance_df(train) if args.balance_train else train.reset_index(drop=True) + ) + return train_out, balance_df(test) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("training_data") + ap.add_argument("--outdir", required=True) + ap.add_argument( + "--outer-jobs", + type=int, + default=1, + help="joblib workers for GridSearchCV (one process per fold/grid-point fit).", + ) + ap.add_argument( + "--inner-jobs", + type=int, + default=8, + help="XGBoost n_jobs (OpenMP threads per single fit).", + ) + ap.add_argument("--train-fdr", type=float, default=0.05) + ap.add_argument("--test-fdr", type=float, default=0.05) + ap.add_argument("--subset-max-train", type=int, default=2_000_000) + ap.add_argument("--direction", default="msp_len_times_m6a_fc") + ap.add_argument("--min-msp-length-for-positive-fire-call", type=int, default=85) + ap.add_argument("--min-msp-length-for-negative-fire-call", type=int, default=85) + ap.add_argument("--grid-search", action="store_true") + ap.add_argument("--n-estimators-grid", default="[200, 300]") + ap.add_argument("--max-depth-grid", default="[9, 15]") + ap.add_argument("--min-child-weight-fracs", default="[0.001, 0.005]") + ap.add_argument("--colsample-bytree-grid", default="[0.5, 1.0]") + ap.add_argument("--gamma-grid", default="[1]") + ap.add_argument("--learning-rate-grid", default="[0.3]") + ap.add_argument( + "--early-stopping-rounds", + type=int, + default=0, + help="0 disables. >0 holds out --early-stopping-val-frac for early stopping.", + ) + ap.add_argument("--early-stopping-val-frac", type=float, default=0.15) + ap.add_argument("--mokapot-max-iter", type=int, default=15) + ap.add_argument( + "--balance-train", + action=argparse.BooleanOptionalAction, + default=True, + help="Downsample majority class in the training set. --no-balance-train keeps all rows.", + ) + ap.add_argument( + "--mokapot-override", + action="store_true", + help="Downgrade mokapot's 'model performs worse than direction' from error to warning.", + ) + args = ap.parse_args() + + args.n_estimators_grid = parse_list(args.n_estimators_grid, int) + args.max_depth_grid = parse_list(args.max_depth_grid, int) + args.min_child_weight_fracs = parse_list(args.min_child_weight_fracs, float) + args.colsample_bytree_grid = parse_list(args.colsample_bytree_grid, float) + args.gamma_grid = parse_list(args.gamma_grid, float) + args.learning_rate_grid = parse_list(args.learning_rate_grid, float) + + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + + logging.basicConfig( + format="[%(levelname)s][%(relativeCreated)d ms]: %(message)s", + level=logging.INFO, + ) + + train_df, test_df = read_features(args.training_data, args) + n_pos = int((train_df.Label == 1).sum()) + n_neg = int((train_df.Label == -1).sum()) + scale_pos_weight = (n_neg / max(n_pos, 1)) if n_pos else 1.0 + + model, test_conf = train_classifier(train_df, test_df, args, scale_pos_weight) + + metrics = save_model(model, test_conf, outdir, max_fdr=args.test_fdr) + metrics.update( + dict( + experiment=outdir.name, + n_train=int(len(train_df)), + n_test=int(len(test_df)), + n_train_pos=n_pos, + n_train_neg=n_neg, + scale_pos_weight=float(scale_pos_weight), + train_fdr=args.train_fdr, + test_fdr=args.test_fdr, + direction=args.direction, + grid_search=bool(args.grid_search), + outer_jobs=int(args.outer_jobs), + inner_jobs=int(args.inner_jobs), + early_stopping_rounds=int(args.early_stopping_rounds), + early_stopping_val_frac=float(args.early_stopping_val_frac), + mokapot_max_iter=int(args.mokapot_max_iter), + balance_train=bool(args.balance_train), + mokapot_override=bool(args.mokapot_override), + learning_rate_grid=args.learning_rate_grid, + ) + ) + if args.grid_search and hasattr(model.estimator, "best_params_"): + metrics["best_params"] = model.estimator.best_params_ + (outdir / "metrics.json").write_text(json.dumps(metrics, indent=2, default=str)) + logging.info(f"Wrote metrics: {metrics}") + + +if __name__ == "__main__": + main() diff --git a/Train-FIRE/workflow/templates/bed12_filter.as b/Train-FIRE/workflow/templates/bed12_filter.as new file mode 100644 index 00000000..ad075712 --- /dev/null +++ b/Train-FIRE/workflow/templates/bed12_filter.as @@ -0,0 +1,17 @@ +table decoration +"Browser extensible data (12 fields) plus information about what item this decorates and how." + ( + string chrom; "Chromosome (or contig, scaffold, etc.)" + uint chromStart; "Start position in chromosome" + uint chromEnd; "End position in chromosome" + string name; "Name of item" + uint score; "Score from 0-1000" + char[1] strand; "+ or -" + uint thickStart; "Start of where display should be thick (start codon)" + uint thickEnd; "End of where display should be thick (stop codon)" + uint color; "Primary RGB color for the decoration" + int blockCount; "Number of blocks" + int[blockCount] blockSizes; "Comma separated list of block sizes" + int[blockCount] chromStarts; "Start positions relative to chromStart" + lstring keywords; "hap associated with the record" + ) diff --git a/Train-FIRE/workflow/templates/decoration.as b/Train-FIRE/workflow/templates/decoration.as new file mode 100644 index 00000000..c5d6cca9 --- /dev/null +++ b/Train-FIRE/workflow/templates/decoration.as @@ -0,0 +1,21 @@ +table decoration +"Browser extensible data (12 fields) plus information about what item this decorates and how." + ( + string chrom; "Chromosome (or contig, scaffold, etc.)" + uint chromStart; "Start position in chromosome" + uint chromEnd; "End position in chromosome" + string name; "Name of item" + uint score; "Score from 0-1000" + char[1] strand; "+ or -" + uint thickStart; "Start of where display should be thick (start codon)" + uint thickEnd; "End of where display should be thick (stop codon)" + uint color; "Primary RGB color for the decoration" + int blockCount; "Number of blocks" + int[blockCount] blockSizes; "Comma separated list of block sizes" + int[blockCount] chromStarts; "Start positions relative to chromStart" + string decoratedItem; "Identity of the decorated item in chr:start-end:item_name format" + string style; "Draw style for the decoration (e.g. block, glyph)" + string fillColor; "Secondary color to use for filling decoration, blocks, supports RGBA" + string glyph; "The glyph to draw in glyph mode; ignored for other styles" + lstring keywords; "Keywords associated with the decoration" + ) diff --git a/build.rs b/build.rs index 6c57d3b6..caae525d 100644 --- a/build.rs +++ b/build.rs @@ -45,12 +45,16 @@ fn main() { // Generate the model code and state file from the ONNX file. use burn_import::onnx::ModelGen; use burn_import::onnx::RecordType; - for x in &[ + let onnx_models = [ "src/m6a_burn/two_zero.onnx", "src/m6a_burn/two_two.onnx", "src/m6a_burn/three_two.onnx", "src/m6a_burn/revio.onnx", - ] { + ]; + for x in &onnx_models { + println!("cargo:rerun-if-changed={}", x); + } + for x in &onnx_models { ModelGen::new() .input(x) // Path to the ONNX model .out_dir("m6a_burn/") // Directory for the generated Rust source file (under target/) diff --git a/src/cli.rs b/src/cli.rs index 1959be7a..0ec83ee6 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -4,6 +4,8 @@ use clap_complete::{generate, Generator, Shell}; use std::{fmt::Debug, io}; // Reference the modules for the subcommands +mod benchmark_opts; +mod call_peaks_opts; mod center_opts; mod clear_kinetics_opts; mod ddda_to_m6a_opts; @@ -12,6 +14,7 @@ mod extract_opts; mod fiber_hmm; mod fire_opts; mod footprint_opts; +mod mock_fire_opts; mod nucleosome_opts; mod pg_inject_opts; mod pg_lift_opts; @@ -23,6 +26,8 @@ mod strip_basemods_opts; mod validate_opts; // include the subcommand modules as top level functions and structs in the cli module +pub use benchmark_opts::*; +pub use call_peaks_opts::*; pub use center_opts::*; pub use clear_kinetics_opts::*; pub use ddda_to_m6a_opts::*; @@ -31,6 +36,7 @@ pub use extract_opts::*; pub use fiber_hmm::*; pub use fire_opts::*; pub use footprint_opts::*; +pub use mock_fire_opts::*; pub use nucleosome_opts::*; pub use pg_inject_opts::*; pub use pg_lift_opts::*; @@ -152,6 +158,16 @@ pub enum Commands { /// Add or strip panSN-spec prefixes from BAM contig names #[clap(name = "pg-pansn")] PgPansn(PgPansnOptions), + /// Call FIRE peaks using FDR-based peak calling on pileup data + #[clap(name = "call-peaks", visible_aliases = &["peaks", "call"])] + CallPeaks(CallPeaksOptions), + /// Create a mock BAM file with FIRE elements from a BED file. + /// Each interval in the BED becomes a FIRE element. The 4th column groups intervals into the same mock read. + #[clap(name = "mock-fire")] + MockFire(MockFireOptions), + /// Benchmark fiber iterator performance (hidden command for testing) + #[clap(hide = true)] + Benchmark(BenchmarkOptions), /// Make command line completions #[clap(hide = true)] Completions(CompletionOptions), diff --git a/src/cli/benchmark_opts.rs b/src/cli/benchmark_opts.rs new file mode 100644 index 00000000..b7a3da87 --- /dev/null +++ b/src/cli/benchmark_opts.rs @@ -0,0 +1,12 @@ +use crate::utils::input_bam::InputBam; +use clap::Args; + +#[derive(Args, Debug)] +pub struct BenchmarkOptions { + #[clap(flatten)] + pub input: InputBam, + + /// Region to benchmark (format: chr:start-end) + #[clap(short, long)] + pub region: Option, +} diff --git a/src/cli/call_peaks_opts.rs b/src/cli/call_peaks_opts.rs new file mode 100644 index 00000000..a40f9885 --- /dev/null +++ b/src/cli/call_peaks_opts.rs @@ -0,0 +1,92 @@ +use crate::utils::input_bam::InputBam; +use clap::Args; +use std::fmt::Debug; + +#[derive(Args, Debug)] +pub struct CallPeaksOptions { + #[clap(flatten)] + pub input: InputBam, + + /// BED file with shuffled fiber positions (from bedtools shuffle) + /// If not provided, will use all positions as real data (no FDR calculation) + #[clap(short, long)] + pub shuffled: Option, + + /// Output BED file with called peaks + #[clap(short, long, default_value = "-")] + pub out: String, + + /// Maximum coverage threshold for filtering (optional) + #[clap(long)] + pub max_cov: Option, + + /// Minimum coverage threshold for filtering (optional) + #[clap(long)] + pub min_cov: Option, + + /// Number of standard deviations from median coverage to use for filtering (default: 5) + /// If set, will calculate median +/- (sd_cov * std_dev) and use those as min/max coverage + /// This overrides --max-cov and --min-cov if those are not explicitly set + #[clap(long, default_value = "5.0")] + pub sd_cov: f64, + + /// Maximum FDR threshold for peak calling (ignored if --min-fire-frac is set) + #[clap(long, default_value = "0.05")] + pub max_fdr: f64, + + /// Minimum fraction of fibers with FIREs required to call a peak + /// If set, skips FDR calculation and uses this threshold instead + /// For example, 0.5 means at least 50% of fibers must have a FIRE at the peak position + #[clap(long)] + pub min_fire_frac: Option, + + /// Minimum fraction of fibers with FIREs required as an additional filter (applied WITH FDR) + /// Unlike --min-fire-frac, this is applied in addition to FDR filtering, not instead of it + /// For example, 0.3 means at least 30% of fibers must have a FIRE AND FDR must be <= max_fdr + #[clap(long, default_value = "0.1")] + pub min_fire_frac_filter: f64, + + /// Minimum fraction of accessible bases in peak + #[clap(long, default_value = "0.0", hide = true)] + pub min_frac_accessible: f64, + + /// Rolling window size for finding local maxima (in base pairs) + #[clap(long, default_value = "200")] + pub window_size: usize, + + /// Minimum fraction of overlapping FIRE elements for merging peaks (Phase 2) + #[clap(long, default_value = "0.5")] + pub min_frac_overlap: f64, + + /// Minimum reciprocal overlap for merging peaks (Phase 3) + #[clap(long, default_value = "0.75")] + pub min_reciprocal_overlap: f64, + + /// High reciprocal overlap threshold for initial merging (Phase 1) + #[clap(long, default_value = "0.90")] + pub high_reciprocal_overlap: f64, + + /// Maximum number of grouping iterations for merging + #[clap(long, default_value = "10")] + pub max_grouping_iterations: usize, + + /// Skip the FDR table generation and use existing table + #[clap(long)] + pub fdr_table: Option, + + /// Output the FDR table to this file + #[clap(long)] + pub fdr_table_out: Option, + + /// Include nucleosome and MSP coverage in pileup (default: only FIRE coverage) + #[clap(long)] + pub include_nuc_msp: bool, + + /// Include haplotype-specific calls + #[clap(long)] + pub haps: bool, + + /// Minimum FIRE coverage required to calculate a score (default: 4) + #[clap(long, default_value = "4", hide = true)] + pub min_fire_coverage: i32, +} diff --git a/src/cli/fire_opts.rs b/src/cli/fire_opts.rs index 7d95fd75..63160359 100644 --- a/src/cli/fire_opts.rs +++ b/src/cli/fire_opts.rs @@ -28,15 +28,6 @@ pub struct FireOptions { /// Output FIREs features for training in a table format #[clap(short, long)] pub feats_to_text: bool, - /// Don't write reads with no m6A calls to the output bam - #[clap(short, long)] - pub skip_no_m6a: bool, - /// Skip reads without at least `N` MSP calls - #[clap(long, default_value = "0", env = "MIN_MSP")] - pub min_msp: usize, - /// Skip reads without an average MSP size greater than `N` - #[clap(long, default_value = "0", env)] - pub min_ave_msp_size: i64, /// Width of bin for feature collection #[clap(short, long, default_value = "40", env, default_value_ifs([ @@ -53,7 +44,7 @@ pub struct FireOptions { #[clap(long, default_value = "100", env)] pub best_window_size: i64, /// Use 5mC data in FIREs - #[clap(short, long, hide = true)] + #[clap(long, hide = true)] pub use_5mc: bool, /// Minium length of msp to call a FIRE #[clap(long, default_value = "85", env)] @@ -78,9 +69,6 @@ impl Default for FireOptions { extract: false, all: false, feats_to_text: false, - skip_no_m6a: false, - min_msp: 0, - min_ave_msp_size: 0, width_bin: 40, bin_num: 9, best_window_size: 100, diff --git a/src/cli/mock_fire_opts.rs b/src/cli/mock_fire_opts.rs new file mode 100644 index 00000000..a0cae9b2 --- /dev/null +++ b/src/cli/mock_fire_opts.rs @@ -0,0 +1,25 @@ +use crate::cli::GlobalOpts; +use clap::Args; +use std::fmt::Debug; + +#[derive(Args, Debug)] +pub struct MockFireOptions { + /// Input BED file where intervals become FIRE elements. + /// The 4th column (name) groups intervals into the same mock read. + #[clap()] + pub bed: String, + /// Output BAM file + #[clap(short, long, default_value = "-")] + pub out: String, + /// Length of mock reads (default: auto-calculated from BED intervals) + #[clap(short, long)] + pub read_length: Option, + /// FIRE quality score to assign (0-255, higher = more confident FIRE call) + #[clap(short, long, default_value_t = 255)] + pub quality: u8, + /// Uncompressed BAM output (default: compressed) + #[clap(short, long)] + pub uncompressed: bool, + #[clap(flatten)] + pub global: GlobalOpts, +} diff --git a/src/cli/pileup_opts.rs b/src/cli/pileup_opts.rs index 7a6e1174..89dcbdbb 100644 --- a/src/cli/pileup_opts.rs +++ b/src/cli/pileup_opts.rs @@ -6,10 +6,15 @@ use std::fmt::Debug; pub struct PileupOptions { #[clap(flatten)] pub input: InputBam, - /// Region string to make a pileup of. e.g. chr1:1-1000 or chr1:1-1,000 + /// Region string(s) to make a pileup of. e.g. chr1:1-1000 or chr1:1-1,000 + /// Can be specified multiple times for multiple regions. /// If not provided will make a pileup of the whole genome - #[clap(default_value = None)] - pub rgn: Option, + #[clap(short, long)] + pub rgn: Vec, + /// BED file with regions to query. If the BED file has a name column (4th column), + /// the name will be added to each output line for that region. + #[clap(short, long, conflicts_with = "rgn")] + pub bed: Option, /// Output file #[clap(short, long, default_value = "-")] pub out: String, @@ -49,3 +54,12 @@ pub struct PileupOptions { #[clap(long)] pub no_nuc: bool, } + +impl PileupOptions { + /// `--fire-filter` bundles `--fiber-coverage` in addition to the three + /// filters. Callers should use this in place of reading `fiber_coverage` + /// directly so the bundle stays coherent. + pub fn effective_fiber_coverage(&self) -> bool { + self.fiber_coverage || self.input.filters.fire_filter + } +} diff --git a/src/fiber.rs b/src/fiber.rs index 68b4cbf6..b7d3d7dd 100644 --- a/src/fiber.rs +++ b/src/fiber.rs @@ -498,19 +498,22 @@ impl FiberseqData { } } -pub struct FiberseqRecords<'a> { - bam_chunk: BamChunk<'a>, +pub struct FiberseqRecords<'a, R = bam::Reader> +where + R: bam::Read, +{ + bam_chunk: BamChunk<'a, R>, header: HeaderView, filters: FiberFilters, cur_chunk: Vec, } -impl<'a> FiberseqRecords<'a> { +impl<'a> FiberseqRecords<'a, bam::Reader> { pub fn new(bam: &'a mut bam::Reader, filters: FiberFilters) -> Self { let header = bam.header().clone(); let bam_recs = bam.records(); let mut bam_chunk = BamChunk::new(bam_recs, None); - bam_chunk.set_bit_flag_filter(filters.bit_flag); + bam_chunk.set_bit_flag_filter(filters.get_bit_flag()); let cur_chunk: Vec = vec![]; FiberseqRecords { bam_chunk, @@ -521,21 +524,48 @@ impl<'a> FiberseqRecords<'a> { } } -impl Iterator for FiberseqRecords<'_> { +impl<'a> FiberseqRecords<'a, bam::IndexedReader> { + pub fn from_rec_iterator( + bam_recs: bam::Records<'a, bam::IndexedReader>, + header: HeaderView, + filters: FiberFilters, + ) -> Self { + let mut bam_chunk = BamChunk::new(bam_recs, None); + bam_chunk.set_bit_flag_filter(filters.get_bit_flag()); + let cur_chunk: Vec = vec![]; + FiberseqRecords { + bam_chunk, + header, + filters, + cur_chunk, + } + } +} + +impl Iterator for FiberseqRecords<'_, R> +where + R: bam::Read, +{ type Item = FiberseqData; fn next(&mut self) -> Option { - // if we are out of data check for another chunk in the bam - if self.cur_chunk.is_empty() { - match self.bam_chunk.next() { - Some(recs) => { - self.cur_chunk = FiberseqData::from_records(recs, &self.header, &self.filters); - // we will be popping from this list so we want to remove the first element first, not the last - self.cur_chunk.reverse(); + loop { + // if we are out of data check for another chunk in the bam + if self.cur_chunk.is_empty() { + match self.bam_chunk.next() { + Some(recs) => { + self.cur_chunk = + FiberseqData::from_records(recs, &self.header, &self.filters); + // we will be popping from this list so we want to remove the first element first, not the last + self.cur_chunk.reverse(); + } + None => return None, } - None => return None, + } + let rec = self.cur_chunk.pop()?; + if self.filters.passes_fire_filter(&rec) { + return Some(rec); } } - self.cur_chunk.pop() } } diff --git a/src/lib.rs b/src/lib.rs index c2ba4f86..d1ef29f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,9 +21,9 @@ use std::io::Write; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); lazy_static! { pub static ref FULL_VERSION: String = format!( - "v{}\tgit-details {}", + "v{}\tgit-commit {}", env!("CARGO_PKG_VERSION"), - env!("VERGEN_GIT_DESCRIBE") + &env!("VERGEN_GIT_SHA")[..7.min(env!("VERGEN_GIT_SHA").len())] ); } // if this string (bar)gets too long it displays weird when writing to stdout diff --git a/src/main.rs b/src/main.rs index 11db9d43..55bd25cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -132,6 +132,15 @@ pub fn main() -> Result<(), Error> { Some(Commands::PgPansn(pg_pansn_opts)) => { subcommands::pg_pansn::run_pg_pansn(pg_pansn_opts)?; } + Some(Commands::CallPeaks(ref mut call_peaks_opts)) => { + subcommands::call_peaks::run_call_peaks(call_peaks_opts)?; + } + Some(Commands::MockFire(mock_fire_opts)) => { + subcommands::mock_fire::run_mock_fire(mock_fire_opts)?; + } + Some(Commands::Benchmark(benchmark_opts)) => { + subcommands::benchmark::run_benchmark(benchmark_opts)?; + } None => {} }; let duration = pg_start.elapsed(); diff --git a/src/subcommands.rs b/src/subcommands.rs index 1c1a6eaa..b2df9a8d 100644 --- a/src/subcommands.rs +++ b/src/subcommands.rs @@ -1,5 +1,9 @@ /// Add nucleosomes to a bam file pub mod add_nucleosomes; +/// Benchmark fiber iterator performance +pub mod benchmark; +/// Call FIRE peaks using FDR-based peak calling +pub mod call_peaks; /// Center fiberseq information around a reference position pub mod center; /// Clear HiFi kinetics tags from a bam file @@ -12,6 +16,8 @@ pub mod extract; /// add fire data pub mod fire; pub mod footprint; +/// create mock BAM with FIRE elements +pub mod mock_fire; /// make a fire track from a bam file pub mod pileup; /// m6A prediction diff --git a/src/subcommands/benchmark.rs b/src/subcommands/benchmark.rs new file mode 100644 index 00000000..56712ab3 --- /dev/null +++ b/src/subcommands/benchmark.rs @@ -0,0 +1,33 @@ +use crate::cli::BenchmarkOptions; +use anyhow::Result; +use std::time::Instant; + +pub fn run_benchmark(opts: &mut BenchmarkOptions) -> Result<()> { + log::info!("Starting fiber iterator benchmark"); + + let mut bam = opts.input.bam_reader(); + + // Benchmark input.fibers() + let start_time = Instant::now(); + let mut fiber_count = 0; + + for fiber in opts.input.fibers(&mut bam) { + fiber_count += 1; + // Force materialization but do nothing with the fiber + std::hint::black_box(&fiber); + } + + let elapsed = start_time.elapsed(); + + log::info!( + "Processed {} fibers in {:.3} seconds", + fiber_count, + elapsed.as_secs_f64() + ); + log::info!( + "Rate: {:.0} fibers/second", + fiber_count as f64 / elapsed.as_secs_f64() + ); + + Ok(()) +} diff --git a/src/subcommands/call_peaks/fdr.rs b/src/subcommands/call_peaks/fdr.rs new file mode 100644 index 00000000..ba723c6a --- /dev/null +++ b/src/subcommands/call_peaks/fdr.rs @@ -0,0 +1,449 @@ +use anyhow::{Context, Result}; +use std::collections::HashMap; +use std::io::Write; + +use crate::cli::CallPeaksOptions; + +/// FDR table entry mapping FIRE scores to FDR values +#[derive(Debug, Clone)] +pub struct FdrEntry { + pub threshold: f64, + pub fdr: f64, + pub shuffled_bp: f64, + pub real_bp: f64, +} + +/// Look up the FDR value for a given score in an ascending-by-threshold FDR table. +/// Returns 1.0 if the table is empty (FIRE fraction mode). +/// For scores below all thresholds, returns the lowest-threshold entry's FDR. +/// Otherwise returns the FDR of the largest threshold <= score. +pub fn lookup_fdr(score: f32, fdr_table: &[FdrEntry]) -> f64 { + if fdr_table.is_empty() { + return 1.0; + } + + let idx = fdr_table.binary_search_by(|entry| { + entry + .threshold + .partial_cmp(&(score as f64)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + match idx { + Ok(i) => fdr_table[i].fdr, + Err(i) => { + if i == 0 { + fdr_table[0].fdr + } else { + fdr_table[i - 1].fdr + } + } + } +} + +/// Pileup record structure (simplified for FDR calculation) +#[derive(Debug, Clone)] +pub struct PileupRecord { + pub start: u64, + pub end: u64, + pub coverage: u32, + pub fire_coverage: u32, + pub score: f64, +} + +/// Calculate FDR from aggregated FIRE scores +/// This follows the Python logic in fdr_from_fire_scores() +fn fdr_from_fire_scores( + fire_scores: &[(f64, bool, u64)], +) -> (Vec, Vec, Vec, Vec) { + let mut vs = Vec::new(); // shuffled bp + let mut rs = Vec::new(); // real bp + let mut ts = Vec::new(); // thresholds + let mut cur_r = 0.0; + let mut cur_v = 0.0; + let mut pre_score = -1.0; + let mut first = true; + + let mut processed_count = 0; + let mut skipped_negative = 0; + + for &(score, is_real, bp) in fire_scores { + // don't add negative scores to the fdr data, since they have no coverage + if score < 0.0 { + skipped_negative += 1; + continue; + } + + // save the counts and thresholds as long as we have counts + if score != pre_score && cur_r > 0.0 && !first { + rs.push(cur_r); + vs.push(cur_v); + ts.push(pre_score); + } + + processed_count += 1; + + // update the counts + if is_real { + cur_r += bp as f64; + } else { + cur_v += bp as f64; + } + + // prepare for next iteration + pre_score = score; + first = false; + } + + log::debug!( + "fdr_from_fire_scores: processed={}, skipped_negative={}, cur_r={:.0}, cur_v={:.0}", + processed_count, + skipped_negative, + cur_r, + cur_v + ); + + // add the last threshold with an FDR of 1 + rs.push(1.0); + vs.push(1.0); + ts.push(-1.0); + + // calculate FDRs + let fdrs: Vec = vs + .iter() + .zip(rs.iter()) + .map(|(&v, &r)| { + let fdr = v / r; + if fdr > 1.0 { + 1.0 + } else { + fdr + } + }) + .collect(); + + (ts, fdrs, vs, rs) +} + +/// Create FDR table from fire scores +/// This follows the Python logic in fdr_table_from_scores() +fn fdr_table_from_scores(fire_scores: &[(f64, bool, u64)]) -> Vec { + let (thresholds, fdrs, shuffled_bps, real_bps) = fdr_from_fire_scores(fire_scores); + + let mut entries: Vec = thresholds + .iter() + .zip(fdrs.iter()) + .zip(shuffled_bps.iter()) + .zip(real_bps.iter()) + .map(|(((&threshold, &fdr), &shuffled_bp), &real_bp)| FdrEntry { + threshold, + fdr, + shuffled_bp, + real_bp, + }) + .collect(); + + // simplify the results - group by FDR and keep last + entries.sort_by(|a, b| a.fdr.partial_cmp(&b.fdr).unwrap()); + deduplicate_by_key(&mut entries, |e| (e.fdr * 1000000.0) as i64); + + // group by shuffled_bp and keep last + entries.sort_by(|a, b| a.shuffled_bp.partial_cmp(&b.shuffled_bp).unwrap()); + deduplicate_by_key(&mut entries, |e| e.shuffled_bp as i64); + + // group by real_bp and keep last + entries.sort_by(|a, b| a.real_bp.partial_cmp(&b.real_bp).unwrap()); + deduplicate_by_key(&mut entries, |e| e.real_bp as i64); + + // round thresholds to 2 decimal places + for entry in &mut entries { + entry.threshold = (entry.threshold * 100.0).round() / 100.0; + } + + // group by threshold and keep last + entries.sort_by(|a, b| a.threshold.partial_cmp(&b.threshold).unwrap()); + deduplicate_by_key(&mut entries, |e| (e.threshold * 100.0) as i64); + + // sort by threshold ascending (needed for binary search later) + entries.sort_by(|a, b| a.threshold.partial_cmp(&b.threshold).unwrap()); + + log::info!("FDR table has {} entries", entries.len()); + if !entries.is_empty() { + log::debug!( + "First FDR entry: threshold={:.2}, FDR={:.4}", + entries[0].threshold, + entries[0].fdr + ); + log::debug!( + "Last FDR entry: threshold={:.2}, FDR={:.4}", + entries.last().unwrap().threshold, + entries.last().unwrap().fdr + ); + } + + entries +} + +/// Helper function to deduplicate entries by key, keeping the last occurrence +fn deduplicate_by_key(entries: &mut Vec, key_fn: impl Fn(&T) -> K) +where + T: Clone, +{ + let mut seen = HashMap::new(); + for (idx, entry) in entries.iter().enumerate() { + seen.insert(key_fn(entry), idx); + } + let mut keep_indices: Vec<_> = seen.values().copied().collect(); + keep_indices.sort_unstable(); + + let mut result = Vec::with_capacity(keep_indices.len()); + for &idx in &keep_indices { + result.push(entries[idx].clone()); + } + *entries = result; +} + +/// Incremental FDR builder that aggregates scores without keeping full pileup data +pub struct IncrementalFdrBuilder { + real_scores: HashMap, u64>, + shuffled_scores: HashMap, u64>, + real_record_count: usize, + shuffled_record_count: usize, +} + +impl Default for IncrementalFdrBuilder { + fn default() -> Self { + Self::new() + } +} + +impl IncrementalFdrBuilder { + /// Create a new incremental FDR builder + pub fn new() -> Self { + Self { + real_scores: HashMap::new(), + shuffled_scores: HashMap::new(), + real_record_count: 0, + shuffled_record_count: 0, + } + } + + /// Add pileup data from one chromosome + pub fn add_chromosome_data( + &mut self, + real_pileup: &[PileupRecord], + shuffled_pileup: &[PileupRecord], + ) { + use ordered_float::NotNan; + + // Update counts for logging + self.real_record_count += real_pileup.len(); + self.shuffled_record_count += shuffled_pileup.len(); + + // Aggregate real scores + for record in real_pileup { + let bp = record.end - record.start; + if let Ok(score_key) = NotNan::new(record.score) { + *self.real_scores.entry(score_key).or_insert(0) += bp; + } + } + + // Aggregate shuffled scores + for record in shuffled_pileup { + let bp = record.end - record.start; + if let Ok(score_key) = NotNan::new(record.score) { + *self.shuffled_scores.entry(score_key).or_insert(0) += bp; + } + } + } + + /// Finalize and build the FDR table + pub fn build(self, max_fdr: f64) -> Result> { + log::info!("Generating FDR table from accumulated score data"); + log::debug!("Real pileup: {} total records", self.real_record_count); + log::debug!( + "Shuffled pileup: {} total records", + self.shuffled_record_count + ); + + // Combine and sort by score descending + let mut fire_scores: Vec<(f64, bool, u64)> = Vec::new(); + for (score_notnan, bp) in self.real_scores { + fire_scores.push((score_notnan.into_inner(), true, bp)); + } + for (score_notnan, bp) in self.shuffled_scores { + fire_scores.push((score_notnan.into_inner(), false, bp)); + } + fire_scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); // descending by score + + // Calculate sums for logging + let real_mbp: f64 = fire_scores + .iter() + .filter(|(_, is_real, _)| *is_real) + .map(|(_, _, bp)| *bp as f64) + .sum::() + / 1_000_000.0; + let shuffled_mbp: f64 = fire_scores + .iter() + .filter(|(_, is_real, _)| !*is_real) + .map(|(_, _, bp)| *bp as f64) + .sum::() + / 1_000_000.0; + log::debug!("Real data: {:.2} Mbp", real_mbp); + log::debug!("Shuffled data: {:.2} Mbp", shuffled_mbp); + + // Debug: Count how many entries have negative scores + let neg_score_count = fire_scores + .iter() + .filter(|(score, _, _)| *score < 0.0) + .count(); + let neg_score_real_bp: u64 = fire_scores + .iter() + .filter(|(score, is_real, _)| *score < 0.0 && *is_real) + .map(|(_, _, bp)| *bp) + .sum(); + let neg_score_shuffled_bp: u64 = fire_scores + .iter() + .filter(|(score, is_real, _)| *score < 0.0 && !*is_real) + .map(|(_, _, bp)| *bp) + .sum(); + log::debug!( + "Negative scores: {} entries, real_bp={}, shuffled_bp={}", + neg_score_count, + neg_score_real_bp, + neg_score_shuffled_bp + ); + + // Create FDR table + let fdr_table = fdr_table_from_scores(&fire_scores); + + // Check if we have any thresholds below max_fdr + if let Some(min_fdr_entry) = fdr_table + .iter() + .min_by(|a, b| a.fdr.partial_cmp(&b.fdr).unwrap()) + { + if min_fdr_entry.fdr > max_fdr { + anyhow::bail!( + "No FIRE score threshold has an FDR < {}. Check the input Fiber-seq data with the QC pipeline and make sure you are using WGS Fiber-seq data.", + max_fdr + ); + } + } + + Ok(fdr_table) + } +} + +/// Generate FDR table incrementally, processing one chromosome at a time +/// This avoids keeping all pileup records in memory at once +pub fn fdr_table( + opts: &mut CallPeaksOptions, + bam: &mut rust_htslib::bam::IndexedReader, + header: &rust_htslib::bam::HeaderView, +) -> Result> { + use super::{chrom_names_and_lengths, chromosome_has_fibers, process_chromosome_pileup_both}; + + let mut fdr_builder = IncrementalFdrBuilder::new(); + + // Process each chromosome and add to the builder + for (chrom_str, chrom_len) in chrom_names_and_lengths(header)? { + // Skip chromosomes with no fibers + if !chromosome_has_fibers(&chrom_str, bam, opts)? { + log::debug!("Skipping chromosome {} (no fibers)", chrom_str); + continue; + } + + // Process the fibers to generate pileup records (streaming - no all_fibers Vec!) + let (real_chrom, shuffled_chrom) = + process_chromosome_pileup_both(&chrom_str, chrom_len, bam, opts)?; + + // Add to builder (this aggregates scores and drops the full records) + fdr_builder.add_chromosome_data(&real_chrom, &shuffled_chrom); + + // Summary info statement per chromosome + log::info!( + "FDR: {} ({} Mbp) - real: {} records, shuffled: {} records", + chrom_str, + chrom_len / 1_000_000, + real_chrom.len(), + shuffled_chrom.len(), + ); + } + + // Build the final FDR table + fdr_builder.build(opts.max_fdr) +} + +/// Write FDR table to TSV file +pub fn write_fdr_table(fdr_table: &[FdrEntry], path: &str) -> Result<()> { + let mut writer = + crate::utils::bio_io::writer(path).context("Failed to create FDR table output file")?; + + writeln!(writer, "threshold\tFDR\tshuffled_bp\treal_bp")?; + for entry in fdr_table { + writeln!( + writer, + "{:.2}\t{:.6}\t{:.0}\t{:.0}", + entry.threshold, entry.fdr, entry.shuffled_bp, entry.real_bp + )?; + } + + Ok(()) +} + +/// Read FDR table from TSV file +pub fn read_fdr_table(path: &str) -> Result> { + use std::fs::File; + use std::io::{BufRead, BufReader}; + + let file = File::open(path).context("Failed to open FDR table file")?; + let reader = BufReader::new(file); + let mut entries = Vec::new(); + + for (line_num, line) in reader.lines().enumerate() { + let line = line.context("Failed to read line from FDR table")?; + + // Skip header line + if line_num == 0 { + continue; + } + + let parts: Vec<&str> = line.split('\t').collect(); + if parts.len() != 4 { + anyhow::bail!( + "Invalid FDR table format at line {}: expected 4 columns, found {}", + line_num + 1, + parts.len() + ); + } + + let threshold = parts[0].parse::().context(format!( + "Failed to parse threshold at line {}", + line_num + 1 + ))?; + let fdr = parts[1] + .parse::() + .context(format!("Failed to parse FDR at line {}", line_num + 1))?; + let shuffled_bp = parts[2].parse::().context(format!( + "Failed to parse shuffled_bp at line {}", + line_num + 1 + ))?; + let real_bp = parts[3] + .parse::() + .context(format!("Failed to parse real_bp at line {}", line_num + 1))?; + + entries.push(FdrEntry { + threshold, + fdr, + shuffled_bp, + real_bp, + }); + } + + log::info!("Loaded {} FDR table entries from {}", entries.len(), path); + + // Sort by threshold (should already be sorted, but ensure it) + entries.sort_by(|a, b| a.threshold.partial_cmp(&b.threshold).unwrap()); + + Ok(entries) +} diff --git a/src/subcommands/call_peaks/mod.rs b/src/subcommands/call_peaks/mod.rs new file mode 100644 index 00000000..392b9811 --- /dev/null +++ b/src/subcommands/call_peaks/mod.rs @@ -0,0 +1,290 @@ +mod fdr; +mod peaks; + +pub use fdr::{ + fdr_table, lookup_fdr, read_fdr_table, write_fdr_table, FdrEntry, IncrementalFdrBuilder, + PileupRecord, +}; +pub use peaks::{call_peaks, reciprocal_overlap_raw}; + +use crate::cli::CallPeaksOptions; +use crate::subcommands::pileup::{FireTrack, FireTrackOptions}; +use anyhow::{Context, Result}; + +pub fn run_call_peaks(opts: &mut CallPeaksOptions) -> Result<()> { + log::info!("Starting FIRE peak calling"); + log::info!(" Input BAM: {}", opts.input.bam); + log::info!(" Output: {}", opts.out); + + if let Some(min_frac) = opts.min_fire_frac { + log::info!(" Using FIRE fraction mode: min_fire_frac = {}", min_frac); + } else { + log::info!(" Max FDR: {}", opts.max_fdr); + } + log::info!(" Window size: {}", opts.window_size); + + let mut bam = opts.input.indexed_bam_reader(); + let header = opts.input.header_view(); + + // Generate or load FDR table (skip if using FIRE fraction mode) + let fdr_table = if opts.min_fire_frac.is_some() { + // FIRE fraction mode: use empty FDR table (won't be used for filtering) + log::info!(" Skipping FDR calculation (using FIRE fraction threshold)"); + Vec::new() + } else { + // FDR mode: generate or load FDR table + if let Some(ref shuffled) = opts.shuffled { + log::info!(" Shuffled positions file: {}", shuffled); + } else { + log::info!(" Using random shuffling for FDR calculation"); + } + + if let Some(ref fdr_table_path) = opts.fdr_table { + log::info!("Loading FDR table from: {}", fdr_table_path); + read_fdr_table(fdr_table_path)? + } else { + // Generate pileup incrementally to avoid OOM + log::info!("Running pileup for real and shuffled data (incremental mode)..."); + fdr_table(opts, &mut bam, &header)? + } + }; + + // Write FDR table if requested (and if we generated one) + if let Some(ref fdr_out) = opts.fdr_table_out { + if !fdr_table.is_empty() { + log::info!("Writing FDR table to: {}", fdr_out); + write_fdr_table(&fdr_table, fdr_out)?; + } + } + + // Call peaks using the FDR table (or FIRE fraction) + call_peaks(opts, &mut bam, &header, &fdr_table)?; + + log::info!("FIRE peak calling completed"); + Ok(()) +} + +/// Default minimum coverage threshold (matching Python's MIN_COVERAGE default) +const DEFAULT_MIN_COVERAGE: i32 = 4; + +/// Get chromosome names and lengths from BAM header +/// +/// # Returns +/// Vector of tuples (chromosome_name, chromosome_length) +pub fn chrom_names_and_lengths( + header: &rust_htslib::bam::HeaderView, +) -> Result> { + let mut chroms = Vec::new(); + for chrom in header.target_names() { + let chrom_str = String::from_utf8_lossy(chrom).to_string(); + let tid = header.tid(chrom).context("Failed to get target ID")?; + let chrom_len = header + .target_len(tid) + .context("Failed to get target length")? as i64; + chroms.push((chrom_str, chrom_len)); + } + Ok(chroms) +} + +/// Check if a chromosome has any fibers +/// +/// # Arguments +/// * `chrom` - Chromosome name +/// * `bam` - Indexed BAM reader +/// * `opts` - Call peaks options (for filtering) +/// +/// # Returns +/// True if the chromosome has at least one fiber, false otherwise +fn chromosome_has_fibers( + chrom: &str, + bam: &mut rust_htslib::bam::IndexedReader, + opts: &CallPeaksOptions, +) -> Result { + // Check if there's at least one fiber + let has_fibers = opts + .input + .fetch_fibers(bam, chrom, None, None)? + .next() + .is_some(); + Ok(has_fibers) +} + +/// Process a single chromosome and return PileupRecords for both real and shuffled +/// Returns (real_records, shuffled_records) processed at the same positions +/// +/// # Arguments +/// * `chrom` - Chromosome name +/// * `chrom_len` - Chromosome length +/// * `bam` - Indexed BAM reader +/// * `opts` - Call peaks options +fn process_chromosome_pileup_both( + chrom: &str, + chrom_len: i64, + bam: &mut rust_htslib::bam::IndexedReader, + opts: &CallPeaksOptions, +) -> Result<(Vec, Vec)> { + log::debug!("Processing chromosome {} (length: {})", chrom, chrom_len); + + // Create fire track options - to calculate FIRE scores + // Note: Python uses --no-msp --no-nuc flags in shuffled_pileup_chromosome rule + // This means FIRE scores are calculated using only fiber_coverage, not MSP/NUC + let real_opts = FireTrackOptions { + no_nuc: true, // Match Python: --no-nuc + no_msp: true, // Match Python: --no-msp + m6a: false, + cpg: false, + fiber_coverage: true, + shuffle: false, + random_shuffle: false, + shuffle_seed: None, + rolling_max: None, + track_fire_elements: false, // No tracking needed for FDR calculation + }; + let mut real_track = FireTrack::new(chrom.to_string(), 0, chrom_len as usize, real_opts, &None); + + // Process all fibers and update real track (first pass - streaming) + // Note: FireTrack will automatically store fiber info in fibers_seen + log::debug!(" First pass: building real track and collecting fiber positions"); + for fiber in opts.input.fetch_fibers(bam, chrom, None, None)? { + real_track.update_with_fiber(&fiber); + } + + // Calculate scores for real fibers + real_track.calculate_scores(Some(-1)); + + // Calculate coverage thresholds based on median and standard deviation + let (median, std_dev, pos_cov) = real_track.median_and_std_coverage(); + + // Apply sd_cov thresholds if max_cov/min_cov are not explicitly set + // Match Python behavior: minimum coverage defaults to 4 + let min_cov_threshold = opts.min_cov.unwrap_or_else(|| { + let calculated_min = (median - opts.sd_cov * std_dev).round() as i32; + calculated_min.max(DEFAULT_MIN_COVERAGE) + }); + let max_cov_threshold = opts + .max_cov + .unwrap_or_else(|| (median + opts.sd_cov * std_dev).round() as i32); + + log::debug!( + " Coverage: median={:.1}, std_dev={:.1} ({:.1} SDs), range=[{}, {}]", + median, + std_dev, + opts.sd_cov, + min_cov_threshold, + max_cov_threshold + ); + log::debug!(" Real: pos_with_cov={}", pos_cov); + + // Generate shuffled positions if not using a file-based shuffle + // Pass coverage thresholds to avoid placing shuffled fibers in extreme coverage regions + let generated_shuffle = Some(real_track.generate_shuffled_positions( + Some(42), + Some(min_cov_threshold), + Some(max_cov_threshold), + )); + + // Create shuffled track with shuffle enabled + // Note: Must match real_opts to ensure consistent score calculation + let shuffled_opts = FireTrackOptions { + no_nuc: true, // Match Python: --no-nuc + no_msp: true, // Match Python: --no-msp + m6a: false, + cpg: false, + fiber_coverage: true, + shuffle: true, + random_shuffle: false, // We have explicit shuffle positions now + shuffle_seed: None, + rolling_max: None, + track_fire_elements: false, // No tracking needed for FDR calculation + }; + let mut shuffled_track = FireTrack::new( + chrom.to_string(), + 0, + chrom_len as usize, + shuffled_opts, + &generated_shuffle, + ); + + // Process the same fibers for shuffled track (second pass - streaming) + log::debug!(" Second pass: building shuffled track"); + for fiber in opts.input.fetch_fibers(bam, chrom, None, None)? { + shuffled_track.update_with_fiber(&fiber); + } + + // Calculate scores for shuffled track + shuffled_track.calculate_scores(Some(-1)); + + // Log median coverage statistics for shuffled + let (shuffled_median, shuffled_pos_cov) = shuffled_track.median_coverage(); + log::debug!( + " Shuffled: median_cov={:.1}, pos_with_cov={}", + shuffled_median, + shuffled_pos_cov + ); + + // Extract pileup records at positions where EITHER real or shuffled has coverage + // This ensures both tracks use the same position set + let mut real_records = Vec::new(); + let mut shuffled_records = Vec::new(); + + for i in 0..real_track.track_len { + let real_cov = real_track.coverage[i]; + let shuffled_cov = shuffled_track.coverage[i]; + let real_score = real_track.scores[i]; + let shuffled_score = shuffled_track.scores[i]; + + // Skip positions with: + // - No coverage (cov <= 0) + // - Invalid scores (score < 0), which indicates low FIRE coverage + // - Coverage outside min/max thresholds (matching Python filtering) + let skip_real = real_cov <= 0 + || real_score < 0.0 + || real_cov < min_cov_threshold + || real_cov > max_cov_threshold; + let skip_shuffled = shuffled_cov <= 0 + || shuffled_score < 0.0 + || shuffled_cov < min_cov_threshold + || shuffled_cov > max_cov_threshold; + + // Add real record at positions with valid coverage and score + if !skip_real { + real_records.push(PileupRecord { + start: i as u64, + end: (i + 1) as u64, + coverage: real_cov as u32, + fire_coverage: real_track.fire_coverage[i] as u32, + score: real_score as f64, + }); + } + + // Add shuffled record at positions with valid coverage and score + if !skip_shuffled { + shuffled_records.push(PileupRecord { + start: i as u64, + end: (i + 1) as u64, + coverage: shuffled_cov as u32, + fire_coverage: shuffled_track.fire_coverage[i] as u32, + score: shuffled_score as f64, + }); + } + } + + // log the lowest and highest scores in real and shuffled tracks + for (label, track) in &[("real", &real_records), ("shuffle", &shuffled_records)] { + let n_neg_one = track.iter().filter(|x| x.score < 0.0).count(); + let min = track + .iter() + .filter(|x| x.score > 0.0) + .fold(f64::INFINITY, |a, b| a.min(b.score)); + let max = track.iter().fold(f64::NEG_INFINITY, |a, b| a.max(b.score)); + log::debug!( + " Track {}: score range = [{:.3}, {:.3}], n_scores<0={}", + label, + min, + max, + n_neg_one + ); + } + + Ok((real_records, shuffled_records)) +} diff --git a/src/subcommands/call_peaks/peaks.rs b/src/subcommands/call_peaks/peaks.rs new file mode 100644 index 00000000..574bff60 --- /dev/null +++ b/src/subcommands/call_peaks/peaks.rs @@ -0,0 +1,720 @@ +use super::chrom_names_and_lengths; +use super::fdr::{lookup_fdr, FdrEntry}; +use crate::cli::CallPeaksOptions; +use crate::subcommands::pileup::{ + FiberseqPileup, FiberseqPileupOptions, FireTrack, FireTrackOptions, +}; +use crate::utils::bio_io; +use anyhow::Result; +use std::collections::{HashMap, HashSet}; +use std::io::Write; + +/// Calculate reciprocal overlap between two genomic intervals. +/// Returns 0.0 if the intervals are on different chromosomes or do not overlap. +/// Otherwise returns `min(overlap_len / a_len, overlap_len / b_len)`. +pub fn reciprocal_overlap_raw( + a_chrom: &str, + a_start: usize, + a_end: usize, + b_chrom: &str, + b_start: usize, + b_end: usize, +) -> f64 { + if a_chrom != b_chrom { + return 0.0; + } + let overlap_start = a_start.max(b_start); + let overlap_end = a_end.min(b_end); + if overlap_start >= overlap_end { + return 0.0; + } + let overlap_len = (overlap_end - overlap_start) as f64; + let a_len = (a_end - a_start) as f64; + let b_len = (b_end - b_start) as f64; + (overlap_len / a_len).min(overlap_len / b_len) +} + +/// Filtering thresholds for peak calling +#[derive(Debug, Clone, Copy)] +struct PeakThresholds { + max_fdr: f64, + min_fire_frac: Option, + min_fire_frac_filter: f64, + min_cov: i32, + max_cov: i32, +} + +/// A peak representing a local maximum in FIRE scores +#[derive(Debug)] +pub struct Peak<'a> { + pub chrom: String, + /// Start position of the local maximum region (0-based, inclusive) + /// This is the median start position of underlying FIRE elements + pub start: usize, + /// End position of the local maximum region (0-based, exclusive) + /// This is the median end position of underlying FIRE elements + pub end: usize, + /// FIRE score at this position + pub score: f32, + /// FDR value at this position + pub fdr: f64, + /// Whether this peak passes coverage filters (coverage within normal range) + pub pass_coverage: bool, + /// Index in the pileup track where the peak was called (for retrieving FIRE elements) + pub peak_index: usize, + /// Reference to the pileup track containing FIRE elements + pub pileup: &'a FiberseqPileup<'a>, +} + +impl<'a> Peak<'a> { + pub fn header() -> String { + let mut header = String::from("#chrom\tpeak_start\tpeak_end\tpeak_max\tFDR"); + for suffix in &["", "_H1", "_H2"] { + header.push_str(&format!( + "\tcoverage{suffix}\tfire_coverage{suffix}\tscore{suffix}\tnuc_coverage{suffix}\tmsp_coverage{suffix}" + )); + } + header.push_str("\tpass_coverage"); + header + } + + /// Format fire track data for output (coverage, fire_coverage, score, nuc_coverage, msp_coverage) + fn format_fire_track(&self, track: &FireTrack) -> String { + format!( + "{}\t{}\t{:.5}\t{}\t{}\t", + track.coverage[self.peak_index], + track.fire_coverage[self.peak_index], + track.scores[self.peak_index], + track.nuc_coverage[self.peak_index], + track.msp_coverage[self.peak_index], + ) + } +} + +impl<'a> std::fmt::Display for Peak<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Local max region boundaries + let local_max = self.pileup.chrom_start + self.peak_index; + + // Start with basic peak info + let mut output = format!( + "{}\t{}\t{}\t{}\t{:.10}\t", + self.chrom, // #chrom + self.start, // peak_start (merged peak start) + self.end, // peak_end (merged peak end) + local_max, // start (local max start) + self.fdr, // FDR + ); + + // Add all_data fire track + output.push_str(&self.format_fire_track(&self.pileup.all_data)); + + for t in [&self.pileup.hap1_data, &self.pileup.hap2_data] { + if let Some(ref track) = t { + output.push_str(&self.format_fire_track(track)); + } else { + output.push_str("0\t0\t-1.0\t0\t0\t"); + } + } + + // Add pass_coverage column + let pass_cov_str = if self.pass_coverage { "true" } else { "false" }; + output.push_str(pass_cov_str); + + write!(f, "{}", output) + } +} + +impl<'a> Peak<'a> { + /// Find local maxima from a pileup + /// For consecutive positions with identical rolling max scores, save the entire region + /// Filtering can be done by FDR threshold or by minimum FIRE fraction + pub fn from_pileup( + pileup: &'a FiberseqPileup<'a>, + fdr_table: &[FdrEntry], + max_fdr: f64, + min_fire_frac: Option, + min_fire_frac_filter: f64, + min_cov: i32, + max_cov: i32, + ) -> Vec { + let mut peaks = Vec::new(); + + let scores = &pileup.all_data.scores; + let rolling_max_scores = pileup + .rolling_max + .as_ref() + .expect("Rolling max scores should be calculated"); + + // Create thresholds struct + let thresholds = PeakThresholds { + max_fdr, + min_fire_frac, + min_fire_frac_filter, + min_cov, + max_cov, + }; + + let mut consecutive_maxima = Vec::new(); + + for i in 0..scores.len() { + // Skip positions with no score or negative scores + if scores[i] < 0.0 { + // If we have consecutive maxima, create a peak for the region + if !consecutive_maxima.is_empty() { + if let Some(peak) = Self::create_peak_if_significant( + pileup, + consecutive_maxima.as_slice(), + fdr_table, + &thresholds, + ) { + peaks.push(peak); + } + consecutive_maxima.clear(); + } + continue; + } + + // Check if this position is a local maximum + // (score equals the rolling max at this position) + if (scores[i] - rolling_max_scores[i]).abs() < 1e-6 { + consecutive_maxima.push(i); + } else { + // If we have consecutive maxima, create a peak for the region + if !consecutive_maxima.is_empty() { + if let Some(peak) = Self::create_peak_if_significant( + pileup, + consecutive_maxima.as_slice(), + fdr_table, + &thresholds, + ) { + peaks.push(peak); + } + consecutive_maxima.clear(); + } + } + } + + // Handle any remaining consecutive maxima at the end + if !consecutive_maxima.is_empty() { + if let Some(peak) = Self::create_peak_if_significant( + pileup, + consecutive_maxima.as_slice(), + fdr_table, + &thresholds, + ) { + peaks.push(peak); + } + } + + peaks + } + + /// Create a peak from a region of consecutive maxima if it meets filtering criteria + /// Filtering can be by FDR threshold or by minimum FIRE fraction + /// Returns None if the peak doesn't meet the threshold + /// Peak boundaries are determined by the median start/end of underlying FIRE elements + fn create_peak_if_significant( + pileup: &'a FiberseqPileup<'a>, + positions: &[usize], + fdr_table: &[FdrEntry], + thresholds: &PeakThresholds, + ) -> Option { + if positions.is_empty() { + return None; + } + + // Use the middle position to get the score and calculate threshold metrics + let middle_idx = positions.len() / 2; + let middle_pos = positions[middle_idx]; + let score = pileup.all_data.scores[middle_pos]; + + // Calculate coverage metrics + let coverage = pileup.all_data.coverage[middle_pos] as f64; + let fire_cov = pileup.all_data.fire_coverage[middle_pos] as f64; + let fire_frac = fire_cov / coverage; + + // Determine if peak passes threshold based on filtering mode + let passes_threshold = if let Some(min_frac) = thresholds.min_fire_frac { + // FIRE fraction mode: check if fraction of fibers with FIREs >= threshold + // (skips FDR calculation entirely) + fire_frac >= min_frac + } else { + // FDR mode: check if FDR <= threshold AND fire_frac >= min_fire_frac_filter + let fdr = lookup_fdr(score, fdr_table); + fdr <= thresholds.max_fdr && fire_frac >= thresholds.min_fire_frac_filter + }; + + // Calculate FDR for display purposes (even in FIRE fraction mode) + let fdr = lookup_fdr(score, fdr_table); + + // Check if coverage passes thresholds + let coverage = pileup.all_data.coverage[middle_pos]; + let pass_coverage = coverage >= thresholds.min_cov && coverage <= thresholds.max_cov; + + // Only keep peaks that pass the threshold + if passes_threshold { + // Collect all FIRE elements from the local max region + let (start, end) = if let Some(ref fire_elements_vec) = pileup.all_data.fire_elements { + let mut starts = Vec::new(); + let mut ends = Vec::new(); + + // Collect FIRE elements from all positions in the local max region + for &pos in positions { + for fire_elem in &fire_elements_vec[pos] { + starts.push(fire_elem.start); + ends.push(fire_elem.end); + } + } + + if starts.is_empty() { + // Fallback: no FIRE elements found, use the local max region boundaries + let start = pileup.chrom_start + positions[0]; + let end = pileup.chrom_start + positions[positions.len() - 1] + 1; + (start, end) + } else { + // Calculate median start and end from FIRE elements + starts.sort_unstable(); + ends.sort_unstable(); + let median_start = starts[starts.len() / 2] as usize; + let median_end = ends[ends.len() / 2] as usize; + (median_start, median_end) + } + } else { + // Fallback: FIRE element tracking not enabled + let start = pileup.chrom_start + positions[0]; + let end = pileup.chrom_start + positions[positions.len() - 1] + 1; + (start, end) + }; + + Some(Self { + chrom: pileup.chrom.clone(), + start, + end, + score, + fdr, + pass_coverage, + peak_index: middle_pos, + pileup, + }) + } else { + None + } + } + + /// Get the set of FIRE element IDs at the peak index position + pub fn get_fire_ids(&self) -> HashSet { + let mut fire_ids = HashSet::new(); + + if let Some(ref fire_elements_vec) = self.pileup.all_data.fire_elements { + // Use the peak_index position to get FIRE elements + // Error if index is out of bounds + assert!( + self.peak_index < fire_elements_vec.len(), + "peak_index {} out of bounds (len: {})", + self.peak_index, + fire_elements_vec.len() + ); + + for fire_elem in &fire_elements_vec[self.peak_index] { + fire_ids.insert(fire_elem.id); + } + } + + fire_ids + } + + /// Get FIRE elements at the peak index position as a map of id -> (start, end) + pub fn get_fire_elements(&self) -> HashMap { + let mut fire_elements = HashMap::new(); + + if let Some(ref fire_elements_vec) = self.pileup.all_data.fire_elements { + if self.peak_index < fire_elements_vec.len() { + for fire_elem in &fire_elements_vec[self.peak_index] { + fire_elements.insert(fire_elem.id, (fire_elem.start, fire_elem.end)); + } + } + } + + fire_elements + } + + /// Calculate the fraction of shared FIRE elements between two peaks + /// Returns the fraction relative to the smaller peak's FIRE element count + pub fn fire_overlap_fraction(&self, other: &Peak) -> f64 { + let self_ids = self.get_fire_ids(); + let other_ids = other.get_fire_ids(); + + if self_ids.is_empty() || other_ids.is_empty() { + return 0.0; + } + + let intersection_count = self_ids.intersection(&other_ids).count(); + let min_count = self_ids.len().min(other_ids.len()); + + intersection_count as f64 / min_count as f64 + } + + /// Calculate reciprocal overlap between two peaks + /// Returns the minimum of (overlap / self_length, overlap / other_length) + pub fn reciprocal_overlap(&self, other: &Peak) -> f64 { + reciprocal_overlap_raw( + &self.chrom, + self.start, + self.end, + &other.chrom, + other.start, + other.end, + ) + } + + /// Determine if this peak should merge with another peak + /// Uses both FIRE element overlap and reciprocal genomic overlap thresholds + pub fn should_merge_with( + &self, + other: &Peak, + min_fire_overlap: f64, + min_reciprocal_overlap: f64, + ) -> bool { + // Peaks must be on same chromosome + if self.chrom != other.chrom { + return false; + } + + // Check reciprocal genomic overlap (only if threshold > 0) + if min_reciprocal_overlap > 0.0 { + let recip_overlap = self.reciprocal_overlap(other); + if recip_overlap >= min_reciprocal_overlap { + return true; + } + } + + // Check FIRE element overlap (only if threshold > 0) + if min_fire_overlap > 0.0 { + let fire_overlap = self.fire_overlap_fraction(other); + if fire_overlap >= min_fire_overlap { + return true; + } + } + + false + } +} + +/// Merge a group of peaks into a single peak +/// Takes the peak with the highest score as representative (higher score = better) +/// and calculates boundaries as the median of all unique FIRE elements across merged peaks +fn merge_peak_group<'a>(peaks: &[&Peak<'a>]) -> Peak<'a> { + assert!(!peaks.is_empty(), "Cannot merge empty peak group"); + + // Find the peak with the highest score (which corresponds to the best/lowest FDR) + let best_peak = peaks + .iter() + .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap()) + .unwrap(); + + // Collect all unique FIRE elements from all peaks being merged + let mut all_fire_elements: HashMap = HashMap::new(); + for peak in peaks { + all_fire_elements.extend(peak.get_fire_elements()); + } + + // Calculate median start and end from all unique FIRE elements + let (merged_start, merged_end) = if all_fire_elements.is_empty() { + // Fallback to best peak boundaries if no FIRE elements + (best_peak.start, best_peak.end) + } else { + let mut starts: Vec = all_fire_elements.values().map(|(s, _)| *s).collect(); + let mut ends: Vec = all_fire_elements.values().map(|(_, e)| *e).collect(); + starts.sort_unstable(); + ends.sort_unstable(); + let median_start = starts[starts.len() / 2] as usize; + let median_end = ends[ends.len() / 2] as usize; + (median_start, median_end) + }; + + Peak { + chrom: best_peak.chrom.clone(), + start: merged_start, + end: merged_end, + score: best_peak.score, + fdr: best_peak.fdr, + pass_coverage: best_peak.pass_coverage, + peak_index: best_peak.peak_index, + pileup: best_peak.pileup, + } +} + +/// Group and merge peaks in a single pass +/// Returns the merged peak set (includes both merged and solo peaks) +fn merge_peaks_single_iteration<'a>( + peaks: Vec>, + min_fire_overlap: f64, + min_reciprocal_overlap: f64, +) -> Vec> { + if peaks.is_empty() { + return Vec::new(); + } + + let mut result = Vec::new(); + let mut current_group: Vec = vec![0]; + + for i in 1..peaks.len() { + // Check if current peak should merge with any peak in the current group + let should_merge = current_group.iter().any(|&group_idx| { + peaks[i].should_merge_with(&peaks[group_idx], min_fire_overlap, min_reciprocal_overlap) + }); + + if should_merge { + current_group.push(i); + } else { + // Finalize current group + if current_group.len() > 1 { + // Merge the group + let group_peaks: Vec<&Peak> = + current_group.iter().map(|&idx| &peaks[idx]).collect(); + result.push(merge_peak_group(&group_peaks)); + } else { + // Solo peak - keep as is + let idx = current_group[0]; + result.push(Peak { + chrom: peaks[idx].chrom.clone(), + start: peaks[idx].start, + end: peaks[idx].end, + score: peaks[idx].score, + fdr: peaks[idx].fdr, + pass_coverage: peaks[idx].pass_coverage, + peak_index: peaks[idx].peak_index, + pileup: peaks[idx].pileup, + }); + } + // Start new group + current_group = vec![i]; + } + } + + // Handle the last group + if current_group.len() > 1 { + let group_peaks: Vec<&Peak> = current_group.iter().map(|&idx| &peaks[idx]).collect(); + result.push(merge_peak_group(&group_peaks)); + } else { + let idx = current_group[0]; + result.push(Peak { + chrom: peaks[idx].chrom.clone(), + start: peaks[idx].start, + end: peaks[idx].end, + score: peaks[idx].score, + fdr: peaks[idx].fdr, + pass_coverage: peaks[idx].pass_coverage, + peak_index: peaks[idx].peak_index, + pileup: peaks[idx].pileup, + }); + } + + result +} + +/// Merge peaks iteratively using three-phase approach +/// Phase 1: High reciprocal overlap (default 90%, configurable via --high-reciprocal-overlap) +/// Phase 2: FIRE element overlap (default 50%, configurable via --min-frac-overlap) +/// Phase 3: Reciprocal overlap (default 75%, configurable via --min-reciprocal-overlap) +fn merge_peaks_iterative<'a>(mut peaks: Vec>, opts: &CallPeaksOptions) -> Vec> { + let initial_count = peaks.len(); + + // Phase 1: High reciprocal overlap + log::debug!( + " Phase 1: Merging peaks with reciprocal overlap >= {}", + opts.high_reciprocal_overlap + ); + for iteration in 0..opts.max_grouping_iterations { + let prev_count = peaks.len(); + peaks = merge_peaks_single_iteration(peaks, 0.0, opts.high_reciprocal_overlap); + log::debug!( + " Iteration {}: {} -> {} peaks", + iteration + 1, + prev_count, + peaks.len() + ); + if peaks.len() == prev_count { + log::debug!(" Phase 1 converged after {} iterations", iteration + 1); + break; + } + } + + // Phase 2: FIRE element overlap + log::debug!( + " Phase 2: Merging peaks with FIRE element overlap >= {}", + opts.min_frac_overlap + ); + for iteration in 0..opts.max_grouping_iterations { + let prev_count = peaks.len(); + peaks = merge_peaks_single_iteration(peaks, opts.min_frac_overlap, 0.0); + log::debug!( + " Iteration {}: {} -> {} peaks", + iteration + 1, + prev_count, + peaks.len() + ); + if peaks.len() == prev_count { + log::debug!(" Phase 2 converged after {} iterations", iteration + 1); + break; + } + } + + // Phase 3: High reciprocal overlap again + log::debug!( + " Phase 3: Merging peaks with reciprocal overlap >= {}", + opts.min_reciprocal_overlap + ); + for iteration in 0..opts.max_grouping_iterations { + let prev_count = peaks.len(); + peaks = merge_peaks_single_iteration(peaks, 0.0, opts.min_reciprocal_overlap); + log::debug!( + " Iteration {}: {} -> {} peaks", + iteration + 1, + prev_count, + peaks.len() + ); + if peaks.len() == prev_count { + log::debug!(" Phase 3 converged after {} iterations", iteration + 1); + break; + } + } + + let final_count = peaks.len(); + log::debug!( + " Merged {} peaks into {} peaks", + initial_count, + final_count + ); + + peaks +} + +/// Call peaks using FDR table or FIRE fraction filtering +/// +/// This will: +/// 1. Run pileup on real data +/// 2. Find local maxima +/// 3. Call peaks above threshold (FDR or FIRE fraction) +/// 4. Merge overlapping peaks +/// 5. Output results to BED file +pub fn call_peaks( + opts: &mut CallPeaksOptions, + bam: &mut rust_htslib::bam::IndexedReader, + header: &rust_htslib::bam::HeaderView, + fdr_table: &[FdrEntry], +) -> Result<()> { + if opts.min_fire_frac.is_some() { + log::info!("Calling peaks using FIRE fraction threshold"); + } else { + log::info!( + "Calling peaks using FDR table with {} entries", + fdr_table.len() + ); + + // Check if FDR table is empty (only when not using FIRE fraction mode) + if fdr_table.is_empty() { + anyhow::bail!( + "FDR table is empty. Cannot call peaks without an FDR table. \ + Please generate an FDR table first or provide a valid FDR table file." + ); + } + } + + // Open output file and write header + let mut writer = bio_io::writer(&opts.out)?; + writeln!(writer, "{}", Peak::header())?; + + let mut total_peaks_before_merge = 0; + let mut total_peaks_after_merge = 0; + + for (chrom, chrom_len) in chrom_names_and_lengths(header)? { + // Skip chromosomes with no fibers + if !super::chromosome_has_fibers(&chrom, bam, opts)? { + log::debug!("Skipping chromosome {} (no fibers)", chrom); + continue; + } + + log::debug!( + "Finding peaks on chromosome {} with length {}", + chrom, + chrom_len + ); + + // Create FiberseqPileupOptions for peak calling + let pileup_opts = FiberseqPileupOptions { + fire_track_opts: FireTrackOptions { + no_nuc: false, + no_msp: false, + m6a: false, + cpg: false, + fiber_coverage: true, + shuffle: false, + random_shuffle: false, + shuffle_seed: None, + rolling_max: Some(opts.window_size), + track_fire_elements: true, // Enable FIRE element tracking for peak calling + }, + rolling_max: Some(opts.window_size), + haps: false, + per_base: false, + keep_zeros: false, + min_fire_coverage: Some(opts.min_fire_coverage), + }; + + // Process fibers to build the track + let mut pileup = + FiberseqPileup::new(&chrom, 0, chrom_len as usize, pileup_opts, &None, None); + let fibers = opts.input.fetch_fibers(bam, &chrom, None, None)?; + pileup.add_fibers(fibers); + + // Calculate coverage thresholds + let (median, std_dev, _) = pileup.all_data.median_and_std_coverage(); + let min_cov = opts.min_cov.unwrap_or_else(|| { + let calculated_min = (median - opts.sd_cov * std_dev).round() as i32; + calculated_min.max(4) // DEFAULT_MIN_COVERAGE = 4 + }); + let max_cov = opts + .max_cov + .unwrap_or_else(|| (median + opts.sd_cov * std_dev).round() as i32); + + // Find local maxima and filter by threshold (FDR or FIRE fraction) + let peaks = Peak::from_pileup( + &pileup, + fdr_table, + opts.max_fdr, + opts.min_fire_frac, + opts.min_fire_frac_filter, + min_cov, + max_cov, + ); + + let peaks_before = peaks.len(); + total_peaks_before_merge += peaks_before; + + // Merge peaks for this chromosome + let merged_peaks = merge_peaks_iterative(peaks, opts); + total_peaks_after_merge += merged_peaks.len(); + + // Summary info statement per chromosome + log::info!( + "Peaks: {} ({} Mbp) - found: {}, merged: {}", + chrom, + chrom_len / 1_000_000, + peaks_before, + merged_peaks.len(), + ); + + // Write merged peaks for this chromosome + for peak in &merged_peaks { + writeln!(writer, "{}", peak)?; + } + } + + log::info!("Total peaks before merging: {}", total_peaks_before_merge); + log::info!("Total peaks after merging: {}", total_peaks_after_merge); + log::info!("Peaks written to {}", opts.out); + + Ok(()) +} diff --git a/src/subcommands/fire.rs b/src/subcommands/fire.rs index f4999c31..99aab79c 100644 --- a/src/subcommands/fire.rs +++ b/src/subcommands/fire.rs @@ -62,45 +62,15 @@ pub fn add_fire_to_bam(fire_opts: &mut FireOptions) -> Result<(), anyhow::Error> else { let mut out = fire_opts.input.bam_writer(&fire_opts.out); let fibers = fire_opts.input.fibers(&mut bam); - let mut skip_because_no_m6a = 0; - let mut skip_because_num_msp = 0; - let mut skip_because_ave_msp_length = 0; for recs in &fibers.chunks(2_000) { let mut recs: Vec = recs.collect(); recs.par_iter_mut().for_each(|r| { add_fire_to_rec(r, fire_opts, &model, &precision_table); }); for rec in recs { - let n_msps = rec.msp.annotations.len(); - if fire_opts.skip_no_m6a || fire_opts.min_msp > 0 || fire_opts.min_ave_msp_size > 0 - { - // skip no calls - if rec.m6a.annotations.is_empty() || n_msps == 0 { - skip_because_no_m6a += 1; - continue; - } - //let max_msp_len = *rec.msp.lengths.iter().flatten().max().unwrap_or(&0); - if n_msps < fire_opts.min_msp { - skip_because_num_msp += 1; - continue; - } - let ave_msp_size = rec.msp.lengths().iter().sum::() / n_msps as i64; - if ave_msp_size < fire_opts.min_ave_msp_size { - skip_because_ave_msp_length += 1; - continue; - } - } out.write(&rec.record)?; } } - log::info!( - "Skipped {} records because they had an average MSP length less than {}; {} records because they had fewer than {} MSPs; and {} records because they had no m6A sites", - skip_because_ave_msp_length, - fire_opts.min_ave_msp_size, - skip_because_num_msp, - fire_opts.min_msp, - skip_because_no_m6a, - ); } Ok(()) } diff --git a/src/subcommands/mock_fire.rs b/src/subcommands/mock_fire.rs new file mode 100644 index 00000000..af14512c --- /dev/null +++ b/src/subcommands/mock_fire.rs @@ -0,0 +1,202 @@ +use crate::cli::MockFireOptions; +use crate::utils::bio_io::{self, read_bed_regions, BedRecord}; +use anyhow::{Context, Result}; +use rust_htslib::bam::header::HeaderRecord; +use rust_htslib::bam::record::{Aux, Cigar, CigarString}; +use rust_htslib::bam::{Header, HeaderView, Record}; +use std::collections::HashMap; + +/// Group BED records by their name (4th column) to create mock reads +fn group_bed_by_name(bed_records: Vec) -> HashMap> { + let mut groups: HashMap> = HashMap::new(); + + for record in bed_records { + let name = record.get_name_or_default(); + groups.entry(name).or_default().push(record); + } + + // Sort each group by start position + for intervals in groups.values_mut() { + intervals.sort_by_key(|r| (r.chrom.clone(), r.start)); + } + + groups +} + +/// Create a BAM header from BED records +fn create_header_from_bed(bed_records: &[BedRecord]) -> Header { + let mut header = Header::new(); + + // Collect unique chromosomes and their max positions + let mut chrom_lengths: HashMap = HashMap::new(); + for record in bed_records { + let entry = chrom_lengths.entry(record.chrom.clone()).or_insert(0); + *entry = (*entry).max(record.end); + } + + // Sort chromosomes for consistent output + let mut chroms: Vec<_> = chrom_lengths.into_iter().collect(); + chroms.sort_by(|a, b| a.0.cmp(&b.0)); + + // Add SQ records + for (chrom, max_pos) in chroms { + let mut sq_record = HeaderRecord::new(b"SQ"); + sq_record.push_tag(b"SN", &chrom); + // Use a reasonable sequence length (max position + buffer) + let len_str = (max_pos + 10000).to_string(); + sq_record.push_tag(b"LN", &len_str); + header.push_record(&sq_record); + } + + header +} + +/// Create a mock BAM record with FIRE elements +fn create_mock_fire_record( + read_name: &str, + intervals: &[BedRecord], + header_view: &HeaderView, + quality: u8, + read_length: Option, +) -> Result { + if intervals.is_empty() { + return Err(anyhow::anyhow!("No intervals for read {}", read_name)); + } + + // All intervals should be on the same chromosome + let chrom = &intervals[0].chrom; + for interval in intervals { + if &interval.chrom != chrom { + return Err(anyhow::anyhow!( + "Read {} has intervals on multiple chromosomes ({} and {}). All FIRE elements for a read must be on the same chromosome.", + read_name, + chrom, + interval.chrom + )); + } + } + + let tid = header_view + .tid(chrom.as_bytes()) + .with_context(|| format!("Chromosome '{}' not found in header", chrom))?; + + // Calculate read boundaries + let read_start = intervals.iter().map(|i| i.start).min().unwrap(); + let read_end = intervals.iter().map(|i| i.end).max().unwrap(); + + // Use provided read length or calculate from intervals + let seq_len = read_length.unwrap_or(read_end - read_start) as usize; + + // Create sequence (all N's for mock data) + let seq = vec![b'N'; seq_len]; + let qual = vec![255u8; seq_len]; + + // Create CIGAR - simple match for the entire read + let cigar = CigarString(vec![Cigar::Equal(seq_len as u32)]); + + // Create the BAM record + let mut record = Record::new(); + record.set(read_name.as_bytes(), Some(&cigar), &seq, &qual); + record.set_tid(tid as i32); + record.set_pos(read_start); + record.set_mapq(60); + record.unset_paired(); + record.set_mtid(-1); + record.set_mpos(-1); + + // Add FIRE elements as MSP-like tags (as/al for starts/lengths, aq for quality) + // FIRE elements are stored on MSPs with quality scores in the aq tag + let mut starts: Vec = Vec::new(); + let mut lengths: Vec = Vec::new(); + let mut quals: Vec = Vec::new(); + + for interval in intervals { + // Convert to read-relative coordinates + let rel_start = (interval.start - read_start) as u32; + let length = (interval.end - interval.start) as u32; + + starts.push(rel_start); + lengths.push(length); + quals.push(quality); + } + + // Add the MSP start positions (as tag) + record + .push_aux(b"as", Aux::ArrayU32((&starts).into())) + .context("Failed to add 'as' tag (MSP starts)")?; + + // Add the MSP lengths (al tag) + record + .push_aux(b"al", Aux::ArrayU32((&lengths).into())) + .context("Failed to add 'al' tag (MSP lengths)")?; + + // Add the FIRE quality scores (aq tag) - this is what makes them FIRE elements + record + .push_aux(b"aq", Aux::ArrayU8((&quals).into())) + .context("Failed to add 'aq' tag (FIRE quality scores)")?; + + Ok(record) +} + +pub fn run_mock_fire(opts: &MockFireOptions) -> Result<()> { + log::info!("Reading BED file: {}", opts.bed); + + // Read BED file + let bed_records = read_bed_regions(&opts.bed).context("Failed to read BED file")?; + + if bed_records.is_empty() { + return Err(anyhow::anyhow!("BED file is empty")); + } + + log::info!("Read {} intervals from BED file", bed_records.len()); + + // Group intervals by read name (4th column) + let grouped = group_bed_by_name(bed_records.clone()); + log::info!( + "Grouped into {} mock reads based on 4th column", + grouped.len() + ); + + // Create BAM header + let header = create_header_from_bed(&bed_records); + let header_view = HeaderView::from_header(&header); + + // Create BAM writer + let mut writer = bio_io::program_bam_writer_from_header( + &opts.out, + header, + "fibertools-rs", + "ft", + crate::VERSION, + ); + + writer + .set_threads(opts.global.threads) + .context("Failed to set threads for BAM writer")?; + + if opts.uncompressed { + writer + .set_compression_level(rust_htslib::bam::CompressionLevel::Uncompressed) + .context("Failed to set uncompressed BAM")?; + } + + // Create and write mock records + let mut read_names: Vec<_> = grouped.keys().collect(); + read_names.sort(); + + for read_name in read_names { + let intervals = &grouped[read_name]; + let record = create_mock_fire_record( + read_name, + intervals, + &header_view, + opts.quality, + opts.read_length, + )?; + + bio_io::write_record(&mut writer, &record)?; + } + + log::info!("Mock BAM written to: {}", opts.out); + Ok(()) +} diff --git a/src/subcommands/pileup.rs b/src/subcommands/pileup.rs index 7aa96c30..52bb957a 100644 --- a/src/subcommands/pileup.rs +++ b/src/subcommands/pileup.rs @@ -7,17 +7,49 @@ use crate::utils::bamannotations; use crate::utils::bio_io; use crate::*; use anyhow::{anyhow, Ok}; -use std::collections::HashMap; -use std::io::BufRead; -//use polars::prelude::*; use ordered_float::NotNan; use rust_htslib::bam::ext::BamRecordExtensions; use rust_htslib::bam::{FetchDefinition, IndexedReader}; +use std::collections::HashMap; +use std::io::BufRead; const MIN_FIRE_COVERAGE: i32 = 4; const MIN_FIRE_QUAL: u8 = 229; // floor(255*0.9) static WINDOW_SIZE: usize = 1_000_000; +/// Options for FireTrack that don't require the full PileupOptions +/// This allows FireTrack to be used independently +#[derive(Debug, Clone, Default)] +pub struct FireTrackOptions { + pub no_nuc: bool, + pub no_msp: bool, + pub m6a: bool, + pub cpg: bool, + pub fiber_coverage: bool, + pub shuffle: bool, // Track if shuffling is enabled + pub random_shuffle: bool, // If true, generate random positions instead of using ShuffledFibers + pub shuffle_seed: Option, // Optional seed for reproducible random shuffling + pub rolling_max: Option, + pub track_fire_elements: bool, // If true, store individual FIRE element positions per base +} + +impl From<&PileupOptions> for FireTrackOptions { + fn from(opts: &PileupOptions) -> Self { + Self { + no_nuc: opts.no_nuc, + no_msp: opts.no_msp, + m6a: opts.m6a, + cpg: opts.cpg, + fiber_coverage: opts.effective_fiber_coverage(), + shuffle: opts.shuffle.is_some(), + random_shuffle: false, // PileupOptions doesn't have this yet + shuffle_seed: None, + rolling_max: opts.rolling_max, + track_fire_elements: false, // Default to false for pileup command + } + } +} + #[derive(Debug)] pub struct FireRow<'a> { pub coverage: &'a i32, @@ -27,17 +59,17 @@ pub struct FireRow<'a> { pub msp_coverage: &'a i32, pub cpg_coverage: &'a i32, pub m6a_coverage: &'a i32, - pileup_opts: &'a PileupOptions, + fire_track_opts: &'a FireTrackOptions, } impl PartialEq for FireRow<'_> { fn eq(&self, other: &Self) -> bool { - let m6a = if self.pileup_opts.m6a { + let m6a = if self.fire_track_opts.m6a { self.m6a_coverage == other.m6a_coverage } else { true }; - let cpg = if self.pileup_opts.cpg { + let cpg = if self.fire_track_opts.cpg { self.cpg_coverage == other.cpg_coverage } else { true @@ -59,22 +91,23 @@ impl std::fmt::Display for FireRow<'_> { "\t{}\t{}\t{}", self.coverage, self.fire_coverage, self.score ); - if !self.pileup_opts.no_nuc { + if !self.fire_track_opts.no_nuc { rtn += &format!("\t{}", self.nuc_coverage); } - if !self.pileup_opts.no_msp { + if !self.fire_track_opts.no_msp { rtn += &format!("\t{}", self.msp_coverage); } - if self.pileup_opts.m6a { + if self.fire_track_opts.m6a { rtn += &format!("\t{}", self.m6a_coverage); } - if self.pileup_opts.cpg { + if self.fire_track_opts.cpg { rtn += &format!("\t{}", self.cpg_coverage); } write!(f, "{rtn}") } } +#[derive(Debug)] pub struct ShuffledFibers { pub shuffled_fiber_starts: HashMap<(String, String, i64), i64>, } @@ -142,7 +175,54 @@ impl ShuffledFibers { } } +/// Generate a random shuffle offset for a fiber using uniform distribution +/// Uses deterministic PRNG seeded from fiber name + seed for reproducibility +fn generate_random_shuffle_offset( + fiber: &FiberseqData, + chrom_len: usize, + seed: Option, +) -> Option { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let fiber_len = fiber.record.reference_end() - fiber.record.reference_start(); + let original_start = fiber.record.reference_start(); + + // Check if fiber can fit in chromosome + let max_start = (chrom_len as i64 - fiber_len).max(0); + if max_start <= 0 { + return Some(0); // Fiber too long, keep at position 0 + } + + // Create deterministic seed from fiber name + optional seed + let mut hasher = DefaultHasher::new(); + fiber.get_qname().hash(&mut hasher); + if let Some(s) = seed { + s.hash(&mut hasher); + } + let fiber_seed = hasher.finish(); + + // Use StdRng for uniform distribution + let mut rng = StdRng::seed_from_u64(fiber_seed); + let shuffled_start = rng.gen_range(0..=max_start); + + // Return offset (shuffled_start - original_start) + Some(shuffled_start - original_start) +} + +/// Represents a single FIRE element (MSP) with its genomic coordinates and unique ID +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct FireElement { + pub start: i64, + pub end: i64, + pub id: usize, // Unique ID for this FIRE element (for tracking in merging) +} + +#[derive(Debug)] pub struct FireTrack<'a> { + pub chrom: String, pub chrom_start: usize, pub chrom_end: usize, pub track_len: usize, @@ -154,22 +234,40 @@ pub struct FireTrack<'a> { pub nuc_coverage: Vec, pub cpg_coverage: Vec, pub m6a_coverage: Vec, - pileup_opts: &'a PileupOptions, + fire_track_opts: FireTrackOptions, // Now owned, not borrowed shuffled_fibers: &'a Option, cur_offset: i64, + // Store fiber information for later shuffle generation + // Key: (fiber_name, original_start), Value: fiber_length + fibers_seen: HashMap<(String, i64), i64>, + // Optional: Store individual FIRE elements per position + // fire_elements[position] = Vec of FireElements overlapping that position + pub fire_elements: Option>>, + // Counter for assigning unique IDs to FIRE elements + next_fire_id: usize, } impl<'a> FireTrack<'a> { pub fn new( + chrom: String, chrom_start: usize, chrom_end: usize, - pileup_opts: &'a PileupOptions, + fire_track_opts: FireTrackOptions, // Take ownership shuffled_fibers: &'a Option, ) -> Self { let track_len = chrom_end - chrom_start + 1; let raw_scores = vec![-1.0; track_len]; let scores = vec![-1.0; track_len]; + + // Initialize fire_elements only if tracking is enabled + let fire_elements = if fire_track_opts.track_fire_elements { + Some(vec![Vec::new(); track_len]) + } else { + None + }; + Self { + chrom, chrom_start, chrom_end, track_len, @@ -181,13 +279,16 @@ impl<'a> FireTrack<'a> { nuc_coverage: vec![0; track_len], cpg_coverage: vec![0; track_len], m6a_coverage: vec![0; track_len], - pileup_opts, + fire_track_opts, shuffled_fibers, cur_offset: 0, + fibers_seen: HashMap::new(), + fire_elements, + next_fire_id: 0, } } - #[inline] + //#[inline] fn add_range_set( array: &mut [i32], ranges: &bamannotations::Ranges, @@ -217,7 +318,7 @@ impl<'a> FireTrack<'a> { } fn fiber_start_and_end(&self, fiber: &FiberseqData) -> (i64, i64) { - if !self.pileup_opts.fiber_coverage { + if !self.fire_track_opts.fiber_coverage { return ( fiber.record.reference_start() + self.cur_offset, fiber.record.reference_end() + self.cur_offset, @@ -260,25 +361,45 @@ impl<'a> FireTrack<'a> { (start + self.cur_offset, end + self.cur_offset) } - /// inline this function - #[inline] - fn update_with_fiber(&mut self, fiber: &FiberseqData) { + pub fn update_with_fiber(&mut self, fiber: &FiberseqData) { // skip this fiber if it has no MSP/NUC information // and we are looking at fiber_coverage - if self.pileup_opts.fiber_coverage + if self.fire_track_opts.fiber_coverage && fiber.msp.reference_starts().is_empty() && fiber.nuc.reference_starts().is_empty() { return; } + // Store fiber information for later shuffle generation (only for real, not shuffled) + if self.cur_offset == 0 && !self.fire_track_opts.shuffle { + let fiber_name = fiber.get_qname(); + let original_start = fiber.record.reference_start(); + let fiber_len = fiber.record.reference_end() - original_start; + self.fibers_seen + .insert((fiber_name, original_start), fiber_len); + } + // find the offset if we are shuffling data + // Priority: 1) shuffled_fibers from file, 2) random shuffle, 3) no shuffle self.cur_offset = match self.shuffled_fibers { - Some(shuffled_fibers) => match shuffled_fibers.get_shuffle_offset(fiber) { - Some(offset) => offset, - None => return, // skip missing fiber if it is not in the shuffle - }, - None => 0, + Some(shuffled_fibers) => { + // Use pre-computed shuffle from file + match shuffled_fibers.get_shuffle_offset(fiber) { + Some(offset) => offset, + None => return, // skip missing fiber if it is not in the shuffle + } + } + None if self.fire_track_opts.random_shuffle => { + // Generate random shuffle offset + generate_random_shuffle_offset( + fiber, + self.chrom_end, + self.fire_track_opts.shuffle_seed, + ) + .unwrap_or(0) + } + None => 0, // No shuffling }; if self.cur_offset != 0 && self.chrom_start != 0 { @@ -306,7 +427,25 @@ impl<'a> FireTrack<'a> { if annotation.qual < MIN_FIRE_QUAL { continue; } - let score_update = (1.0 - annotation.qual as f32 / 255.0).log10() * -50.0; + // Cap quality at 253 to avoid log10(0) issues, and cap score at 100 + let capped_qual = annotation.qual.min(253) as f32; + let score_update = ((1.0 - capped_qual / 255.0).log10() * -50.0).min(100.0); + + // If tracking FIRE elements, create a FireElement for this MSP + let fire_element = if self.fire_track_opts.track_fire_elements { + let elem_start = rs + self.cur_offset; + let elem_end = re + self.cur_offset; + let fire_id = self.next_fire_id; + self.next_fire_id += 1; + Some(FireElement { + start: elem_start, + end: elem_end, + id: fire_id, + }) + } else { + None + }; + for i in rs..re { let pos = i + self.cur_offset - self.chrom_start as i64; if pos < 0 || pos >= self.track_len as i64 { @@ -314,6 +453,13 @@ impl<'a> FireTrack<'a> { } self.fire_coverage[pos as usize] += 1; self.raw_scores[pos as usize] += score_update; + + // Store the FIRE element at this position if tracking is enabled + if let (Some(fire_elements), Some(elem)) = + (&mut self.fire_elements, fire_element) + { + fire_elements[pos as usize].push(elem); + } } } _ => continue, @@ -322,16 +468,16 @@ impl<'a> FireTrack<'a> { // add other sets of data to the FireTrack depending on CLI opts let mut pairs = vec![]; - if !self.pileup_opts.no_nuc { + if !self.fire_track_opts.no_nuc { pairs.push((&mut self.nuc_coverage, &fiber.nuc)); } - if !self.pileup_opts.no_msp { + if !self.fire_track_opts.no_msp { pairs.push((&mut self.msp_coverage, &fiber.msp)); } - if self.pileup_opts.m6a { + if self.fire_track_opts.m6a { pairs.push((&mut self.m6a_coverage, &fiber.m6a)); } - if self.pileup_opts.cpg { + if self.fire_track_opts.cpg { pairs.push((&mut self.cpg_coverage, &fiber.cpg)); } @@ -340,13 +486,12 @@ impl<'a> FireTrack<'a> { } } - pub fn calculate_scores(&mut self) { + pub fn calculate_scores(&mut self, min_fire_coverage: Option) { + let min_fire_coverage = min_fire_coverage.unwrap_or(MIN_FIRE_COVERAGE); for i in 0..self.track_len { if self.fire_coverage[i] <= 0 { self.scores[i] = -1.0; - } else if self.fire_coverage[i] < MIN_FIRE_COVERAGE - && self.pileup_opts.shuffle.is_none() - { + } else if self.fire_coverage[i] < min_fire_coverage && !self.fire_track_opts.shuffle { // there is no minimum fire coverage if we are shuffling self.scores[i] = -1.0; } else { @@ -357,7 +502,7 @@ impl<'a> FireTrack<'a> { pub fn calculate_rolling_max_score(&mut self) -> Vec { let mut rolling_max = vec![-1.0; self.track_len]; - let window_size = self.pileup_opts.rolling_max.unwrap(); + let window_size = self.fire_track_opts.rolling_max.unwrap(); let look_back = window_size / 2; for (i, cur_roll_max) in rolling_max.iter_mut().enumerate().take(self.track_len) { let start = i.saturating_sub(look_back); @@ -383,11 +528,166 @@ impl<'a> FireTrack<'a> { nuc_coverage: &self.nuc_coverage[i], cpg_coverage: &self.cpg_coverage[i], m6a_coverage: &self.m6a_coverage[i], - pileup_opts: self.pileup_opts, + fire_track_opts: &self.fire_track_opts, + } + } + + /// Calculate median coverage across the track + /// Returns (median_coverage, positions_with_coverage, positions_with_fire) + pub fn median_coverage(&self) -> (f64, usize) { + let mut coverages: Vec = self.coverage.iter().filter(|&&c| c > 0).copied().collect(); + + let positions_with_coverage = coverages.len(); + + if coverages.is_empty() { + return (0.0, 0); + } + + coverages.sort_unstable(); + + let median = if coverages.len() % 2 == 0 { + let mid = coverages.len() / 2; + (coverages[mid - 1] as f64 + coverages[mid] as f64) / 2.0 + } else { + coverages[coverages.len() / 2] as f64 + }; + + (median, positions_with_coverage) + } + + /// Calculate median and estimated standard deviation (sqrt of median for Poisson) + /// Returns (median, std_dev, positions_with_coverage) + pub fn median_and_std_coverage(&self) -> (f64, f64, usize) { + let (median, positions_with_coverage) = self.median_coverage(); + + // For sequencing data, we assume Poisson distribution where std_dev โ‰ˆ sqrt(median) + let std_dev = median.sqrt(); + + (median, std_dev, positions_with_coverage) + } + + /// Generate a ShuffledFibers HashMap from a list of fibers + /// This creates random shuffled positions for each fiber within the chromosome. + /// Uses the FireTrack's coverage to avoid placing shuffled fibers in regions with: + /// - Zero coverage (always avoided) + /// - Coverage below min_cov (if specified) + /// - Coverage above max_cov (if specified) + /// + /// Will retry up to 1000 times to find a valid position. + pub fn generate_shuffled_positions( + &self, + seed: Option, + min_cov: Option, + max_cov: Option, + ) -> ShuffledFibers { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut shuffled_fiber_starts = HashMap::new(); + let mut regenerated_count = 0; + + for ((fiber, original_start), fiber_len) in &self.fibers_seen { + let max_start = (self.track_len as i64 - fiber_len).max(0); + let coverage = self.coverage[*original_start as usize]; + + // filter by coverage constraints before attempting to shuffle + let has_valid_coverage = coverage > 0 + && min_cov.is_none_or(|min| coverage >= min) + && max_cov.is_none_or(|max| coverage <= max); + if !has_valid_coverage { + continue; + } + + // Create deterministic seed from fiber name + let mut hasher = DefaultHasher::new(); + fiber.hash(&mut hasher); + if let Some(s) = seed { + s.hash(&mut hasher); + } + let fiber_seed = hasher.finish(); + + // Generate random position, retrying up to 1000 times if coverage is invalid + let mut rng = StdRng::seed_from_u64(fiber_seed); + let mut shuffled_start = rng.gen_range(0..=max_start); + + // Try up to 1000 times to find a position with valid coverage + let mut attempts = 0; + while attempts < 1000 { + // Check if this position has valid coverage + // No bounds check needed: shuffled_start is guaranteed to be in [0, max_start] + // where max_start = track_len - fiber_len, so it's always valid + let cov = self.coverage[shuffled_start as usize]; + + // Check coverage constraints + let has_valid_coverage = cov > 0 + && min_cov.is_none_or(|min| cov >= min) + && max_cov.is_none_or(|max| cov <= max); + + if has_valid_coverage { + if attempts > 0 { + regenerated_count += 1; + } + break; + } + + // Regenerate position + shuffled_start = rng.gen_range(0..=max_start); + attempts += 1; + } + + // Store as (chrom, fiber_name, original_start) -> shuffled_start + let key = (self.chrom.clone(), fiber.clone(), *original_start); + shuffled_fiber_starts.insert(key, shuffled_start); + } + + log::debug!( + "Generated shuffle positions for {} fibers ({} regenerated for valid coverage [min={:?}, max={:?}])", + shuffled_fiber_starts.len(), + regenerated_count, + min_cov, + max_cov + ); + + ShuffledFibers { + shuffled_fiber_starts, + } + } +} + +/// Options needed for FiberseqPileup +/// This is a lightweight struct that only contains the options actually used by FiberseqPileup +#[derive(Debug, Clone)] +pub struct FiberseqPileupOptions { + /// Track options for FireTrack (coverage, marks, etc.) + pub fire_track_opts: FireTrackOptions, + /// Output rolling max of the score column over X bases + pub rolling_max: Option, + /// Include haplotype-specific tracks + pub haps: bool, + /// Write output one base at a time even if values don't change + pub per_base: bool, + /// Keep zero coverage regions + pub keep_zeros: bool, + /// Minimum FIRE coverage required to calculate a score (default: 4) + pub min_fire_coverage: Option, +} + +impl From<&PileupOptions> for FiberseqPileupOptions { + fn from(opts: &PileupOptions) -> Self { + Self { + fire_track_opts: FireTrackOptions::from(opts), + rolling_max: opts.rolling_max, + haps: opts.haps, + per_base: opts.per_base, + keep_zeros: opts.keep_zeros, + min_fire_coverage: None, // Use default } } } +#[derive(Debug)] pub struct FiberseqPileup<'a> { pub all_data: FireTrack<'a>, pub hap1_data: Option>, @@ -398,9 +698,11 @@ pub struct FiberseqPileup<'a> { pub chrom_end: usize, pub track_len: usize, has_data: bool, - pileup_opts: &'a PileupOptions, + pileup_opts: FiberseqPileupOptions, shuffled_fibers: &'a Option, - rolling_max: Option>, + pub rolling_max: Option>, + pub fdr_scores: Option>, + pub region_name: Option, } impl<'a> FiberseqPileup<'a> { @@ -408,25 +710,48 @@ impl<'a> FiberseqPileup<'a> { chrom: &str, chrom_start: usize, chrom_end: usize, - pileup_opts: &'a PileupOptions, + pileup_opts: FiberseqPileupOptions, shuffled_fibers: &'a Option, + region_name: Option, ) -> Self { let track_len = chrom_end - chrom_start + 1; - let all_data = FireTrack::new(chrom_start, chrom_end, pileup_opts, &None); + let fire_track_opts = pileup_opts.fire_track_opts.clone(); + let all_data = FireTrack::new( + chrom.to_string(), + chrom_start, + chrom_end, + fire_track_opts.clone(), + &None, + ); let (hap1_data, hap2_data) = if pileup_opts.haps { ( - Some(FireTrack::new(chrom_start, chrom_end, pileup_opts, &None)), - Some(FireTrack::new(chrom_start, chrom_end, pileup_opts, &None)), + Some(FireTrack::new( + chrom.to_string(), + chrom_start, + chrom_end, + fire_track_opts.clone(), + &None, + )), + Some(FireTrack::new( + chrom.to_string(), + chrom_start, + chrom_end, + fire_track_opts.clone(), + &None, + )), ) } else { (None, None) }; let shuffled_data = if shuffled_fibers.is_some() { + let mut shuffled_opts = fire_track_opts.clone(); + shuffled_opts.shuffle = true; Some(FireTrack::new( + chrom.to_string(), chrom_start, chrom_end, - pileup_opts, + shuffled_opts, shuffled_fibers, )) } else { @@ -446,6 +771,8 @@ impl<'a> FiberseqPileup<'a> { pileup_opts, shuffled_fibers, rolling_max: None, + fdr_scores: None, + region_name, } } @@ -453,58 +780,80 @@ impl<'a> FiberseqPileup<'a> { self.has_data } - pub fn add_records( - &mut self, - records: bam::Records<'a, IndexedReader>, - ) -> Result<(), anyhow::Error> { - self.pileup_opts - .input - .filters - .filter_on_bit_flags(records) - .chunks(1000) - .into_iter() - .map(|r| r.collect::>()) - .for_each(|r| { - let fibers: Vec = FiberseqData::from_records( - r, - &self.pileup_opts.input.header_view(), - &self.pileup_opts.input.filters, - ); - if !fibers.is_empty() { - self.has_data = true; - } - for fiber in fibers { - // skip if the fiber was unable to be shuffled - if self.shuffled_fibers.is_some() - && !self.shuffled_fibers.as_ref().unwrap().has_fiber(&fiber) - { - continue; - } + /// Calculate and store FDR scores for each position based on the FDR table + pub fn calculate_fdr_scores(&mut self, fdr_table: &[crate::subcommands::call_peaks::FdrEntry]) { + let mut fdr_scores = vec![1.0; self.track_len]; - self.all_data.update_with_fiber(&fiber); - // add hap1 data - if let Some(hap1_data) = &mut self.hap1_data { - if fiber.get_hp() == "H1" { - hap1_data.update_with_fiber(&fiber); - } - } - // add hap2 data - if let Some(hap2_data) = &mut self.hap2_data { - if fiber.get_hp() == "H2" { - hap2_data.update_with_fiber(&fiber); - } - } - // add shuffled data - if let Some(shuffled_data) = &mut self.shuffled_data { - shuffled_data.update_with_fiber(&fiber); + if fdr_table.is_empty() { + self.fdr_scores = Some(fdr_scores); + return; + } + + for (i, &score) in self.all_data.scores.iter().enumerate() { + if score < 0.0 { + // No coverage, FDR = 1.0 (already set) + continue; + } + + // Binary search to find the FDR for this score + let search_result = fdr_table.binary_search_by(|entry| { + entry + .threshold + .partial_cmp(&(score as f64)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + fdr_scores[i] = match search_result { + Result::Ok(idx) => fdr_table[idx].fdr, + Result::Err(idx) => { + if idx == 0 { + fdr_table[0].fdr + } else { + fdr_table[idx - 1].fdr } } - }); + }; + } + + self.fdr_scores = Some(fdr_scores); + } + + /// Add fibers from an iterator + /// This is more efficient than add_records as it works directly with FiberseqData + pub fn add_fibers(&mut self, fibers: impl Iterator) { + for fiber in fibers { + self.has_data = true; + + // skip if the fiber was unable to be shuffled + if self.shuffled_fibers.is_some() + && !self.shuffled_fibers.as_ref().unwrap().has_fiber(&fiber) + { + continue; + } + + self.all_data.update_with_fiber(&fiber); + // add hap1 data + if let Some(hap1_data) = &mut self.hap1_data { + if fiber.get_hp() == "H1" { + hap1_data.update_with_fiber(&fiber); + } + } + // add hap2 data + if let Some(hap2_data) = &mut self.hap2_data { + if fiber.get_hp() == "H2" { + hap2_data.update_with_fiber(&fiber); + } + } + // add shuffled data + if let Some(shuffled_data) = &mut self.shuffled_data { + shuffled_data.update_with_fiber(&fiber); + } + } + self.calculate_scores(); - Ok(()) } - pub fn header(pileup_opts: &PileupOptions) -> String { + pub fn header(pileup_opts: &PileupOptions, include_name: bool) -> String { let mut header = format!("{}\t{}\t{}", "#chrom", "start", "end"); let mut suffixes = vec![""]; @@ -537,25 +886,30 @@ impl<'a> FiberseqPileup<'a> { if pileup_opts.rolling_max.is_some() { header += "\trolling_max"; } + // Add name column at the end to minimize breaking downstream tools + if include_name { + header += "\tname"; + } header += "\n"; header } fn calculate_scores(&mut self) { - self.all_data.calculate_scores(); + self.all_data + .calculate_scores(self.pileup_opts.min_fire_coverage); // calculate rolling max if self.pileup_opts.rolling_max.is_some() { self.rolling_max = Some(self.all_data.calculate_rolling_max_score()); } // scores for other tracks if let Some(hap1_data) = &mut self.hap1_data { - hap1_data.calculate_scores(); + hap1_data.calculate_scores(self.pileup_opts.min_fire_coverage); } if let Some(hap2_data) = &mut self.hap2_data { - hap2_data.calculate_scores(); + hap2_data.calculate_scores(self.pileup_opts.min_fire_coverage); } if let Some(shuffled_data) = &mut self.shuffled_data { - shuffled_data.calculate_scores(); + shuffled_data.calculate_scores(self.pileup_opts.min_fire_coverage); } } @@ -652,6 +1006,10 @@ impl<'a> FiberseqPileup<'a> { self.rolling_max.as_ref().unwrap()[write_start_index] ); } + // Add name column at the end to minimize breaking downstream tools + if let Some(name) = &self.region_name { + line += &format!("\t{}", name); + } // don't write empty lines unless keep_zeros is set let mut cov = self.all_data.coverage[write_start_index]; if let Some(shuffled_data) = &self.shuffled_data { @@ -699,9 +1057,16 @@ fn run_rgn( out: &mut Box, pileup_opts: &PileupOptions, shuffled_fibers: &Option, + region_name: Option, ) -> Result<(), anyhow::Error> { - let tid = bam.header().tid(chrom.as_bytes()).unwrap(); - let chrom_len = bam.header().target_len(tid).unwrap() as i64; + let tid = bam.header().tid(chrom.as_bytes()).ok_or(anyhow::anyhow!( + "Chromosome {} not found in BAM header", + chrom + ))?; + let chrom_len = bam.header().target_len(tid).ok_or(anyhow::anyhow!( + "Chromosome {} length not found in BAM header", + chrom + ))? as i64; let window_size = if shuffled_fibers.is_some() { (chrom_len + 1) as usize @@ -719,26 +1084,28 @@ fn run_rgn( chrom_end = chrom_len; } - // check if region has data - bam.fetch((chrom, chrom_start, chrom_end))?; - let mut tmp_records = bam.records(); - if tmp_records.next().is_none() { - continue; - } - // fetch the data - bam.fetch((chrom, chrom_start, chrom_end))?; - let records = bam.records(); + // Fetch fibers from the region using the new iterator-based approach + let fiber_iter = + pileup_opts + .input + .fetch_fibers(bam, chrom, Some(chrom_start), Some(chrom_end))?; + // make the pileup log::debug!("Initializing pileup for {chrom}:{chrom_start}-{chrom_end}"); let mut pileup = FiberseqPileup::new( chrom, chrom_start as usize, chrom_end as usize, - pileup_opts, + pileup_opts.into(), shuffled_fibers, + region_name.clone(), ); - pileup.add_records(records)?; - pileup.write(out)?; + pileup.add_fibers(fiber_iter); + + // Only write if we have data + if pileup.has_data() { + pileup.write(out)?; + } } Ok(()) @@ -751,18 +1118,47 @@ pub fn pileup_track(pileup_opts: &mut PileupOptions) -> Result<(), anyhow::Error let header = pileup_opts.input.header_view(); let mut out = bio_io::writer(&pileup_opts.out)?; - // add the header - out.write_all(FiberseqPileup::header(pileup_opts).as_bytes())?; let shuffled_fibers = match &pileup_opts.shuffle { Some(file_path) => Some(ShuffledFibers::new(file_path)?), None => None, }; - match &pileup_opts.rgn { - // if a region is specified, only process that region - Some(rgn) => { - let (rgn, chrom) = region_parser(rgn); + // Handle regions based on source (BED file, command-line args, or all chromosomes) + // We process regions immediately rather than collecting them because FetchDefinition has lifetime constraints + if let Some(bed_path) = &pileup_opts.bed { + // Parse BED file + let bed_records = bio_io::read_bed_regions(bed_path)?; + let include_name = bed_records.iter().any(|r| r.name.is_some()); + + // add the header + out.write_all(FiberseqPileup::header(pileup_opts, include_name).as_bytes())?; + + // Process each BED record immediately + for rec in bed_records { + let fetch_def = FetchDefinition::RegionString(rec.chrom.as_bytes(), rec.start, rec.end); + // If any record has a name, use "." for records without names to keep column count consistent + let region_name = if include_name { + Some(rec.name.unwrap_or_else(|| ".".to_string())) + } else { + None + }; + run_rgn( + &rec.chrom, + fetch_def, + &mut bam, + &mut out, + pileup_opts, + &shuffled_fibers, + region_name, + )?; + } + } else if !pileup_opts.rgn.is_empty() { + // Use command-line regions + out.write_all(FiberseqPileup::header(pileup_opts, false).as_bytes())?; + + for rgn_str in &pileup_opts.rgn { + let (rgn, chrom) = region_parser(rgn_str); run_rgn( &chrom, rgn, @@ -770,22 +1166,27 @@ pub fn pileup_track(pileup_opts: &mut PileupOptions) -> Result<(), anyhow::Error &mut out, pileup_opts, &shuffled_fibers, + None, )?; } - // if no region is specified, process all regions - None => { - for chrom in header.target_names() { - let rgn = FetchDefinition::String(chrom); - run_rgn( - &String::from_utf8_lossy(chrom), - rgn, - &mut bam, - &mut out, - pileup_opts, - &shuffled_fibers, - )?; - } + } else { + // Process all chromosomes + out.write_all(FiberseqPileup::header(pileup_opts, false).as_bytes())?; + + for chrom in header.target_names() { + let chrom_str = String::from_utf8_lossy(chrom).to_string(); + let rgn = FetchDefinition::String(chrom); + run_rgn( + &chrom_str, + rgn, + &mut bam, + &mut out, + pileup_opts, + &shuffled_fibers, + None, + )?; } } + Ok(()) } diff --git a/src/utils/bamannotations.rs b/src/utils/bamannotations.rs index 57a55f22..c154bfd1 100644 --- a/src/utils/bamannotations.rs +++ b/src/utils/bamannotations.rs @@ -43,10 +43,23 @@ impl FiberAnnotations { /// starts and ends are [) intervals. pub fn new( + record: &bam::Record, + forward_starts: Vec, + forward_ends: Option>, + lengths: Option>, + ) -> Self { + Self::new_with_extras(record, forward_starts, forward_ends, lengths, None) + } + + /// Same as [`FiberAnnotations::new`], but also accepts forward-order + /// `extras` that get carried through the flip and sort so each + /// annotation keeps its original extra_columns. + pub fn new_with_extras( record: &bam::Record, mut forward_starts: Vec, forward_ends: Option>, mut lengths: Option>, + mut forward_extras: Option>>>, ) -> Self { let mut single_bp_liftover = false; // assume ends == starts if not provided @@ -72,6 +85,13 @@ impl FiberAnnotations { // get positions and lengths in reference orientation Self::positions_on_aligned_sequence(&mut forward_starts, is_reverse, seq_len); Self::positions_on_aligned_sequence(&mut forward_ends_inclusive, is_reverse, seq_len); + // Mirror the array reversal that positions_on_aligned_sequence applied + // so extras stay paired with their original peak. + if is_reverse { + if let Some(ref mut e) = forward_extras { + e.reverse(); + } + } let mut starts = forward_starts; let mut ends = forward_ends_inclusive; @@ -106,6 +126,22 @@ impl FiberAnnotations { panic!("Failed to lift query range: {}", e); }); + // Normalize extras into a Vec matching the other parallel vectors so + // it can be zipped in and move with the annotation through the sort. + let extras_iter: Vec>> = match forward_extras { + Some(e) => { + assert_eq!( + e.len(), + starts.len(), + "extras length ({}) must match annotation count ({})", + e.len(), + starts.len(), + ); + e + } + None => vec![None; starts.len()], + }; + // create annotations from parallel vectors let mut annotations: Vec = starts .into_iter() @@ -114,21 +150,26 @@ impl FiberAnnotations { .zip(reference_starts) .zip(reference_ends) .zip(reference_lengths) + .zip(extras_iter) .map( - |(((((start, end), length), ref_start), ref_end), ref_length)| FiberAnnotation { - start, - end, - length: length.unwrap_or(end - start), - qual: 0, - reference_start: ref_start, - reference_end: ref_end, - reference_length: ref_length, - extra_columns: None, + |((((((start, end), length), ref_start), ref_end), ref_length), extra)| { + FiberAnnotation { + start, + end, + length: length.unwrap_or(end - start), + qual: 0, + reference_start: ref_start, + reference_end: ref_end, + reference_length: ref_length, + extra_columns: extra, + } }, ) .collect(); - // Sort annotations by start position to ensure they are always in order + // Sort annotations by start position to ensure they are always in + // order. `sort_by_key` is stable, so extras stay paired with their + // annotation even when multiple peaks share a start. annotations.sort_by_key(|a| a.start); // return object @@ -418,43 +459,46 @@ impl FiberAnnotations { let forward_starts: Vec = start_values.iter().map(|&x| x as i64).collect(); let lengths: Vec = length_values.iter().map(|&x| x as i64).collect(); - // Get annotation tag for extra columns if specified - let annotation_values = if let Some(ann_tag) = annotation_tag { - if let Ok(rust_htslib::bam::record::Aux::String(ann_string)) = - record.aux(ann_tag) - { - Some(ann_string.split('|').collect::>()) - } else { - None - } - } else { - None + // Get annotation tag for extra columns if specified. Convert + // directly to the forward-order shape `FiberAnnotations` expects + // so extras travel through the flip+sort paired with their peak. + let forward_extras: Option>>> = match annotation_tag { + Some(ann_tag) => match record.aux(ann_tag) { + Ok(rust_htslib::bam::record::Aux::String(ann_string)) => { + let parts: Vec>> = ann_string + .split('|') + .map(|s| { + if s.is_empty() { + None + } else { + Some(s.split(';').map(|t| t.to_string()).collect()) + } + }) + .collect(); + if parts.len() != forward_starts.len() { + return Err(anyhow::anyhow!( + "Mismatched {} ({}) and {} ({}) array lengths", + String::from_utf8_lossy(ann_tag), + parts.len(), + String::from_utf8_lossy(start_tag), + forward_starts.len(), + )); + } + Some(parts) + } + _ => None, + }, + None => None, }; - // Use the existing FiberAnnotations::new method - let mut fiber_annotations = FiberAnnotations::new( + let fiber_annotations = FiberAnnotations::new_with_extras( record, forward_starts, None, // no ends provided Some(lengths), // use lengths from fl tag + forward_extras, ); - // Add extra columns to the annotations if present. - // FiberAnnotations::new already reversed starts/lengths on reverse-strand reads - // (fs/fl are flipped+array-reversed into aligned order), so fa must be reversed - // too to keep extras paired with their original peak. - if let Some(mut ann_vals) = annotation_values { - if record.is_reverse() { - ann_vals.reverse(); - } - for (i, annotation) in fiber_annotations.annotations.iter_mut().enumerate() { - if i < ann_vals.len() && !ann_vals[i].is_empty() { - annotation.extra_columns = - Some(ann_vals[i].split(';').map(|s| s.to_string()).collect()); - } - } - } - Ok(Some(fiber_annotations)) } else { Ok(None) // Tags exist but wrong type diff --git a/src/utils/bio_io.rs b/src/utils/bio_io.rs index 46f99748..d8f440e8 100644 --- a/src/utils/bio_io.rs +++ b/src/utils/bio_io.rs @@ -2,7 +2,7 @@ use anyhow::Result; use colored::Colorize; use gzp::deflate::Bgzf; //, Gzip, Mgzip, RawDeflate}; use gzp::{Compression, ZBuilder}; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use itertools::Itertools; use lazy_static::lazy_static; use linear_map::LinearMap; @@ -18,10 +18,9 @@ use std::collections::HashMap; use std::env; use std::ffi::OsStr; use std::fs::File; -use std::io::{self, stdout, BufReader, BufWriter, Write}; +use std::io::{self, stdout, BufRead, BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; use std::process::exit; -use std::time::Instant; const BUFFER_SIZE: usize = 32 * 1024; const COMPRESSION_THREADS: usize = 8; @@ -45,6 +44,10 @@ pub fn no_length_progress_bar() -> ProgressBar { .unwrap(); let bar = ProgressBar::new(0); bar.set_style(style); + + // Start hidden - will be shown on first tick/update + bar.set_draw_target(ProgressDrawTarget::hidden()); + let finish = indicatif::ProgressFinish::AndLeave; bar.with_finish(finish) } @@ -223,16 +226,22 @@ pub fn bam_reader(bam: &str) -> bam::Reader { } } // This is a bam chunk reader -pub struct BamChunk<'a> { - pub bam: bam::Records<'a, bam::Reader>, +pub struct BamChunk<'a, R> +where + R: bam::Read, +{ + pub bam: bam::Records<'a, R>, pub chunk_size: usize, pub pre_chunk_done: u64, pub bar: ProgressBar, pub bit_flag_filter: u16, } -impl<'a> BamChunk<'a> { - pub fn new(bam: bam::Records<'a, bam::Reader>, chunk_size: Option) -> Self { +impl<'a, R> BamChunk<'a, R> +where + R: bam::Read, +{ + pub fn new(bam: bam::Records<'a, R>, chunk_size: Option) -> Self { let chunk_size = std::cmp::min( chunk_size.unwrap_or_else(|| current_num_threads() * 100), 2_500, @@ -253,7 +262,10 @@ impl<'a> BamChunk<'a> { } // The `Iterator` trait only requires a method to be defined for the `next` element. -impl Iterator for BamChunk<'_> { +impl Iterator for BamChunk<'_, R> +where + R: bam::Read, +{ // We can refer to this type using Self::Item type Item = Vec; @@ -264,9 +276,14 @@ impl Iterator for BamChunk<'_> { // the type without having to update the function signatures. fn next(&mut self) -> Option { // update progress bar with results from previous iteration - self.bar.inc(self.pre_chunk_done); + if self.pre_chunk_done > 0 { + // Make progress bar visible on first actual update + if self.bar.is_hidden() { + self.bar.set_draw_target(ProgressDrawTarget::stderr()); + } + self.bar.inc(self.pre_chunk_done); + } - let start = Instant::now(); let mut cur_vec = vec![]; for r in self.bam.by_ref().take(self.chunk_size) { let r = r.unwrap(); @@ -295,14 +312,6 @@ impl Iterator for BamChunk<'_> { if cur_vec.is_empty() { None } else { - let duration = start.elapsed().as_secs_f64(); - log::debug!( - "Read {} bam records at {}.", - format!("{:}", cur_vec.len()).bright_magenta().bold(), - format!("{:.2?} reads/s", cur_vec.len() as f64 / duration) - .bright_cyan() - .bold(), - ); Some(cur_vec) } } @@ -533,3 +542,86 @@ pub fn convert_seq_uppercase(mut seq: Vec) -> Vec { } seq } + +/// A BED record with all fields preserved +#[derive(Debug, Clone)] +pub struct BedRecord { + pub chrom: String, + pub start: i64, + pub end: i64, + pub name: Option, + pub extra_fields: Vec, // Store all remaining fields (from column 4 onwards, or from 5 if name exists) +} + +impl BedRecord { + /// Get the name or a default region string + pub fn get_name_or_default(&self) -> String { + self.name + .clone() + .unwrap_or_else(|| format!("{}:{}-{}", self.chrom, self.start, self.end)) + } + + /// Reconstruct the original BED line (without the name column if it's separate) + pub fn to_bed_line(&self) -> String { + let mut fields = vec![ + self.chrom.clone(), + self.start.to_string(), + self.end.to_string(), + ]; + if let Some(name) = &self.name { + fields.push(name.clone()); + } + fields.extend(self.extra_fields.clone()); + fields.join("\t") + } +} + +/// Read BED file and parse regions with optional name column +/// Returns a vector of BedRecord structs containing chrom, start, end, name, and extra fields +pub fn read_bed_regions(bed_path: &str) -> Result> { + let reader = buffer_from(bed_path)?; + let mut regions = Vec::new(); + + for line in reader.lines() { + let line = line?; + if line.starts_with('#') || line.trim().is_empty() { + continue; + } + let tokens: Vec<&str> = line.split('\t').collect(); + if tokens.len() < 3 { + return Err(anyhow::anyhow!( + "BED file must have at least 3 columns (chrom, start, end)" + )); + } + + let chrom = tokens[0].to_string(); + let start = tokens[1] + .parse::() + .map_err(|_| anyhow::anyhow!("Invalid start position: {}", tokens[1]))?; + let end = tokens[2] + .parse::() + .map_err(|_| anyhow::anyhow!("Invalid end position: {}", tokens[2]))?; + + let (name, extra_start_idx) = if tokens.len() >= 4 && !tokens[3].is_empty() { + (Some(tokens[3].to_string()), 4) + } else { + (None, 3) + }; + + let extra_fields: Vec = tokens[extra_start_idx..] + .iter() + .map(|s| s.to_string()) + .collect(); + + regions.push(BedRecord { + chrom, + start, + end, + name, + extra_fields, + }); + } + + log::debug!("Read {} regions from BED file: {}", regions.len(), bed_path); + Ok(regions) +} diff --git a/src/utils/input_bam.rs b/src/utils/input_bam.rs index c7ab30c4..b8afa8bb 100644 --- a/src/utils/input_bam.rs +++ b/src/utils/input_bam.rs @@ -13,14 +13,14 @@ pub static MIN_ML_SCORE: &str = "125"; #[derive(Debug, Args, Clone)] pub struct FiberFilters { /// BAM bit flags to filter on, equivalent to `-F` in samtools view + /// Defaults to 0 (no filtering) #[clap( global = true, short = 'F', long = "filter", - default_value = "0", help_heading = "BAM-Options" )] - pub bit_flag: u16, + pub bit_flag: Option, /// Filtering expression to use for filtering records /// Example: filter to nucleosomes with lengths greater than 150 bp /// -x "len(nuc)>150" @@ -52,21 +52,118 @@ pub struct FiberFilters { hide = true )] pub strip_starting_basemods: i64, + /// Convenience: apply the FIRE peak-calling pipeline's fiber-level filters + /// (`--skip-no-m6a`, `--min-msp 10`, `--min-ave-msp-size 10`). Individual + /// filter flags still override when both are set. Requires MSP/m6A + /// annotations on the input BAM, so it is a no-op for commands that run + /// before those annotations exist. + #[clap(global = true, long, help_heading = "FIRE-Filter")] + pub fire_filter: bool, + /// Drop fibers with no m6A calls. Off by default; + /// `--fire-filter` turns this on unless explicitly set to `false`. + /// Use `--skip-no-m6a=false` to override when `--fire-filter` is set. + #[clap( + global = true, + long, + num_args = 0..=1, + default_missing_value = "true", + require_equals = true, + help_heading = "FIRE-Filter" + )] + pub skip_no_m6a: Option, + /// Drop fibers with fewer than `N` MSP calls. + /// Off (0) by default; `--fire-filter` sets this to 10 unless overridden. + #[clap(global = true, long, env = "MIN_MSP", help_heading = "FIRE-Filter")] + pub min_msp: Option, + /// Drop fibers whose average MSP size is below `N`. + /// Off (0) by default; `--fire-filter` sets this to 10 unless overridden. + #[clap( + global = true, + long, + env = "MIN_AVE_MSP_SIZE", + help_heading = "FIRE-Filter" + )] + pub min_ave_msp_size: Option, } impl std::default::Default for FiberFilters { fn default() -> Self { Self { - bit_flag: 0, + bit_flag: Some(0), min_ml_score: MIN_ML_SCORE.parse().unwrap(), filter_expression: None, uncompressed: false, strip_starting_basemods: 0, + fire_filter: false, + skip_no_m6a: None, + min_msp: None, + min_ave_msp_size: None, } } } impl FiberFilters { + /// Get the bit flag value, using a default if not explicitly set + pub fn get_bit_flag(&self) -> u16 { + self.bit_flag.unwrap_or(0) + } + + /// Resolved `--skip-no-m6a`, using `--fire-filter` as the fallback. + pub fn resolved_skip_no_m6a(&self) -> bool { + self.skip_no_m6a.unwrap_or(self.fire_filter) + } + + /// Resolved `--min-msp`, using `--fire-filter` (10) as the fallback. + pub fn resolved_min_msp(&self) -> usize { + self.min_msp + .unwrap_or(if self.fire_filter { 10 } else { 0 }) + } + + /// Resolved `--min-ave-msp-size`, using `--fire-filter` (10) as the fallback. + pub fn resolved_min_ave_msp_size(&self) -> i64 { + self.min_ave_msp_size + .unwrap_or(if self.fire_filter { 10 } else { 0 }) + } + + /// True if any FIRE fiber-level filter is active. + pub fn fire_filter_active(&self) -> bool { + self.resolved_skip_no_m6a() + || self.resolved_min_msp() > 0 + || self.resolved_min_ave_msp_size() > 0 + } + + /// True if `rec` passes the FIRE fiber-level filters (skip_no_m6a, + /// min_msp, min_ave_msp_size). Called by `FiberseqRecords::next` so + /// every downstream consumer sees only filtered fibers. + /// + /// Edge cases worth preserving: + /// - Fibers with zero MSPs are always rejected once any filter is + /// active. The check also guards the divide-by-zero in the average + /// MSP size below, so don't drop it when refactoring. + /// - The no-m6a rejection is gated on `resolved_skip_no_m6a()` so + /// `--skip-no-m6a=false` actually disables it (e.g. when combined + /// with `--fire-filter`). + pub fn passes_fire_filter(&self, rec: &crate::fiber::FiberseqData) -> bool { + if !self.fire_filter_active() { + return true; + } + let n_msps = rec.msp.annotations.len(); + if n_msps == 0 { + return false; + } + if self.resolved_skip_no_m6a() && rec.m6a.annotations.is_empty() { + return false; + } + if n_msps < self.resolved_min_msp() { + return false; + } + let ave_msp_size = rec.msp.lengths().iter().sum::() / n_msps as i64; + if ave_msp_size < self.resolved_min_ave_msp_size() { + return false; + } + true + } + /// This function accepts an iterator over bam records and filters them based on the bit flag. pub fn filter_on_bit_flags<'a, I>( &'a self, @@ -75,12 +172,14 @@ impl FiberFilters { where I: IntoIterator> + 'a, { + let bit_flag = self.get_bit_flag(); records .into_iter() .map(|r| r.expect("htslib is unable to read a record in the input.")) - .filter(|r| { + .filter(move |r| { // filter by bit flag - (r.flags() & self.bit_flag) == 0 + // `move` is needed to capture bit_flag value in the closure + (r.flags() & bit_flag) == 0 }) } } @@ -173,6 +272,37 @@ impl InputBam { *header = crate::utils::panspec::strip_pan_spec_header(header, &delimiter); } } + + /// Fetch fibers from a specific region with filters applied + /// Returns an iterator of FiberseqData records from the specified region + /// + /// # Arguments + /// * `bam` - Mutable reference to an IndexedReader + /// * `chrom` - Chromosome/contig name + /// * `start` - Optional start position (0-based) + /// * `end` - Optional end position (0-based, exclusive) + /// ``` + pub fn fetch_fibers<'a>( + &'a self, + bam: &'a mut bam::IndexedReader, + chrom: &str, + start: Option, + end: Option, + ) -> Result, rust_htslib::errors::Error> { + // Fetch the region + match (start, end) { + (Some(s), Some(e)) => bam.fetch((chrom, s, e))?, + (None, None) => bam.fetch(chrom.as_bytes())?, + _ => panic!("Both start and end must be specified, or neither"), + } + + // Create FiberseqRecords iterator from the fetched records + let records = bam.records(); + let header = self.header_view(); + let fiber_iter = FiberseqRecords::from_rec_iterator(records, header, self.filters.clone()); + + Ok(fiber_iter) + } } impl std::default::Default for InputBam { @@ -185,3 +315,151 @@ impl std::default::Default for InputBam { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::fiber::FiberseqData; + use crate::utils::bamannotations::{FiberAnnotation, FiberAnnotations}; + use crate::utils::basemods::BaseMods; + + /// Build a minimal `FiberseqData` shaped only for `passes_fire_filter`. + /// `msp_lengths` populates `rec.msp.annotations` (each entry contributes + /// to count and average size). `m6a_count` controls whether `m6a` + /// annotations are empty or present (only emptiness matters for the + /// filter). + fn make_fsd(msp_lengths: &[i64], m6a_count: usize) -> FiberseqData { + let msp_anns: Vec = msp_lengths + .iter() + .map(|&len| FiberAnnotation { + start: 0, + end: len, + length: len, + qual: 0, + reference_start: None, + reference_end: None, + reference_length: None, + extra_columns: None, + }) + .collect(); + let m6a_anns: Vec = (0..m6a_count) + .map(|i| FiberAnnotation { + start: i as i64, + end: i as i64 + 1, + length: 1, + qual: 0, + reference_start: None, + reference_end: None, + reference_length: None, + extra_columns: None, + }) + .collect(); + FiberseqData { + record: rust_htslib::bam::Record::new(), + msp: FiberAnnotations::from_annotations(msp_anns, 1000, false), + nuc: FiberAnnotations::from_annotations(vec![], 1000, false), + m6a: FiberAnnotations::from_annotations(m6a_anns, 1000, false), + cpg: FiberAnnotations::from_annotations(vec![], 1000, false), + base_mods: BaseMods { base_mods: vec![] }, + ec: 0.0, + target_name: ".".to_string(), + rg: ".".to_string(), + center_position: None, + } + } + + fn filters() -> FiberFilters { + FiberFilters::default() + } + + #[test] + fn passes_when_no_filter_active() { + // Default config: nothing rejected, even fibers with no m6a / no msps. + let f = filters(); + assert!(!f.fire_filter_active()); + assert!(f.passes_fire_filter(&make_fsd(&[], 0))); + assert!(f.passes_fire_filter(&make_fsd(&[100, 100, 100], 5))); + } + + #[test] + fn skip_no_m6a_rejects_only_no_m6a_fibers() { + let mut f = filters(); + f.skip_no_m6a = Some(true); + assert!(f.fire_filter_active()); + // No m6a โ†’ reject (and no-msp also rejected as a side effect). + assert!(!f.passes_fire_filter(&make_fsd(&[100], 0))); + // Has m6a, even with zero msps, still rejected by the empty-msp guard. + assert!(!f.passes_fire_filter(&make_fsd(&[], 3))); + // Has both โ†’ passes (other thresholds default to 0). + assert!(f.passes_fire_filter(&make_fsd(&[5], 1))); + } + + #[test] + fn min_msp_rejects_fibers_with_too_few_msps() { + let mut f = filters(); + f.min_msp = Some(3); + assert!(f.fire_filter_active()); + assert!(!f.passes_fire_filter(&make_fsd(&[100, 100], 5))); + assert!(f.passes_fire_filter(&make_fsd(&[100, 100, 100], 5))); + } + + #[test] + fn min_ave_msp_size_rejects_low_average() { + let mut f = filters(); + f.min_ave_msp_size = Some(50); + // Average = 30 โ†’ reject. + assert!(!f.passes_fire_filter(&make_fsd(&[10, 20, 60], 5))); + // Average = 60 โ†’ pass. + assert!(f.passes_fire_filter(&make_fsd(&[40, 60, 80], 5))); + } + + #[test] + fn fire_filter_combo_applies_all_three_defaults() { + // `--fire-filter` alone should imply skip_no_m6a + min_msp=10 + min_ave_msp_size=10. + let mut f = filters(); + f.fire_filter = true; + assert!(f.resolved_skip_no_m6a()); + assert_eq!(f.resolved_min_msp(), 10); + assert_eq!(f.resolved_min_ave_msp_size(), 10); + // <10 msps โ†’ reject. + let nine_long_msps: Vec = vec![100; 9]; + assert!(!f.passes_fire_filter(&make_fsd(&nine_long_msps, 5))); + // 10 msps but ave_size = 5 < 10 โ†’ reject. + let ten_short_msps: Vec = vec![5; 10]; + assert!(!f.passes_fire_filter(&make_fsd(&ten_short_msps, 5))); + // 10 msps, ave_size = 100, has m6a โ†’ pass. + let ten_long_msps: Vec = vec![100; 10]; + assert!(f.passes_fire_filter(&make_fsd(&ten_long_msps, 5))); + // Same fiber but no m6a โ†’ reject (skip_no_m6a is implied). + assert!(!f.passes_fire_filter(&make_fsd(&ten_long_msps, 0))); + } + + #[test] + fn explicit_flag_overrides_fire_filter_default() { + // `--fire-filter --min-msp=5` should use 5, not 10. + let mut f = filters(); + f.fire_filter = true; + f.min_msp = Some(5); + assert_eq!(f.resolved_min_msp(), 5); + // Other two still take fire-filter defaults. + assert!(f.resolved_skip_no_m6a()); + assert_eq!(f.resolved_min_ave_msp_size(), 10); + // 5 msps would have failed under the 10 default, but passes under 5. + let five_long_msps: Vec = vec![100; 5]; + assert!(f.passes_fire_filter(&make_fsd(&five_long_msps, 5))); + } + + #[test] + fn explicit_skip_no_m6a_false_overrides_fire_filter_default() { + // `--fire-filter --skip-no-m6a=false` keeps the size/count thresholds + // but should turn the no-m6a guard off. The empty-msp guard still + // applies (it's a divide-by-zero protection), so we use a record + // with msps but no m6a. + let mut f = filters(); + f.fire_filter = true; + f.skip_no_m6a = Some(false); + assert!(!f.resolved_skip_no_m6a()); + let ten_long_msps: Vec = vec![100; 10]; + assert!(f.passes_fire_filter(&make_fsd(&ten_long_msps, 0))); + } +} diff --git a/tests/call_peaks_test.rs b/tests/call_peaks_test.rs new file mode 100644 index 00000000..7f9e64a5 --- /dev/null +++ b/tests/call_peaks_test.rs @@ -0,0 +1,359 @@ +//! Unit tests for the `call_peaks` module's pure helpers and FDR machinery. +//! +//! These tests exercise public surface only โ€” `lookup_fdr`, +//! `reciprocal_overlap_raw`, `IncrementalFdrBuilder`, and the FDR-table +//! read/write helpers. The geometry-only `reciprocal_overlap_raw` and the +//! score-only `lookup_fdr` were extracted from `Peak` methods so they can +//! be tested without constructing a `FiberseqPileup`. + +use fibertools_rs::subcommands::call_peaks::{ + lookup_fdr, read_fdr_table, reciprocal_overlap_raw, write_fdr_table, FdrEntry, + IncrementalFdrBuilder, PileupRecord, +}; +use std::io::Write; +use tempfile::NamedTempFile; + +fn approx_eq(a: f64, b: f64) -> bool { + (a - b).abs() < 1e-9 +} + +fn fdr_entry(threshold: f64, fdr: f64, shuffled_bp: f64, real_bp: f64) -> FdrEntry { + FdrEntry { + threshold, + fdr, + shuffled_bp, + real_bp, + } +} + +fn pileup_record(start: u64, end: u64, score: f64) -> PileupRecord { + PileupRecord { + start, + end, + coverage: 1, + fire_coverage: 1, + score, + } +} + +// --------------------------------------------------------------------------- +// lookup_fdr +// --------------------------------------------------------------------------- + +#[test] +fn lookup_fdr_empty_table_returns_one() { + let table: Vec = vec![]; + assert_eq!(lookup_fdr(0.0, &table), 1.0); + assert_eq!(lookup_fdr(99.0, &table), 1.0); + assert_eq!(lookup_fdr(-5.0, &table), 1.0); +} + +#[test] +fn lookup_fdr_score_below_all_thresholds_returns_first_entry_fdr() { + let table = vec![ + fdr_entry(1.0, 0.5, 0.0, 0.0), + fdr_entry(5.0, 0.1, 0.0, 0.0), + fdr_entry(10.0, 0.01, 0.0, 0.0), + ]; + // Score below the smallest threshold returns the first entry's FDR. + assert!(approx_eq(lookup_fdr(0.5, &table), 0.5)); + assert!(approx_eq(lookup_fdr(-100.0, &table), 0.5)); +} + +#[test] +fn lookup_fdr_score_above_all_thresholds_returns_last_entry_fdr() { + let table = vec![ + fdr_entry(1.0, 0.5, 0.0, 0.0), + fdr_entry(5.0, 0.1, 0.0, 0.0), + fdr_entry(10.0, 0.01, 0.0, 0.0), + ]; + assert!(approx_eq(lookup_fdr(100.0, &table), 0.01)); +} + +#[test] +fn lookup_fdr_exact_threshold_match_returns_that_fdr() { + let table = vec![ + fdr_entry(1.0, 0.5, 0.0, 0.0), + fdr_entry(5.0, 0.1, 0.0, 0.0), + fdr_entry(10.0, 0.01, 0.0, 0.0), + ]; + assert!(approx_eq(lookup_fdr(1.0, &table), 0.5)); + assert!(approx_eq(lookup_fdr(5.0, &table), 0.1)); + assert!(approx_eq(lookup_fdr(10.0, &table), 0.01)); +} + +#[test] +fn lookup_fdr_score_between_thresholds_returns_largest_le_score() { + let table = vec![ + fdr_entry(1.0, 0.5, 0.0, 0.0), + fdr_entry(5.0, 0.1, 0.0, 0.0), + fdr_entry(10.0, 0.01, 0.0, 0.0), + ]; + // Between 1.0 and 5.0 โ†’ uses the 1.0 entry. + assert!(approx_eq(lookup_fdr(3.0, &table), 0.5)); + assert!(approx_eq(lookup_fdr(4.99, &table), 0.5)); + // Between 5.0 and 10.0 โ†’ uses the 5.0 entry. + assert!(approx_eq(lookup_fdr(7.0, &table), 0.1)); + assert!(approx_eq(lookup_fdr(9.99, &table), 0.1)); +} + +#[test] +fn lookup_fdr_is_monotone_nonincreasing_for_decreasing_fdr_table() { + // Real FDR tables are sorted ascending by threshold and FDR is generally + // non-increasing as threshold rises. lookup_fdr should preserve this: + // higher score โ†’ equal-or-lower FDR. + let table = vec![ + fdr_entry(0.0, 0.9, 0.0, 0.0), + fdr_entry(1.0, 0.5, 0.0, 0.0), + fdr_entry(2.5, 0.2, 0.0, 0.0), + fdr_entry(5.0, 0.05, 0.0, 0.0), + fdr_entry(10.0, 0.001, 0.0, 0.0), + ]; + let mut prev = f64::INFINITY; + for s in [-1.0, 0.0, 0.5, 1.0, 2.0, 2.5, 3.7, 5.0, 8.0, 10.0, 50.0] { + let f = lookup_fdr(s as f32, &table); + assert!( + f <= prev + 1e-12, + "non-monotonic at score {s}: {f} > {prev}" + ); + prev = f; + } +} + +// --------------------------------------------------------------------------- +// reciprocal_overlap_raw +// --------------------------------------------------------------------------- + +#[test] +fn reciprocal_overlap_different_chromosomes_is_zero() { + assert_eq!(reciprocal_overlap_raw("chr1", 0, 100, "chr2", 0, 100), 0.0); +} + +#[test] +fn reciprocal_overlap_disjoint_intervals_is_zero() { + assert_eq!( + reciprocal_overlap_raw("chr1", 0, 100, "chr1", 200, 300), + 0.0 + ); +} + +#[test] +fn reciprocal_overlap_touching_boundaries_is_zero() { + // BED-style half-open: end == other.start means no overlapping bases. + assert_eq!( + reciprocal_overlap_raw("chr1", 0, 100, "chr1", 100, 200), + 0.0 + ); +} + +#[test] +fn reciprocal_overlap_identical_intervals_is_one() { + assert!(approx_eq( + reciprocal_overlap_raw("chr1", 100, 200, "chr1", 100, 200), + 1.0 + )); +} + +#[test] +fn reciprocal_overlap_full_containment_returns_smaller_over_larger() { + // [100, 200] contains [120, 180]: overlap = 60, smaller_len = 60, larger_len = 100. + // Reciprocal overlap = min(60/100, 60/60) = 0.6. + assert!(approx_eq( + reciprocal_overlap_raw("chr1", 100, 200, "chr1", 120, 180), + 0.6 + )); +} + +#[test] +fn reciprocal_overlap_partial_overlap_returns_min_fraction() { + // [0, 100] and [50, 200]: overlap = 50. + // a_frac = 50/100 = 0.5, b_frac = 50/150 โ‰ˆ 0.333; min โ‰ˆ 0.333. + let v = reciprocal_overlap_raw("chr1", 0, 100, "chr1", 50, 200); + assert!(approx_eq(v, 50.0 / 150.0)); +} + +#[test] +fn reciprocal_overlap_is_symmetric() { + let a = reciprocal_overlap_raw("chr1", 0, 100, "chr1", 50, 200); + let b = reciprocal_overlap_raw("chr1", 50, 200, "chr1", 0, 100); + assert!(approx_eq(a, b)); +} + +#[test] +fn reciprocal_overlap_fifty_percent_each_side() { + // [0, 100] and [50, 150]: overlap = 50, both lens = 100. + assert!(approx_eq( + reciprocal_overlap_raw("chr1", 0, 100, "chr1", 50, 150), + 0.5 + )); +} + +// --------------------------------------------------------------------------- +// IncrementalFdrBuilder +// --------------------------------------------------------------------------- + +#[test] +fn fdr_builder_single_chromosome_yields_expected_table() { + // One real record (score=1.0, bp=200) and one shuffled record (score=0.5, bp=100). + // After fdr_from_fire_scores in descending-by-score order: + // First we push when score changes from 1.0 โ†’ 0.5 with cur_r=200, cur_v=0. + // Then we accumulate cur_v=100, then push the trailing sentinel + // (1.0 shuffled, 1.0 real, threshold=-1.0). + // FDRs: 0/200 = 0.0 and 1/1 = 1.0. + let real = vec![pileup_record(0, 200, 1.0)]; + let shuffled = vec![pileup_record(0, 100, 0.5)]; + + let mut builder = IncrementalFdrBuilder::new(); + builder.add_chromosome_data(&real, &shuffled); + let table = builder.build(0.5).expect("FDR build should succeed"); + + // Expect two entries sorted ascending by threshold: -1.00 and 1.00. + assert_eq!(table.len(), 2, "table = {:?}", table); + assert!(approx_eq(table[0].threshold, -1.0)); + assert!(approx_eq(table[0].fdr, 1.0)); + assert!(approx_eq(table[1].threshold, 1.0)); + assert!(approx_eq(table[1].fdr, 0.0)); +} + +#[test] +fn fdr_builder_bails_when_no_threshold_below_max_fdr() { + // Only shuffled data โ†’ every FDR is 1.0, so any max_fdr < 1.0 must error. + let real: Vec = vec![]; + let shuffled = vec![pileup_record(0, 100, 1.0)]; + + let mut builder = IncrementalFdrBuilder::new(); + builder.add_chromosome_data(&real, &shuffled); + let err = builder + .build(0.5) + .expect_err("should bail with no real data"); + let msg = format!("{err}"); + assert!( + msg.contains("FDR"), + "expected error message to mention FDR, got: {msg}" + ); +} + +#[test] +fn fdr_builder_multi_chromosome_additivity_matches_single_call() { + // Splitting the same input across two add_chromosome_data calls must + // produce the same final table as feeding it all at once. The builder + // aggregates by score key so per-chromosome boundaries should not matter. + let real_a = vec![pileup_record(0, 200, 1.0), pileup_record(0, 50, 2.0)]; + let shuffled_a = vec![pileup_record(0, 100, 0.5)]; + let real_b = vec![pileup_record(0, 75, 1.0)]; + let shuffled_b = vec![pileup_record(0, 25, 0.5), pileup_record(0, 60, 1.0)]; + + let mut combined = IncrementalFdrBuilder::new(); + let real_all: Vec<_> = real_a.iter().chain(real_b.iter()).cloned().collect(); + let shuffled_all: Vec<_> = shuffled_a + .iter() + .chain(shuffled_b.iter()) + .cloned() + .collect(); + combined.add_chromosome_data(&real_all, &shuffled_all); + let table_combined = combined.build(0.99).expect("build should succeed"); + + let mut split = IncrementalFdrBuilder::new(); + split.add_chromosome_data(&real_a, &shuffled_a); + split.add_chromosome_data(&real_b, &shuffled_b); + let table_split = split.build(0.99).expect("build should succeed"); + + assert_eq!( + table_combined.len(), + table_split.len(), + "combined={:?}\nsplit={:?}", + table_combined, + table_split + ); + for (c, s) in table_combined.iter().zip(table_split.iter()) { + assert!(approx_eq(c.threshold, s.threshold)); + assert!(approx_eq(c.fdr, s.fdr)); + assert!(approx_eq(c.shuffled_bp, s.shuffled_bp)); + assert!(approx_eq(c.real_bp, s.real_bp)); + } +} + +#[test] +fn fdr_builder_output_is_sorted_ascending_by_threshold() { + let real = vec![ + pileup_record(0, 200, 5.0), + pileup_record(0, 100, 2.0), + pileup_record(0, 50, 1.0), + ]; + let shuffled = vec![pileup_record(0, 100, 4.0), pileup_record(0, 100, 0.5)]; + + let mut builder = IncrementalFdrBuilder::new(); + builder.add_chromosome_data(&real, &shuffled); + let table = builder.build(0.99).expect("build should succeed"); + + let mut prev = f64::NEG_INFINITY; + for entry in &table { + assert!( + entry.threshold > prev || approx_eq(entry.threshold, prev), + "thresholds not ascending: {} after {}", + entry.threshold, + prev + ); + prev = entry.threshold; + } +} + +// --------------------------------------------------------------------------- +// read_fdr_table / write_fdr_table round trip +// --------------------------------------------------------------------------- + +#[test] +fn fdr_table_round_trip_via_tempfile() { + let original = vec![ + fdr_entry(-1.0, 1.0, 1.0, 1.0), + fdr_entry(0.5, 0.25, 100.0, 400.0), + fdr_entry(1.0, 0.05, 50.0, 1000.0), + ]; + + let tmp = NamedTempFile::new().expect("tempfile"); + let path = tmp.path().to_str().unwrap().to_string(); + write_fdr_table(&original, &path).expect("write"); + let round_tripped = read_fdr_table(&path).expect("read"); + + assert_eq!(round_tripped.len(), original.len()); + for (a, b) in round_tripped.iter().zip(original.iter()) { + // write_fdr_table formats threshold/fdr/bp at 2/6/0 decimal places. + assert!(approx_eq(a.threshold, b.threshold)); + assert!((a.fdr - b.fdr).abs() < 1e-6); + assert!(approx_eq(a.shuffled_bp, b.shuffled_bp)); + assert!(approx_eq(a.real_bp, b.real_bp)); + } +} + +#[test] +fn read_fdr_table_rejects_wrong_column_count() { + let mut tmp = NamedTempFile::new().expect("tempfile"); + writeln!(tmp, "threshold\tFDR\tshuffled_bp\treal_bp").unwrap(); + writeln!(tmp, "1.0\t0.5\t10").unwrap(); // only 3 columns + tmp.flush().unwrap(); + let path = tmp.path().to_str().unwrap(); + + let err = read_fdr_table(path).expect_err("should reject malformed line"); + let msg = format!("{err}"); + assert!( + msg.contains("4 columns") || msg.contains("Invalid FDR table format"), + "expected column-count error, got: {msg}" + ); +} + +#[test] +fn read_fdr_table_returns_entries_sorted_ascending_by_threshold() { + // write a table with entries in non-ascending order; read should return them sorted. + let mut tmp = NamedTempFile::new().expect("tempfile"); + writeln!(tmp, "threshold\tFDR\tshuffled_bp\treal_bp").unwrap(); + writeln!(tmp, "5.00\t0.010000\t10\t1000").unwrap(); + writeln!(tmp, "1.00\t0.500000\t100\t200").unwrap(); + writeln!(tmp, "-1.00\t1.000000\t1\t1").unwrap(); + tmp.flush().unwrap(); + let path = tmp.path().to_str().unwrap(); + + let table = read_fdr_table(path).expect("read"); + assert_eq!(table.len(), 3); + assert!(table[0].threshold < table[1].threshold); + assert!(table[1].threshold < table[2].threshold); +} diff --git a/tests/coordinate_lift_test.rs b/tests/coordinate_lift_test.rs index 586b7684..248482f9 100644 --- a/tests/coordinate_lift_test.rs +++ b/tests/coordinate_lift_test.rs @@ -6,7 +6,7 @@ use rust_htslib::bam::Read; fn test_coordinate_lift() -> anyhow::Result<()> { let mut bam = bam::Reader::from_path("tests/data/nuc_example.bam")?; - for result in bam.records() { + if let Some(result) = bam.records().next() { let record = result?; // Print basic record info @@ -67,8 +67,6 @@ fn test_coordinate_lift() -> anyhow::Result<()> { } Err(e) => println!("Error lifting coordinates: {}", e), } - - break; // Only process first record } Ok(()) diff --git a/tests/fibertig_test.rs b/tests/fibertig_test.rs index 9d505ec8..e811a24c 100644 --- a/tests/fibertig_test.rs +++ b/tests/fibertig_test.rs @@ -583,3 +583,63 @@ fn test_from_bam_tags_fa_pairing_reverse_strand_multi_peak() -> Result<()> { } Ok(()) } + +#[test] +fn test_from_bam_tags_fa_pairing_reverse_strand_shared_start() -> Result<()> { + // Regression test for Anna's fibertig peak-length scrambling: two peaks share + // forward-strand start position 0 (fs=0), and all four peaks overlap heavily. + // After flipping to aligned coordinates the sort by aligned start is no longer + // equivalent to reversing the input order, so a naive `ann_vals.reverse()` + // breaks the extras pairing. The extras must be carried through the sort. + let seq_len: u32 = 1_000; + let fs: Vec = vec![0, 0, 12, 159]; + let fl: Vec = vec![263, 124, 209, 305]; + let fa = "peakA|peakB|peakC|peakD"; + + let record = build_tagged_record(seq_len, &fs, &fl, fa, true); + + let anns = FiberAnnotations::from_bam_tags(&record, b"fs", b"fl", Some(b"fa"))? + .expect("tags should parse"); + + assert_eq!(anns.annotations.len(), 4); + assert!(anns.reverse); + + // forward peak [s, s+l) -> ref peak [seq_len - (s+l), seq_len - s) + // peakA forward [0, 263) -> ref [737, 1000), len 263 + // peakB forward [0, 124) -> ref [876, 1000), len 124 + // peakC forward [12, 221) -> ref [779, 988), len 209 + // peakD forward [159, 464) -> ref [536, 841), len 305 + // Ref-start ascending: D, A, C, B. + let expected = [ + (536_i64, 841_i64, 305_i64, "peakD"), + (737, 1_000, 263, "peakA"), + (779, 988, 209, "peakC"), + (876, 1_000, 124, "peakB"), + ]; + for (i, (ann, (exp_start, exp_end, exp_len, exp_tag))) in + anns.annotations.iter().zip(expected.iter()).enumerate() + { + assert_eq!( + ann.reference_start, + Some(*exp_start), + "ref_start mismatch at index {i}" + ); + assert_eq!( + ann.reference_end, + Some(*exp_end), + "ref_end mismatch at index {i}" + ); + assert_eq!( + ann.reference_length, + Some(*exp_len), + "ref_length mismatch at index {i}" + ); + assert_eq!( + ann.extra_columns.as_ref().map(|c| c.join(";")), + Some((*exp_tag).to_string()), + "fa-extras pairing wrong at index {i} (size {exp_len}): got {:?}", + ann.extra_columns, + ); + } + Ok(()) +}