Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 102 additions & 6 deletions rust/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,29 @@ fn deduplicate_variants(variants: Vec<VariantCounts>, phased: bool) -> Vec<Varia
result
}

fn effective_phased(variants: &[VariantCounts], requested_phased: bool) -> bool {
if !requested_phased {
return false;
}

let mut saw_ref_alt = false;
let mut saw_alt_ref = false;
for variant in variants {
match variant.gt {
Some(0) => saw_ref_alt = true,
Some(1) => saw_alt_ref = true,
_ => {}
}
}

if saw_ref_alt && saw_alt_ref {
true
} else {
eprintln!("No informative numeric phased GT found: switching to unphased model");
false
}
}

/// Sigmoid function (inverse logit): expit(x) = 1 / (1 + e^-x)
///
/// Numerically stable implementation that avoids overflow for large |x|.
Expand Down Expand Up @@ -724,7 +747,7 @@ pub fn single_model(variants: Vec<VariantCounts>, phased: bool) -> Result<Vec<Im
// single_model uses one global scalar dispersion; pass it as a
// 1-element slice so the broadcast fallback keeps results identical.
optimize_prob_phased(&region_ref, &region_n, &region_gt, &[disp])?
} else if !phased && indices.len() > 1 {
} else if indices.len() > 1 {
// Unphased multi-SNP region: marginalize over phase via DP,
// minimizing opt_unphased_dp over prob in (0, 1)
// (first = index 0, phase = remaining), matching Python's
Expand Down Expand Up @@ -875,7 +898,7 @@ pub fn linear_model(variants: Vec<VariantCounts>, phased: bool) -> Result<Vec<Im
.map(|&i| variants[i].gt.expect("gt present (checked by all_gt)"))
.collect();
optimize_prob_phased(&region_ref, &region_n, &region_gt, &region_rho)?
} else if !phased && indices.len() > 1 {
} else if indices.len() > 1 {
// Unphased multi-SNP region: marginalize over phase via DP,
// minimizing opt_unphased_dp over prob in (0, 1)
// (first = index 0, phase = remaining), matching Python's
Expand Down Expand Up @@ -1165,12 +1188,14 @@ pub fn analyze_imbalance(

eprintln!("Processing {} variants after filtering", filtered.len());

let phased = effective_phased(&filtered, config.phased);

// Deduplicate variants for Python parity (single and linear models only).
// Python's get_imbalance() does: df[keep_cols].drop_duplicates()
// Per-donor model skips dedup as it's a Rust-only feature.
let deduped = if config.method != AnalysisMethod::PerDonor {
let before = filtered.len();
let after = deduplicate_variants(filtered, config.phased);
let after = deduplicate_variants(filtered, phased);
if after.len() < before {
eprintln!(
"Deduplicated {} → {} variants ({} duplicates removed)",
Expand All @@ -1186,9 +1211,9 @@ pub fn analyze_imbalance(

// Run analysis based on method
let mut results = match config.method {
AnalysisMethod::Single => single_model(deduped, config.phased)?,
AnalysisMethod::Linear => linear_model(deduped, config.phased)?,
AnalysisMethod::PerDonor => per_donor_model(deduped, config.phased)?,
AnalysisMethod::Single => single_model(deduped, phased)?,
AnalysisMethod::Linear => linear_model(deduped, phased)?,
AnalysisMethod::PerDonor => per_donor_model(deduped, phased)?,
};

// Remove pseudocounts from results
Expand Down Expand Up @@ -1342,6 +1367,77 @@ mod tests {
assert!(result_imbalanced.is_finite());
}

fn vc_peak(pos: u32, rc: u32, ac: u32, gt: Option<u8>) -> VariantCounts {
VariantCounts {
chrom: "chr1".to_string(),
pos,
ref_count: rc,
alt_count: ac,
region: "peak1".to_string(),
sample: Some("d1".to_string()),
gt,
}
}

fn assert_result_close(a: &ImbalanceResult, b: &ImbalanceResult) {
assert_eq!(a.region, b.region);
assert_eq!(a.ref_count, b.ref_count);
assert_eq!(a.alt_count, b.alt_count);
assert_eq!(a.n, b.n);
assert!((a.null_ll - b.null_ll).abs() < 1e-10);
assert!((a.alt_ll - b.alt_ll).abs() < 1e-10);
assert!((a.mu - b.mu).abs() < 1e-10);
assert!((a.lrt - b.lrt).abs() < 1e-10);
assert!((a.pval - b.pval).abs() < 1e-10);
}

#[test]
fn test_effective_phased_requires_both_numeric_orientations() {
let both = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 5, 15, Some(1))];
assert!(effective_phased(&both, true));
assert!(!effective_phased(&both, false));

let none = vec![vc_peak(100, 18, 4, None), vc_peak(200, 5, 15, None)];
assert!(!effective_phased(&none, true));

let same = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, Some(0))];
assert!(!effective_phased(&same, true));

let partial = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, None)];
assert!(!effective_phased(&partial, true));
}

