diff --git a/rust/src/analysis.rs b/rust/src/analysis.rs index f6f91f0..1e9963d 100644 --- a/rust/src/analysis.rs +++ b/rust/src/analysis.rs @@ -136,6 +136,29 @@ fn deduplicate_variants(variants: Vec, phased: bool) -> Vec bool { + if !requested_phased { + return false; + } + + let mut saw_ref_alt = false; + let mut saw_alt_ref = false; + for variant in variants { + match variant.gt { + Some(0) => saw_ref_alt = true, + Some(1) => saw_alt_ref = true, + _ => {} + } + } + + if saw_ref_alt && saw_alt_ref { + true + } else { + eprintln!("No informative numeric phased GT found: switching to unphased model"); + false + } +} + /// Sigmoid function (inverse logit): expit(x) = 1 / (1 + e^-x) /// /// Numerically stable implementation that avoids overflow for large |x|. @@ -724,7 +747,7 @@ pub fn single_model(variants: Vec, phased: bool) -> Result 1 { + } else if indices.len() > 1 { // Unphased multi-SNP region: marginalize over phase via DP, // minimizing opt_unphased_dp over prob in (0, 1) // (first = index 0, phase = remaining), matching Python's @@ -875,7 +898,7 @@ pub fn linear_model(variants: Vec, phased: bool) -> Result 1 { + } else if indices.len() > 1 { // Unphased multi-SNP region: marginalize over phase via DP, // minimizing opt_unphased_dp over prob in (0, 1) // (first = index 0, phase = remaining), matching Python's @@ -1165,12 +1188,14 @@ pub fn analyze_imbalance( eprintln!("Processing {} variants after filtering", filtered.len()); + let phased = effective_phased(&filtered, config.phased); + // Deduplicate variants for Python parity (single and linear models only). // Python's get_imbalance() does: df[keep_cols].drop_duplicates() // Per-donor model skips dedup as it's a Rust-only feature. let deduped = if config.method != AnalysisMethod::PerDonor { let before = filtered.len(); - let after = deduplicate_variants(filtered, config.phased); + let after = deduplicate_variants(filtered, phased); if after.len() < before { eprintln!( "Deduplicated {} → {} variants ({} duplicates removed)", @@ -1186,9 +1211,9 @@ pub fn analyze_imbalance( // Run analysis based on method let mut results = match config.method { - AnalysisMethod::Single => single_model(deduped, config.phased)?, - AnalysisMethod::Linear => linear_model(deduped, config.phased)?, - AnalysisMethod::PerDonor => per_donor_model(deduped, config.phased)?, + AnalysisMethod::Single => single_model(deduped, phased)?, + AnalysisMethod::Linear => linear_model(deduped, phased)?, + AnalysisMethod::PerDonor => per_donor_model(deduped, phased)?, }; // Remove pseudocounts from results @@ -1342,6 +1367,77 @@ mod tests { assert!(result_imbalanced.is_finite()); } + fn vc_peak(pos: u32, rc: u32, ac: u32, gt: Option) -> VariantCounts { + VariantCounts { + chrom: "chr1".to_string(), + pos, + ref_count: rc, + alt_count: ac, + region: "peak1".to_string(), + sample: Some("d1".to_string()), + gt, + } + } + + fn assert_result_close(a: &ImbalanceResult, b: &ImbalanceResult) { + assert_eq!(a.region, b.region); + assert_eq!(a.ref_count, b.ref_count); + assert_eq!(a.alt_count, b.alt_count); + assert_eq!(a.n, b.n); + assert!((a.null_ll - b.null_ll).abs() < 1e-10); + assert!((a.alt_ll - b.alt_ll).abs() < 1e-10); + assert!((a.mu - b.mu).abs() < 1e-10); + assert!((a.lrt - b.lrt).abs() < 1e-10); + assert!((a.pval - b.pval).abs() < 1e-10); + } + + #[test] + fn test_effective_phased_requires_both_numeric_orientations() { + let both = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 5, 15, Some(1))]; + assert!(effective_phased(&both, true)); + assert!(!effective_phased(&both, false)); + + let none = vec![vc_peak(100, 18, 4, None), vc_peak(200, 5, 15, None)]; + assert!(!effective_phased(&none, true)); + + let same = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, Some(0))]; + assert!(!effective_phased(&same, true)); + + let partial = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, None)]; + assert!(!effective_phased(&partial, true)); + } + + #[test] + fn test_analyze_imbalance_all_same_numeric_gt_falls_back_to_unphased() { + let variants = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, Some(0))]; + let phased_cfg = AnalysisConfig { + min_count: 0, + pseudocount: 0, + method: AnalysisMethod::Single, + phased: true, + }; + let unphased_cfg = AnalysisConfig { + phased: false, + ..phased_cfg.clone() + }; + + let phased = analyze_imbalance(variants.clone(), &phased_cfg).unwrap(); + let unphased = analyze_imbalance(variants, &unphased_cfg).unwrap(); + assert_eq!(phased.len(), 1); + assert_eq!(unphased.len(), 1); + assert_result_close(&phased[0], &unphased[0]); + } + + #[test] + fn test_single_model_missing_gt_uses_unphased_multi_snp_path() { + let variants = vec![vc_peak(100, 18, 4, None), vc_peak(200, 5, 15, None)]; + let phased = single_model(variants.clone(), true).unwrap(); + let unphased = single_model(variants, false).unwrap(); + assert_eq!(phased.len(), 1); + assert_eq!(unphased.len(), 1); + assert_result_close(&phased[0], &unphased[0]); + } + // Helper: a per_donor VariantCounts for one donor "d1". fn vc_d1(pos: u32, rc: u32, ac: u32, gt: u8) -> VariantCounts { VariantCounts { diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 6b2b116..7c44602 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -27,6 +27,14 @@ use bam_counter::BamCounter; use bam_counter_sc::BamCounterSC; use mapping_filter::filter_bam_wasp; +fn parse_numeric_phased_gt(gt: &str) -> Option { + match gt { + "0|1" => Some(0), + "1|0" => Some(1), + _ => None, + } +} + // ============================================================================ // PyO3 Bindings for BAM Remapping // ============================================================================ @@ -307,8 +315,6 @@ fn analyze_imbalance( let mut min_fields: usize = 7; // GT column index (only present in 9-col format) let mut gt_idx: Option = None; - // Ref allele column index (for deriving phase from GT alleles) - let mut ref_allele_idx: Option = None; // Region column index (use input region names instead of chrom_pos format) let mut region_idx: Option = None; // Optional donor/sample column (used by the per_donor model) @@ -330,7 +336,6 @@ fn analyze_imbalance( alt_count_idx = 7; min_fields = 9; gt_idx = Some(5); - ref_allele_idx = Some(3); } // Detect an optional `sample` column by name (per_donor input) sample_idx = headers.iter().position(|h| *h == "sample"); @@ -385,34 +390,12 @@ fn analyze_imbalance( // homozygous, missing) is None. Emits ONLY Some(0)/Some(1)/None so // downstream `1 - g` on u8 never underflows. // - // Supports two GT formats: - // 1. Numeric: "0|1" or "1|0" (VCF standard) - // 2. Allele: "A|G" or "G|A" (allele strings) - derives phase by comparing to ref + // AHO parity: only numeric VCF-style phased heterozygotes are trusted. + // Allele-string GT values such as "A|G" are not interpreted as + // cross-SNV haplotype phase by the archived Python implementation. let gt: Option = gt_idx.and_then(|gi| { let gt_str = fields.get(gi)?; - // Try numeric format first (VCF standard) - match *gt_str { - "0|1" => return Some(0), - "1|0" => return Some(1), - _ => {} - } - // Try allele format: "X|Y" where X,Y are allele strings - // Compare to ref allele to derive phase - if let Some(ref_idx) = ref_allele_idx { - if let Some(&ref_allele) = fields.get(ref_idx) { - if let Some((allele1, allele2)) = gt_str.split_once('|') { - if allele1 == ref_allele && allele2 != ref_allele { - // ref on hap1 (e.g., GT="C|T" with ref="C") - return Some(0); - } else if allele2 == ref_allele && allele1 != ref_allele { - // ref on hap2 (e.g., GT="T|C" with ref="C") - return Some(1); - } - // Homozygous (both same) or neither matches ref -> None - } - } - } - None + parse_numeric_phased_gt(gt_str) }); variants.push(analysis::VariantCounts { @@ -1043,3 +1026,23 @@ fn filter_bam_wasp_with_sidecar( expected_sidecar, ) } + +#[cfg(test)] +mod tests { + use super::parse_numeric_phased_gt; + + #[test] + fn test_parse_numeric_phased_gt_accepts_only_aho_numeric_hets() { + assert_eq!(parse_numeric_phased_gt("0|1"), Some(0)); + assert_eq!(parse_numeric_phased_gt("1|0"), Some(1)); + + for gt in ["A|G", "G|A", "0/1", "0|0", "1|1", "./.", ".|."] { + assert_eq!( + parse_numeric_phased_gt(gt), + None, + "{} should not be treated as phased GT", + gt + ); + } + } +}