diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index a141a29..c44894a 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -4,4 +4,11 @@ Please make sure that introduced changes are consistent with the testing api and ## Updating the versioning -Please add to `changelog.yaml` and then run `make changelog` before committing the results ONCE in this PR. +Add a towncrier fragment under `changelog.d/` with the format +`..md`, where `` is one of `added`, +`changed`, `fixed`, `removed`, or `breaking`. + +Do not edit `CHANGELOG.md` directly in feature PRs. After the PR is merged, +the versioning workflow runs `make changelog`, deletes the consumed fragments, +updates `CHANGELOG.md`, bumps `pyproject.toml`, and commits the result as +`Update package version`. diff --git a/.github/fetch_version.py b/.github/fetch_version.py new file mode 100644 index 0000000..975130f --- /dev/null +++ b/.github/fetch_version.py @@ -0,0 +1,12 @@ +# Note: Action must be run in Python 3.11 or later. +import tomllib + + +def fetch_version(): + with open("pyproject.toml", "rb") as f: + pyproject = tomllib.load(f) + return pyproject["project"]["version"] + + +if __name__ == "__main__": + print(fetch_version()) diff --git a/.github/workflows/changelog_entry.yaml b/.github/workflows/changelog_entry.yaml index 4cf6327..f51eeb2 100644 --- a/.github/workflows/changelog_entry.yaml +++ b/.github/workflows/changelog_entry.yaml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 + with: + fetch-depth: 0 + - uses: actions/setup-python@v6 + with: + python-version: "3.14" + - name: Install towncrier + run: pip install towncrier - name: Check for changelog fragment - run: | - FRAGMENTS=$(find changelog.d -type f ! -name '.gitkeep' | wc -l) - if [ "$FRAGMENTS" -eq 0 ]; then - echo "::error::No changelog fragment found in changelog.d/" - echo "Add one with: echo 'Description.' > changelog.d/\$(git branch --show-current)..md" - echo "Types: added, changed, fixed, removed, breaking" - exit 1 - fi + run: towncrier check --compare-with origin/main diff --git a/.github/workflows/versioning.yaml b/.github/workflows/versioning.yaml index c1cab5b..bc207a1 100644 --- a/.github/workflows/versioning.yaml +++ b/.github/workflows/versioning.yaml @@ -1,15 +1,10 @@ -# Workflow that runs on versioning metadata updates. +# Workflow that bumps the package version after changes land on main. name: Versioning updates on: push: - branches: - - main - - paths: - - "changelog.d/**" - - "CHANGELOG.md" - - "!pyproject.toml" + branches: + - main jobs: Versioning: @@ -32,11 +27,15 @@ jobs: uses: actions/setup-python@v6 with: python-version: 3.14 + - name: Install uv + uses: astral-sh/setup-uv@v8.1.0 - name: Bump version and build changelog run: | pip install towncrier python .github/bump_version.py - towncrier build --yes --version $(python -c "import re; print(re.search(r'version = \"(.+?)\"', open('pyproject.toml').read()).group(1))") + towncrier build --yes --version $(python .github/fetch_version.py) + - name: Update lockfile + run: uv lock - name: Update changelog uses: EndBug/add-and-commit@v10 with: diff --git a/changelog.d/changed.1.md b/changelog.d/dashboard-paper-versioning.changed.md similarity index 100% rename from changelog.d/changed.1.md rename to changelog.d/dashboard-paper-versioning.changed.md diff --git a/changelog.d/release-dashboard-compatibility.fixed.md b/changelog.d/release-dashboard-compatibility.fixed.md new file mode 100644 index 0000000..85f5444 --- /dev/null +++ b/changelog.d/release-dashboard-compatibility.fixed.md @@ -0,0 +1 @@ +Fixed towncrier fragment validation and package versioning workflow triggers, and dashboard CSV compatibility with `metric_std` columns, benchmark-loss error bars, and donor/receiver distribution rows. diff --git a/examples/demo.ipynb b/examples/demo.ipynb index d53c154..d62e5e3 100644 --- a/examples/demo.ipynb +++ b/examples/demo.ipynb @@ -147,7 +147,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c0d546d0cd9e4e8fa9887e82834f01e5", + "model_id": "d674b691bcec45b0bd86c8c64944817b", "version_major": 2, "version_minor": 0 }, @@ -163,25 +163,21 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 19.0s finished\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 10.3s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.4s finished\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.9s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n", - "[Parallel(n_jobs=-1)]: Batch computation too fast (0.02189326286315918s.) Setting batch_size=2.\n", "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n", + "[Parallel(n_jobs=-1)]: Batch computation too fast (0.018491029739379883s.) Setting batch_size=2.\n", "QuantReg does not support categorical variable 'risk_factor'. Skipping QuantReg for this fold.\n", "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.0s finished\n", "QuantReg cannot handle the provided variable types. Returning NaN results.\n", - "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 2.3s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 2.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.1s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "Using neural network for categorical variable 'risk_factor'. This may be computationally expensive for simple classification tasks.\n", - "Using neural network for categorical variable 'risk_factor'. This may be computationally expensive for simple classification tasks.\n", - "Using neural network for categorical variable 'risk_factor'. This may be computationally expensive for simple classification tasks.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.0min finished\n", - "Model QuantReg cannot handle categorical variable 'risk_factor' (type: categorical). Skipping.\n", - "Using neural network for categorical variable 'risk_factor'. This may be computationally expensive for simple classification tasks.\n" + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.3s finished\n", + "Model QuantReg cannot handle categorical variable 'risk_factor' (type: categorical). Skipping.\n" ] }, { @@ -279,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -293,25 +289,25 @@ "alignmentgroup": "True", "error_y": { "array": [ - 0.00012760939421510863, - 0.0004562342768109314, - 0.0006101177084331859, - 0.0008908200932593258, - 0.0008790826663963316, - 0.0008918101604308018, - 0.0014014359931475851, - 0.0011450580119986484, - 0.0008347489157011024, - 0.0028336243053131253, - 0.0010015077062397809, - 0.0010483079782938722, - 0.0009641203180412462, - 0.0013634329233451873, - 0.0010973146976260759, - 0.0014533564186845864, - 0.0015844859370791663, - 0.0017969190479432422, - 0.001729443006175829 + 0.0002954494834297407, + 0.0004494725976682979, + 0.0006767372105484908, + 0.0006659311774118407, + 0.000797455897020639, + 0.0006892183293419389, + 0.0006110414541535603, + 0.0004648896396357678, + 0.0003488530234654882, + 0.00021606041908505509, + 0.0009485255438535998, + 0.0011921059039827593, + 0.0009797575621752596, + 0.0013427678136997366, + 0.0015090067957215783, + 0.0013181650583408751, + 0.0013966635660512038, + 0.0011512709649561482, + 0.000718524644778864 ] }, "hovertemplate": "Method=QRF
Quantiles=%{x}
Quantile loss=%{y}", @@ -351,25 +347,25 @@ ], "xaxis": "x", "y": [ - 0.004976762851613436, - 0.007958901707037789, - 0.011348256884255009, - 0.014479045607066829, - 0.01804729212868635, - 0.020713966247658862, - 0.022742931417508282, - 0.02258721277917371, - 0.023351581275292608, - 0.023625607730980425, - 0.023186470004863385, - 0.022905432991342744, - 0.022230731438156376, - 0.02057637886132135, - 0.01912773338390523, - 0.01739101077609373, - 0.013723503511700213, - 0.010967503180152355, - 0.0071899999204732045 + 0.004271342555413982, + 0.007428034547984062, + 0.010187107832753436, + 0.012221632376959864, + 0.014165233279944305, + 0.015597517958153539, + 0.017052829175001196, + 0.017946394869311873, + 0.019009879868497417, + 0.01921918586418314, + 0.019641689082866752, + 0.020026433113742494, + 0.019949010245087648, + 0.018709024525545386, + 0.017334992589638793, + 0.015680622528302796, + 0.013344364468131665, + 0.010172648876622744, + 0.006303629655361205 ], "yaxis": "y" }, @@ -377,25 +373,25 @@ "alignmentgroup": "True", "error_y": { "array": [ - 0.0002572352193697227, - 0.00048039464414164327, - 0.0006001743846790527, - 0.000650308596304022, - 0.0006978935830309139, - 0.0007422089956084065, - 0.0007773387618231861, - 0.0007411256769677033, - 0.0006987078147417422, - 0.0006534052072665708, - 0.0006444268979832439, - 0.0007126378391709421, - 0.0007802994397195336, - 0.0007929249621406207, - 0.0008151584886854951, - 0.0008494206242740293, - 0.0008143581430293128, - 0.0007222900955084952, - 0.0005814317154274587 + 0.00025613778102594154, + 0.0004932847354406371, + 0.0006178525384013924, + 0.0006680601100816272, + 0.0007179699587303018, + 0.0007563832159868537, + 0.0007924247103551689, + 0.0007524835090258524, + 0.0007044019728328267, + 0.0006534052072665718, + 0.0006392917138645394, + 0.0007047800184399004, + 0.0007688291042545453, + 0.0007736323445815647, + 0.0007944006158029734, + 0.0008246797845019394, + 0.0007844243245779494, + 0.0006932726633133631, + 0.0005483300517176163 ] }, "hovertemplate": "Method=OLS
Quantiles=%{x}
Quantile loss=%{y}", @@ -435,25 +431,25 @@ ], "xaxis": "x", "y": [ - 0.003875743024333408, - 0.006636892853346667, - 0.008964045456920146, - 0.010925268697485854, - 0.012586171018988979, - 0.013971078532560128, - 0.015078371048174203, - 0.015960673093506517, - 0.016611314804344936, + 0.003893152124267097, + 0.006644524109360325, + 0.00896349107737936, + 0.010921624680729358, + 0.012579225806088556, + 0.013962675026290615, + 0.015068898243069617, + 0.015952160041557438, + 0.016607551016941856, 0.017023427694261576, - 0.017162262823655853, - 0.016994331164241493, - 0.016563877768320707, - 0.01584349745574873, - 0.014830479499612956, - 0.013414418494530157, - 0.011604173888267462, - 0.009116693282835488, - 0.005653910575626517 + 0.017166369336370075, + 0.0170018974995418, + 0.016575028362549146, + 0.01585947135458786, + 0.01484794071481238, + 0.013431599655628922, + 0.011613855843838733, + 0.009106741902057723, + 0.005641799961590206 ], "yaxis": "y" }, @@ -545,25 +541,25 @@ "alignmentgroup": "True", "error_y": { "array": [ - 0.02276425224156018, - 0.02374984149826596, - 0.023084856392378967, - 0.022307702038882615, - 0.02104868709294386, - 0.019742768859700437, - 0.018880872020786737, - 0.017381494038930252, - 0.016374368504772197, - 0.015625242056816976, - 0.014519642274897656, - 0.01346810944783479, - 0.012128080639775956, - 0.009851861037881939, - 0.007636455453969302, - 0.005439968754166566, - 0.004359787415105129, - 0.006260954513825139, - 0.009076313552330278 + 0.023634466974836014, + 0.02406765247286279, + 0.023312476846387983, + 0.02246773054722902, + 0.021533579863950953, + 0.0201636007675073, + 0.019469145210195465, + 0.018271876616327277, + 0.01751128089779416, + 0.01688031093540707, + 0.015947278904868863, + 0.01487226715759355, + 0.013520743498898678, + 0.011400094127013597, + 0.00895173015168919, + 0.006476305134543658, + 0.004649374721034495, + 0.0057702207334363225, + 0.007970298230787725 ] }, "hovertemplate": "Method=MDN
Quantiles=%{x}
Quantile loss=%{y}", @@ -603,25 +599,25 @@ ], "xaxis": "x", "y": [ - 0.05582848269971269, - 0.08056099140545277, - 0.09775022131702664, - 0.10997597210140204, - 0.11835481941263067, - 0.12392234475637211, - 0.12720960013038832, - 0.1290198647237737, - 0.12931552057686446, - 0.1282395774202653, - 0.1255249036833294, - 0.12172155095625699, - 0.11708496977122927, - 0.11131352038546956, - 0.10380083044756254, - 0.09403407459425839, - 0.08197911966698565, - 0.06631822473122849, - 0.04608625541782032 + 0.056118017725221105, + 0.08092895706004721, + 0.09844626060872501, + 0.11080375568529617, + 0.11943998822234464, + 0.12474517993692051, + 0.1282511384465024, + 0.12997205081150148, + 0.13017048620008212, + 0.128751062237184, + 0.12619976875856675, + 0.12228579865349533, + 0.11760708775495192, + 0.11194306461770354, + 0.10452484110699278, + 0.09456211813800362, + 0.0824242018971309, + 0.06644217246886804, + 0.045680006163049104 ], "yaxis": "y" } @@ -648,8 +644,8 @@ "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.017217385405120095, - "y1": 0.017217385405120095 + "y0": 0.014645345969131699, + "y1": 0.014645345969131699 }, { "line": { @@ -661,8 +657,8 @@ "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.012779822693513777, - "y1": 0.012779822693513777 + "y0": 0.012782180760574876, + "y1": 0.012782180760574876 }, { "line": { @@ -687,8 +683,8 @@ "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.10358109706305417, - "y1": 0.10358109706305417 + "y0": 0.10417347139434667, + "y1": 0.10417347139434667 } ], "template": { @@ -1662,7 +1658,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9cc406a366bd4c65be0d6c70047d472f", + "model_id": "69ae187f19074039bb37328100915dd1", "version_major": 2, "version_minor": 0 }, @@ -1679,10 +1675,10 @@ "text": [ "Predictor importance results:\n", " predictor_removed relative_impact\n", - "1 sex 30.680328\n", - "2 bmi 0.032479\n", - "0 age 0.004546\n", - "3 bp -0.004827\n" + "1 sex 30.680431\n", + "2 bmi 0.032473\n", + "0 age 0.004772\n", + "3 bp -0.004654\n" ] } ], @@ -1717,7 +1713,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "53ea8275012141b09d5d89527c071ba1", + "model_id": "c19e2a36feff41769985072ba1b459b2", "version_major": 2, "version_minor": 0 }, @@ -1734,7 +1730,7 @@ "text": [ "Optimal predictor order: ['sex', 'bmi', 'age', 'bp']\n", "Optimal subset: ['sex', 'bmi', 'age']\n", - "Optimal loss: 2.409709\n" + "Optimal loss: 2.409708\n" ] } ], @@ -1776,28 +1772,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "Formatted DataFrame shape: (49, 9)\n", + "Formatted DataFrame shape: (436, 9)\n", "Result types included: \n", - "['distribution_distance', 'predictor_correlation', 'predictor_target_mi',\n", - " 'predictor_importance', 'progressive_inclusion']\n", - "Length: 5, dtype: str\n" + "[ 'benchmark_loss', 'distribution_distance', 'predictor_correlation',\n", + " 'predictor_target_mi', 'predictor_importance', 'progressive_inclusion',\n", + " 'distribution_bins']\n", + "Length: 7, dtype: str\n" ] } ], "source": [ "output_path = \"microimputation_results.csv\"\n", "\n", - "autoimpute_dict = {\"cv_results\": autoimpute_results.cv_results}\n", - "\n", "formatted_df = format_csv(\n", " output_path=output_path,\n", - " autoimpute_result=autoimpute_dict,\n", + " autoimpute_result=autoimpute_results,\n", " comparison_metrics_df=None,\n", " distribution_comparison_df=distribution_comparison_df,\n", " predictor_correlations=predictor_correlations,\n", " predictor_importance_df=predictor_importance_df,\n", " progressive_inclusion_df=progressive_inclusion_df,\n", " best_method_name=best_method_name,\n", + " donor_data=donor_data,\n", + " receiver_data=autoimpute_results.receiver_data,\n", + " imputed_variables=imputed_variables,\n", ")\n", "\n", "print(f\"Formatted DataFrame shape: {formatted_df.shape}\")\n", diff --git a/examples/pipeline.py b/examples/pipeline.py index 9cb45e5..ceec8c5 100644 --- a/examples/pipeline.py +++ b/examples/pipeline.py @@ -99,7 +99,7 @@ def categorize_risk(s4_value): print(f"Receiver data shape: {receiver_data_without_targets.shape}") print(f"Predictors: {predictors}") print(f"Variables to impute: {imputed_variables}") - print(f"Risk factor distribution in donor data:") + print("Risk factor distribution in donor data:") print(donor_data["risk_factor"].value_counts()) print() @@ -225,13 +225,10 @@ def categorize_risk(s4_value): print("STEP 7: Formatting results for dashboard visualization...") print("-" * 80) - # Convert autoimpute_results to dictionary format expected by format_csv - autoimpute_dict = {"cv_results": autoimpute_results.cv_results} - # Format all results formatted_df = format_csv( output_path=output_path, - autoimpute_result=autoimpute_dict, + autoimpute_result=autoimpute_results, comparison_metrics_df=None, # cv_results already contain this info distribution_comparison_df=distribution_comparison_df, predictor_correlations=predictor_correlations, @@ -259,8 +256,8 @@ def categorize_risk(s4_value): print(f" - Best imputation method: {best_method_name}") print(f" - Number of predictors analyzed: {len(predictors)}") print(f" - Number of imputed variables: {len(imputed_variables)}") - print(f" - Numerical variables: s1, s4") - print(f" - Categorical variables: risk_factor") + print(" - Numerical variables: s1, s4") + print(" - Categorical variables: risk_factor") print() print("Output CSV contains:") for result_type in formatted_df["type"].unique(): diff --git a/microimputation-dashboard/components/BenchmarkLossCharts.tsx b/microimputation-dashboard/components/BenchmarkLossCharts.tsx index d900f10..cd21b74 100644 --- a/microimputation-dashboard/components/BenchmarkLossCharts.tsx +++ b/microimputation-dashboard/components/BenchmarkLossCharts.tsx @@ -8,6 +8,7 @@ import { XAxis, YAxis, CartesianGrid, + ErrorBar, Tooltip, Legend, ResponsiveContainer, @@ -20,6 +21,12 @@ interface BenchmarkLossChartsProps { data: ImputationDataPoint[]; } +const ERROR_BAR_STROKE = '#374151'; + +function isFiniteNumber(value: unknown): value is number { + return typeof value === 'number' && Number.isFinite(value); +} + export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) { // Filter for benchmark_loss data const benchmarkData = useMemo(() => { @@ -63,7 +70,7 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) if (quantileLossData.length === 0) return []; // Group by quantile - const quantileMap = new Map>(); + const quantileMap = new Map>(); quantileLossData.forEach(d => { const quantile = Number(d.quantile); @@ -72,6 +79,9 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) } const entry = quantileMap.get(quantile)!; entry[d.method] = d.metric_value; + if (isFiniteNumber(d.metric_std)) { + entry[`${d.method}__std`] = d.metric_std; + } }); return Array.from(quantileMap.values()).sort( @@ -79,30 +89,50 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) ); }, [quantileLossData]); + const hasQuantileErrorBarsByMethod = useMemo(() => { + const result = new Map(); + methods.forEach(method => { + result.set( + method, + quantileChartData.some(row => isFiniteNumber(row[`${method}__std`])) + ); + }); + return result; + }, [methods, quantileChartData]); + // Transform log loss data for bar chart const logLossChartData = useMemo(() => { if (logLossData.length === 0) return []; // Average log loss per method - const methodMap = new Map(); + const methodMap = new Map(); logLossData.forEach(d => { if (d.metric_value !== null) { if (!methodMap.has(d.method)) { - methodMap.set(d.method, { sum: 0, count: 0 }); + methodMap.set(d.method, { sum: 0, count: 0, stdSum: 0, stdCount: 0 }); } const entry = methodMap.get(d.method)!; entry.sum += d.metric_value; entry.count += 1; + if (isFiniteNumber(d.metric_std)) { + entry.stdSum += d.metric_std; + entry.stdCount += 1; + } } }); - return Array.from(methodMap.entries()).map(([method, { sum, count }]) => ({ + return Array.from(methodMap.entries()).map(([method, { sum, count, stdSum, stdCount }]) => ({ method, value: sum / count, + std: stdCount > 0 ? stdSum / stdCount : undefined, })); }, [logLossData]); + const hasLogLossErrorBars = useMemo(() => { + return logLossChartData.some(row => isFiniteNumber(row.std)); + }, [logLossChartData]); + // Determine best performing model const bestModel = useMemo(() => { if (methods.length === 0) return null; @@ -253,13 +283,24 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) if (!selectedMethod) return { quantile: [], logLoss: [] }; // Quantile loss train vs test - const quantileTrainTest: Array<{ quantile: string; train: number | null; test: number | null }> = []; + const quantileTrainTest: Array<{ + quantile: string; + train: number | null; + test: number | null; + trainStd?: number; + testStd?: number; + }> = []; const quantileData = benchmarkData.filter( d => d.method === selectedMethod && d.metric_name === 'quantile_loss' ); if (quantileData.length > 0) { - const quantileMap = new Map(); + const quantileMap = new Map(); quantileData.forEach(d => { const q = typeof d.quantile === 'number' ? d.quantile.toFixed(2) : String(d.quantile || ''); @@ -270,8 +311,14 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) quantileMap.set(q, { train: null, test: null }); } const entry = quantileMap.get(q)!; - if (d.split === 'train') entry.train = d.metric_value; - if (d.split === 'test') entry.test = d.metric_value; + if (d.split === 'train') { + entry.train = d.metric_value; + if (isFiniteNumber(d.metric_std)) entry.trainStd = d.metric_std; + } + if (d.split === 'test') { + entry.test = d.metric_value; + if (isFiniteNumber(d.metric_std)) entry.testStd = d.metric_std; + } }); quantileMap.forEach((value, quantile) => { @@ -282,7 +329,13 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) } // Log loss train vs test (average across variables) - const logLossTrainTest: Array<{ category: string; train: number; test: number }> = []; + const logLossTrainTest: Array<{ + category: string; + train: number; + test: number; + trainStd?: number; + testStd?: number; + }> = []; const logData = benchmarkData.filter( d => d.method === selectedMethod && d.metric_name === 'log_loss' && d.metric_value !== null ); @@ -290,20 +343,32 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) if (logData.length > 0) { const trainVals: number[] = []; const testVals: number[] = []; + const trainStdVals: number[] = []; + const testStdVals: number[] = []; logData.forEach(d => { - if (d.split === 'train') trainVals.push(d.metric_value!); - if (d.split === 'test') testVals.push(d.metric_value!); + if (d.split === 'train') { + trainVals.push(d.metric_value!); + if (isFiniteNumber(d.metric_std)) trainStdVals.push(d.metric_std); + } + if (d.split === 'test') { + testVals.push(d.metric_value!); + if (isFiniteNumber(d.metric_std)) testStdVals.push(d.metric_std); + } }); if (trainVals.length > 0 || testVals.length > 0) { const trainAvg = trainVals.length > 0 ? trainVals.reduce((a, b) => a + b, 0) / trainVals.length : 0; const testAvg = testVals.length > 0 ? testVals.reduce((a, b) => a + b, 0) / testVals.length : 0; + const trainStdAvg = trainStdVals.length > 0 ? trainStdVals.reduce((a, b) => a + b, 0) / trainStdVals.length : undefined; + const testStdAvg = testStdVals.length > 0 ? testStdVals.reduce((a, b) => a + b, 0) / testStdVals.length : undefined; logLossTrainTest.push({ category: 'Average', train: trainAvg, test: testAvg, + trainStd: trainStdAvg, + testStd: testStdAvg, }); } } @@ -316,6 +381,10 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) const hasQuantileTrainTest = trainTestData.quantile.length > 0; const hasLogLossTrainTest = trainTestData.logLoss.length > 0; + const hasQuantileTrainErrorBars = trainTestData.quantile.some(row => isFiniteNumber(row.trainStd)); + const hasQuantileTestErrorBars = trainTestData.quantile.some(row => isFiniteNumber(row.testStd)); + const hasLogLossTrainErrorBars = trainTestData.logLoss.some(row => isFiniteNumber(row.trainStd)); + const hasLogLossTestErrorBars = trainTestData.logLoss.some(row => isFiniteNumber(row.testStd)); // Filter methods that have train/test data const methodsWithData = useMemo(() => { @@ -446,7 +515,15 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) dataKey={method} fill={getMethodColor(method, index)} name={method} - /> + > + {hasQuantileErrorBarsByMethod.get(method) && ( + + )} + ))} @@ -496,6 +573,13 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) formatter={(value: number) => [value.toFixed(6), 'Log loss']} /> + {hasLogLossErrorBars && ( + + )} {logLossChartData.map((entry) => { const globalIndex = methods.indexOf(entry.method); return ( @@ -580,8 +664,24 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) formatter={(value: number) => value.toFixed(6)} /> } /> - - + + {hasQuantileTrainErrorBars && ( + + )} + + + {hasQuantileTestErrorBars && ( + + )} + @@ -618,8 +718,24 @@ export default function BenchmarkLossCharts({ data }: BenchmarkLossChartsProps) formatter={(value: number) => value.toFixed(6)} /> } /> - - + + {hasLogLossTrainErrorBars && ( + + )} + + + {hasLogLossTestErrorBars && ( + + )} + diff --git a/microimputation-dashboard/components/DistributionOverlay.tsx b/microimputation-dashboard/components/DistributionOverlay.tsx index 0958801..183cb2d 100644 --- a/microimputation-dashboard/components/DistributionOverlay.tsx +++ b/microimputation-dashboard/components/DistributionOverlay.tsx @@ -1,6 +1,6 @@ 'use client'; -import { useMemo, useState } from 'react'; +import { useEffect, useMemo, useState } from 'react'; import { ImputationDataPoint } from '@/types/imputation'; import { BarChart, @@ -108,7 +108,7 @@ export default function DistributionOverlay({ } else if (d.metric_name === 'categorical_distribution') { // Categorical variable (distributions[variable].data as CategoryData[]).push({ - category: info.category, + category: String(info.category), donorProportion: info.donor_proportion, receiverProportion: info.receiver_proportion, }); @@ -135,6 +135,12 @@ export default function DistributionOverlay({ variables[0] || '' ); + useEffect(() => { + if (variables.length > 0 && !variables.includes(selectedVariable)) { + setSelectedVariable(variables[0]); + } + }, [selectedVariable, variables]); + if (variables.length === 0) { return null; } @@ -274,7 +280,7 @@ export default function DistributionOverlay({ contentStyle={{ color: '#000000' }} labelStyle={{ color: '#000000' }} /> - } /> + } /> { try { @@ -177,7 +178,7 @@ export default function FileUpload({ } finally { setIsLoading(false); } - }, [onFileLoad, onDeeplinkLoadComplete, loadArtifactFromDeeplink]); + }, [onFileLoad, onDeeplinkLoadComplete, loadArtifactFromDeeplink, generatePreview]); // Handle deeplink loading on mount useEffect(() => { @@ -195,45 +196,79 @@ export default function FileUpload({ function validateCSVFormat(content: string): void { // Expected column names for microimputation dashboard - const EXPECTED_COLUMNS = [ + const REQUIRED_COLUMNS = [ + 'type', + 'method', + 'variable', + 'quantile', + 'metric_name', + 'metric_value', + 'split', + 'additional_info' + ]; + const OPTIONAL_COLUMNS = ['metric_std']; + const NEW_FORMAT_COLUMNS = [ 'type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', + 'metric_std', 'split', 'additional_info' ]; - const lines = content.trim().split('\n'); + const parsed = Papa.parse(content.trim(), { + header: false, + skipEmptyLines: true, + }); + + if (parsed.errors.length > 0) { + const firstError = parsed.errors[0]; + throw new Error( + `Invalid CSV format: ${firstError.message} on row ${firstError.row ?? 'unknown'}.` + ); + } + + const rows = parsed.data.filter(row => row.length > 0); + + if (rows.length < 2) { + throw new Error('The file must contain at least a header row and one data row.'); + } // Parse header - const header = lines[0].split(',').map(col => col.trim()); + const header = rows[0].map(col => col.trim()); // Check if all expected columns are present - const missingColumns = EXPECTED_COLUMNS.filter(col => !header.includes(col)); + const missingColumns = REQUIRED_COLUMNS.filter(col => !header.includes(col)); if (missingColumns.length > 0) { throw new Error( - `Invalid CSV format\nMissing required columns: ${missingColumns.join(', ')}\nExpected columns: ${EXPECTED_COLUMNS.join(', ')}` + `Invalid CSV format\nMissing required columns: ${missingColumns.join(', ')}\nExpected columns: ${NEW_FORMAT_COLUMNS.join(', ')}` ); } // Check for extra columns - const extraColumns = header.filter(col => !EXPECTED_COLUMNS.includes(col)); + const allowedColumns = [...REQUIRED_COLUMNS, ...OPTIONAL_COLUMNS]; + const extraColumns = header.filter(col => !allowedColumns.includes(col)); if (extraColumns.length > 0) { throw new Error( `Invalid CSV format: Unexpected columns found: ${extraColumns.join(', ')}. ` + - `Expected columns: ${EXPECTED_COLUMNS.join(', ')}` + `Expected columns: ${NEW_FORMAT_COLUMNS.join(', ')}` ); } // Check column order matches expected - const columnOrderMismatch = EXPECTED_COLUMNS.some((col, idx) => header[idx] !== col); + const expectedOrder = header.includes('metric_std') + ? NEW_FORMAT_COLUMNS + : REQUIRED_COLUMNS; + const columnOrderMismatch = + header.length !== expectedOrder.length || + expectedOrder.some((col, idx) => header[idx] !== col); if (columnOrderMismatch) { throw new Error( `Invalid CSV format: Columns are not in the expected order. ` + - `Expected order: ${EXPECTED_COLUMNS.join(', ')}` + `Expected order: ${NEW_FORMAT_COLUMNS.join(', ')}` ); } @@ -242,33 +277,21 @@ export default function FileUpload({ // Validate JSON format in additional_info column for a sample of rows // We'll check the first 10 data rows to avoid processing large files - const maxRowsToCheck = Math.min(10, lines.length - 1); + const maxRowsToCheck = Math.min(10, rows.length - 1); for (let i = 1; i <= maxRowsToCheck; i++) { - const line = lines[i]; - if (!line.trim()) continue; // Skip empty lines - - // Parse CSV row (basic parsing - doesn't handle quoted commas) - const columns = line.split(','); + const columns = rows[i]; - if (columns.length !== EXPECTED_COLUMNS.length) { + if (columns.length !== expectedOrder.length) { throw new Error( `Invalid CSV format: Row ${i + 1} has ${columns.length} columns, ` + - `but ${EXPECTED_COLUMNS.length} columns are expected.` + `but ${expectedOrder.length} columns are expected.` ); } - let additionalInfo = columns[additionalInfoIndex]?.trim(); + const additionalInfo = columns[additionalInfoIndex]?.trim(); if (additionalInfo) { - // Handle CSV quote escaping: pandas doubles quotes when writing to CSV - // e.g., {"key": "value"} becomes "{""key"": ""value""}" in the CSV - // Remove outer quotes if present and unescape inner quotes - if (additionalInfo.startsWith('"') && additionalInfo.endsWith('"')) { - additionalInfo = additionalInfo.slice(1, -1); // Remove outer quotes - additionalInfo = additionalInfo.replace(/""/g, '"'); // Unescape doubled quotes - } - try { JSON.parse(additionalInfo); } catch { @@ -1187,4 +1210,4 @@ export default function FileUpload({ ); -} \ No newline at end of file +} diff --git a/microimputation-dashboard/components/ImputationResults.tsx b/microimputation-dashboard/components/ImputationResults.tsx index c048221..00ec3e1 100644 --- a/microimputation-dashboard/components/ImputationResults.tsx +++ b/microimputation-dashboard/components/ImputationResults.tsx @@ -46,7 +46,7 @@ export default function ImputationResults({ data }: ImputationResultsProps) { if (info.bin_end !== undefined) { ranges[variable].max = Math.max(ranges[variable].max, info.bin_end); } - } catch (e) { + } catch { // Ignore parsing errors } }); @@ -93,8 +93,9 @@ export default function ImputationResults({ data }: ImputationResultsProps) { const hasWasserstein = wassersteinData.length > 0; const hasKLDivergence = klDivergenceData.length > 0; + const hasDistributionBins = data.some(d => d.type === 'distribution_bins'); - if (!hasWasserstein && !hasKLDivergence) { + if (!hasWasserstein && !hasKLDivergence && !hasDistributionBins) { return null; } diff --git a/microimputation-dashboard/components/PerVariableCharts.tsx b/microimputation-dashboard/components/PerVariableCharts.tsx index 52a7ffc..4ffe344 100644 --- a/microimputation-dashboard/components/PerVariableCharts.tsx +++ b/microimputation-dashboard/components/PerVariableCharts.tsx @@ -8,6 +8,7 @@ import { XAxis, YAxis, CartesianGrid, + ErrorBar, Tooltip, Legend, ResponsiveContainer, @@ -22,6 +23,12 @@ interface PerVariableChartsProps { metricType: 'quantile_loss' | 'log_loss'; } +const ERROR_BAR_STROKE = '#374151'; + +function isFiniteNumber(value: unknown): value is number { + return typeof value === 'number' && Number.isFinite(value); +} + export default function PerVariableCharts({ data, variable, @@ -56,7 +63,7 @@ export default function PerVariableCharts({ typeof d.quantile === 'number' && d.quantile >= 0 && d.quantile <= 1 ); - const quantileMap = new Map>(); + const quantileMap = new Map>(); numericData.forEach((d) => { const quantile = Number(d.quantile); @@ -65,6 +72,9 @@ export default function PerVariableCharts({ } const entry = quantileMap.get(quantile)!; entry[d.method] = d.metric_value; + if (isFiniteNumber(d.metric_std)) { + entry[`${d.method}__std`] = d.metric_std; + } }); return Array.from(quantileMap.values()).sort( @@ -72,31 +82,51 @@ export default function PerVariableCharts({ ); }, [variableData, metricType]); + const hasQuantileErrorBarsByMethod = useMemo(() => { + const result = new Map(); + methods.forEach((method) => { + result.set( + method, + quantileChartData.some((row) => isFiniteNumber(row[`${method}__std`])) + ); + }); + return result; + }, [methods, quantileChartData]); + // For categorical variables (log_loss), show simple bar comparison const logLossChartData = useMemo(() => { if (metricType !== 'log_loss') return []; - const methodMap = new Map(); + const methodMap = new Map(); variableData.forEach((d) => { if (d.metric_value !== null) { if (!methodMap.has(d.method)) { - methodMap.set(d.method, { sum: 0, count: 0 }); + methodMap.set(d.method, { sum: 0, count: 0, stdSum: 0, stdCount: 0 }); } const entry = methodMap.get(d.method)!; entry.sum += d.metric_value; entry.count += 1; + if (isFiniteNumber(d.metric_std)) { + entry.stdSum += d.metric_std; + entry.stdCount += 1; + } } }); return Array.from(methodMap.entries()).map( - ([method, { sum, count }]) => ({ + ([method, { sum, count, stdSum, stdCount }]) => ({ method, value: sum / count, + std: stdCount > 0 ? stdSum / stdCount : undefined, }) ); }, [variableData, metricType]); + const hasLogLossErrorBars = useMemo(() => { + return logLossChartData.some((row) => isFiniteNumber(row.std)); + }, [logLossChartData]); + if (variableData.length === 0) { return (
@@ -156,7 +186,15 @@ export default function PerVariableCharts({ dataKey={method} fill={getMethodColor(method, globalIndex >= 0 ? globalIndex : 0)} name={method} - /> + > + {hasQuantileErrorBarsByMethod.get(method) && ( + + )} + ); })} @@ -197,6 +235,13 @@ export default function PerVariableCharts({ formatter={(value: number) => [value.toFixed(6), 'Log Loss']} /> + {hasLogLossErrorBars && ( + + )} {logLossChartData.map((entry) => { const globalIndex = allMethods.indexOf(entry.method); return ( diff --git a/microimputation-dashboard/components/VisualizationDashboard.tsx b/microimputation-dashboard/components/VisualizationDashboard.tsx index 653c9b4..435bb47 100644 --- a/microimputation-dashboard/components/VisualizationDashboard.tsx +++ b/microimputation-dashboard/components/VisualizationDashboard.tsx @@ -1,6 +1,6 @@ 'use client'; -import { useMemo, useState } from 'react'; +import { useEffect, useMemo, useState } from 'react'; import { ImputationDataPoint } from '@/types/imputation'; import { GitHubArtifactInfo, createShareableUrl } from '@/utils/deeplinks'; import BenchmarkLossCharts from './BenchmarkLossCharts'; @@ -83,9 +83,11 @@ export default function VisualizationDashboard({ // Check for actual distribution distance data (wasserstein or kl_divergence) const distributionData = data.filter(d => d.type === 'distribution_distance'); + const distributionBinsData = data.filter(d => d.type === 'distribution_bins'); const hasWasserstein = distributionData.some(d => d.metric_name === 'wasserstein_distance' && d.metric_value !== null); const hasKLDivergence = distributionData.some(d => d.metric_name === 'kl_divergence' && d.metric_value !== null); const hasDistributionDistance = hasWasserstein || hasKLDivergence; + const hasDistributionBins = distributionBinsData.length > 0; // Check for predictor correlation data const correlationData = data.filter(d => d.type === 'predictor_correlation'); @@ -104,6 +106,11 @@ export default function VisualizationDashboard({ imputedVars.add(d.variable); } }); + distributionBinsData.forEach(d => { + if (d.variable) { + imputedVars.add(d.variable); + } + }); // Calculate best performing model (same logic as BenchmarkLossCharts) let bestModel = ''; @@ -313,6 +320,7 @@ export default function VisualizationDashboard({ return { hasBenchmarkLoss, hasDistributionDistance, + hasDistributionBins, hasPredictorCorrelation, hasPredictorOrdering, numericalVars, @@ -345,7 +353,7 @@ export default function VisualizationDashboard({ tabsList.push({ id: 'overview', label: 'Model benchmarking', description: 'Compare quantile loss and log loss across imputation methods' }); } - if (dataAnalysis.hasDistributionDistance) { + if (dataAnalysis.hasDistributionDistance || dataAnalysis.hasDistributionBins) { tabsList.push({ id: 'imputation', label: 'Imputation results', @@ -390,7 +398,13 @@ export default function VisualizationDashboard({ return tabsList; }, [dataAnalysis]); - if (!dataAnalysis.hasBenchmarkLoss) { + useEffect(() => { + if (tabs.length > 0 && !tabs.some(tab => tab.id === activeTab)) { + setActiveTab(tabs[0].id); + } + }, [activeTab, tabs]); + + if (tabs.length === 0) { return (
{/* Header */} @@ -425,7 +439,7 @@ export default function VisualizationDashboard({

No visualization data found

- Upload a CSV file with benchmark_loss data to see visualizations. + Upload a CSV file with benchmark_loss, distribution_distance, distribution_bins, or predictor data to see visualizations.

@@ -680,4 +694,4 @@ export default function VisualizationDashboard({
); -} \ No newline at end of file +} diff --git a/microimputation-dashboard/types/imputation.ts b/microimputation-dashboard/types/imputation.ts index 799862d..64b7b2a 100644 --- a/microimputation-dashboard/types/imputation.ts +++ b/microimputation-dashboard/types/imputation.ts @@ -6,9 +6,10 @@ export interface ImputationDataPoint { quantile: string | number; // numeric (0.05, 0.1, etc.), "mean", or "N/A" metric_name: string; // e.g., "quantile_loss", "log_loss" metric_value: number | null; // numeric value of the metric + metric_std?: number | null; // optional standard deviation for repeated CV metrics split: string; // e.g., "train", "test", "full" additional_info: string; // JSON-formatted string with metadata - [key: string]: any; // Allow additional fields + [key: string]: unknown; // Allow additional fields } export interface ImputationMetrics { @@ -22,4 +23,4 @@ export interface FileInfo { filename: string; loaded: boolean; data: ImputationDataPoint[]; -} \ No newline at end of file +} diff --git a/microimpute/utils/dashboard_formatter.py b/microimpute/utils/dashboard_formatter.py index d39291f..38d2b59 100644 --- a/microimpute/utils/dashboard_formatter.py +++ b/microimpute/utils/dashboard_formatter.py @@ -3,11 +3,17 @@ """ import json +import logging from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd +from microimpute.utils.type_handling import VariableTypeDetector + + +log = logging.getLogger(__name__) + def _compute_histogram_data( donor_values: np.ndarray, @@ -118,7 +124,7 @@ def _compute_categorical_distribution( pd.Series(receiver_props) / receiver_values.count() * 100 ).tolist() else: - categories = sorted(all_categories) + categories = sorted(all_categories, key=lambda value: str(value)) donor_props = [ (donor_counts.get(cat, 0) / donor_values.count() * 100) for cat in categories @@ -264,9 +270,38 @@ def _validate_imputed_variables( ) +def _is_categorical_distribution_variable( + series: pd.Series, + variable_name: str, +) -> bool: + """Return whether a variable should use categorical distribution rows.""" + detector = VariableTypeDetector() + var_type, _ = detector.categorize_variable(series, variable_name, log) + return var_type in ["bool", "categorical", "numeric_categorical"] + + +def _extract_cv_results(autoimpute_result: Any) -> Optional[Dict[str, Dict[str, Any]]]: + """Normalize supported autoimpute result shapes to a cv_results dict.""" + if autoimpute_result is None: + return None + + if hasattr(autoimpute_result, "cv_results"): + cv_results = getattr(autoimpute_result, "cv_results") + return cv_results if isinstance(cv_results, dict) else None + + if not isinstance(autoimpute_result, dict): + return None + + wrapped_cv_results = autoimpute_result.get("cv_results") + if isinstance(wrapped_cv_results, dict): + return wrapped_cv_results + + return autoimpute_result + + def format_csv( output_path: Optional[str] = None, - autoimpute_result: Optional[Dict] = None, + autoimpute_result: Optional[Any] = None, comparison_metrics_df: Optional[pd.DataFrame] = None, distribution_comparison_df: Optional[pd.DataFrame] = None, predictor_correlations: Optional[Dict[str, pd.DataFrame]] = None, @@ -288,7 +323,8 @@ def format_csv( autoimpute_result : Dict, optional Result from autoimpute containing cv_results with benchmark losses. - Expected structure: {method: {'quantile_loss': {...}, 'log_loss': {...}}} + Supports an AutoImputeResult object, a {'cv_results': ...} wrapper, + or the direct structure {method: {'quantile_loss': {...}, 'log_loss': {...}}}. comparison_metrics_df : pd.DataFrame, optional DataFrame from compare_metrics() with columns: @@ -331,7 +367,8 @@ def format_csv( ------- pd.DataFrame Unified long-format DataFrame with columns: - ['type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', 'split', 'additional_info'] + ['type', 'method', 'variable', 'quantile', 'metric_name', 'metric_value', + 'metric_std', 'split', 'additional_info'] Raises ------ @@ -342,12 +379,13 @@ def format_csv( rows = [] # 1. Process autoimpute benchmark losses from cv_results - if autoimpute_result and isinstance(autoimpute_result, dict): - first_value = next(iter(autoimpute_result.values()), None) + cv_results = _extract_cv_results(autoimpute_result) + if cv_results: + first_value = next(iter(cv_results.values()), None) if isinstance(first_value, dict) and ( "quantile_loss" in first_value or "log_loss" in first_value ): - for method, cv_result in autoimpute_result.items(): + for method, cv_result in cv_results.items(): # Append "_best_method" if this is the best method method_label = ( f"{method}_best_method" if method == best_method_name else method @@ -647,9 +685,7 @@ def format_csv( # Generate histogram data for each imputed variable for var in imputed_variables: # Check if variable is categorical or numerical - if pd.api.types.is_string_dtype(donor_data[var]) or isinstance( - donor_data[var].dtype, pd.CategoricalDtype - ): + if _is_categorical_distribution_variable(donor_data[var], var): # Categorical variable hist_data = _compute_categorical_distribution( donor_data[var], receiver_data[var], var @@ -700,7 +736,7 @@ def convert_quantile(q): else: try: return float(q) - except: + except (TypeError, ValueError): return q df["quantile"] = df["quantile"].apply(convert_quantile) diff --git a/tests/test_dashboard_formatter.py b/tests/test_dashboard_formatter.py index 7dfd0f1..1cf7f8e 100644 --- a/tests/test_dashboard_formatter.py +++ b/tests/test_dashboard_formatter.py @@ -299,6 +299,33 @@ def test_additional_info_is_valid_json( class TestFormatCSVBenchmarkLoss: """Tests for benchmark_loss type formatting.""" + def test_benchmark_loss_from_cv_results_wrapper(self, sample_autoimpute_result): + """Test benchmark loss formatting from {'cv_results': ...} wrappers.""" + result = format_csv( + autoimpute_result={"cv_results": sample_autoimpute_result}, + ) + + benchmark_rows = result[result["type"] == "benchmark_loss"] + assert len(benchmark_rows) > 0 + assert {"OLS", "QRF"}.issubset(set(benchmark_rows["method"])) + + def test_benchmark_loss_from_autoimpute_result_object( + self, sample_autoimpute_result + ): + """Test benchmark loss formatting from objects exposing cv_results.""" + + class ResultLike: + pass + + result_like = ResultLike() + result_like.cv_results = sample_autoimpute_result + + result = format_csv(autoimpute_result=result_like) + + benchmark_rows = result[result["type"] == "benchmark_loss"] + assert len(benchmark_rows) > 0 + assert {"OLS", "QRF"}.issubset(set(benchmark_rows["method"])) + def test_benchmark_loss_from_autoimpute(self, sample_autoimpute_result): """Test benchmark loss formatting from autoimpute results.""" with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f: @@ -943,6 +970,31 @@ def test_error_when_variable_missing_from_datasets(self): imputed_variables=imputed_variables, ) + def test_numeric_categorical_distribution_uses_categorical_rows(self): + """Test numeric categorical variables produce categorical distributions.""" + donor_data = pd.DataFrame( + { + "rating": [1, 1, 2, 2, 3, 3], + "flag": [0, 1, 1, 0, 1, 0], + } + ) + receiver_data = pd.DataFrame( + { + "rating": [1, 2, 2, 3], + "flag": [1, 1, 0, 0], + } + ) + + result = format_csv( + donor_data=donor_data, + receiver_data=receiver_data, + imputed_variables=["rating", "flag"], + ) + + dist_bins = result[result["type"] == "distribution_bins"] + assert set(dist_bins["variable"]) == {"rating", "flag"} + assert set(dist_bins["metric_name"]) == {"categorical_distribution"} + class TestEdgeCases: """Test edge cases and error handling."""