#[test]
fn test_analyze_imbalance_all_same_numeric_gt_falls_back_to_unphased() {
let variants = vec![vc_peak(100, 18, 4, Some(0)), vc_peak(200, 12, 8, Some(0))];
let phased_cfg = AnalysisConfig {
min_count: 0,
pseudocount: 0,
method: AnalysisMethod::Single,
phased: true,
};
let unphased_cfg = AnalysisConfig {
phased: false,
..phased_cfg.clone()
};

let phased = analyze_imbalance(variants.clone(), &phased_cfg).unwrap();
let unphased = analyze_imbalance(variants, &unphased_cfg).unwrap();
assert_eq!(phased.len(), 1);
assert_eq!(unphased.len(), 1);
assert_result_close(&phased[0], &unphased[0]);
}

#[test]
fn test_single_model_missing_gt_uses_unphased_multi_snp_path() {
let variants = vec![vc_peak(100, 18, 4, None), vc_peak(200, 5, 15, None)];
let phased = single_model(variants.clone(), true).unwrap();
let unphased = single_model(variants, false).unwrap();
assert_eq!(phased.len(), 1);
assert_eq!(unphased.len(), 1);
assert_result_close(&phased[0], &unphased[0]);
}

// Helper: a per_donor VariantCounts for one donor "d1".
fn vc_d1(pos: u32, rc: u32, ac: u32, gt: u8) -> VariantCounts {
VariantCounts {
Expand Down
61 changes: 32 additions & 29 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ use bam_counter::BamCounter;
use bam_counter_sc::BamCounterSC;
use mapping_filter::filter_bam_wasp;

fn parse_numeric_phased_gt(gt: &str) -> Option<u8> {
match gt {
"0|1" => Some(0),
"1|0" => Some(1),
_ => None,
}
}

// ============================================================================
// PyO3 Bindings for BAM Remapping
// ============================================================================
Expand Down Expand Up @@ -307,8 +315,6 @@ fn analyze_imbalance(
let mut min_fields: usize = 7;
// GT column index (only present in 9-col format)
let mut gt_idx: Option<usize> = None;
// Ref allele column index (for deriving phase from GT alleles)
let mut ref_allele_idx: Option<usize> = None;
// Region column index (use input region names instead of chrom_pos format)
let mut region_idx: Option<usize> = None;
// Optional donor/sample column (used by the per_donor model)
Expand All @@ -330,7 +336,6 @@ fn analyze_imbalance(
alt_count_idx = 7;
min_fields = 9;
gt_idx = Some(5);
ref_allele_idx = Some(3);
}
// Detect an optional `sample` column by name (per_donor input)
sample_idx = headers.iter().position(|h| *h == "sample");
Expand Down Expand Up @@ -385,34 +390,12 @@ fn analyze_imbalance(
// homozygous, missing) is None. Emits ONLY Some(0)/Some(1)/None so
// downstream `1 - g` on u8 never underflows.
//
// Supports two GT formats:
// 1. Numeric: "0|1" or "1|0" (VCF standard)
// 2. Allele: "A|G" or "G|A" (allele strings) - derives phase by comparing to ref
// AHO parity: only numeric VCF-style phased heterozygotes are trusted.
// Allele-string GT values such as "A|G" are not interpreted as
// cross-SNV haplotype phase by the archived Python implementation.
let gt: Option<u8> = gt_idx.and_then(|gi| {
let gt_str = fields.get(gi)?;
// Try numeric format first (VCF standard)
match *gt_str {
"0|1" => return Some(0),
"1|0" => return Some(1),
_ => {}
}
// Try allele format: "X|Y" where X,Y are allele strings
// Compare to ref allele to derive phase
if let Some(ref_idx) = ref_allele_idx {
if let Some(&ref_allele) = fields.get(ref_idx) {
if let Some((allele1, allele2)) = gt_str.split_once('|') {
if allele1 == ref_allele && allele2 != ref_allele {
// ref on hap1 (e.g., GT="C|T" with ref="C")
return Some(0);
} else if allele2 == ref_allele && allele1 != ref_allele {
// ref on hap2 (e.g., GT="T|C" with ref="C")
return Some(1);
}
// Homozygous (both same) or neither matches ref -> None
}
}
}
None
parse_numeric_phased_gt(gt_str)
});

variants.push(analysis::VariantCounts {
Expand Down Expand Up @@ -1043,3 +1026,23 @@ fn filter_bam_wasp_with_sidecar(
expected_sidecar,
)
}

#[cfg(test)]
mod tests {
use super::parse_numeric_phased_gt;

#[test]
fn test_parse_numeric_phased_gt_accepts_only_aho_numeric_hets() {
assert_eq!(parse_numeric_phased_gt("0|1"), Some(0));
assert_eq!(parse_numeric_phased_gt("1|0"), Some(1));

for gt in ["A|G", "G|A", "0/1", "0|0", "1|1", "./.", ".|."] {
assert_eq!(
parse_numeric_phased_gt(gt),
None,
"{} should not be treated as phased GT",
gt
);
}
}
}
Loading