diff --git a/aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py b/aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py index a2311c33..4041e00c 100644 --- a/aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py +++ b/aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py @@ -38,24 +38,19 @@ def plot_feat_importance_bars_subcat(ax=None, list_cat=df_feat[col_cat].to_list(), col_cat=col_cat) if shap_plot: - # Cumulative feature impact per category, stacked in one direction with one - # thin white-edged segment per feature (red=positive, blue=negative) + # Stack signed feature impact per category (positive=red right, negative=blue left). + # Build one value PER category in list_cat (row) order -- categories with no impact + # get 0 -- and pass the whole ordered list to a single ax.barh call, so every bar + # reserves its own row slot and stays aligned with the heatmap rows. (Drawing bars + # one category at a time by name lets matplotlib's categorical converter compact the + # non-empty ones onto the top rows -- see test_shap_impact_bar_aligns_with_correct_row.) df = df_feat[[col_cat, col_imp]] - dict_total = {} - for cat in list_cat: - vals = df.loc[df[col_cat] == cat, col_imp].values - left = 0.0 - for v in vals: - if v > 0: - ax.barh(cat, v, left=left, color=ut.COLOR_SHAP_POS, - edgecolor="white", linewidth=0.3, align="edge") - left += v - for v in vals: - if v < 0: - ax.barh(cat, abs(v), left=left, color=ut.COLOR_SHAP_NEG, - edgecolor="white", linewidth=0.3, align="edge") - left += abs(v) - dict_total[cat] = left + s_pos = df[df[col_imp] > 0].groupby(by=col_cat)[col_imp].sum() + s_neg = df[df[col_imp] < 0].groupby(by=col_cat)[col_imp].sum() + list_pos = [s_pos.get(x, 0) for x in list_cat] + list_neg = [s_neg.get(x, 0) for x in list_cat] + ax.barh(list_cat, list_pos, color=ut.COLOR_SHAP_POS, edgecolor=None, align="edge") + ax.barh(list_cat, list_neg, color=ut.COLOR_SHAP_NEG, edgecolor=None, align="edge") else: # Get feature importance per scale class df_imp = df_feat[[col_cat, col_imp]].groupby(by=col_cat).sum() @@ -77,10 +72,14 @@ def plot_feat_importance_bars_subcat(ax=None, annotation_th = v_max / 2 if annotation_th is None else annotation_th for i, val in enumerate(list_imp): if val >= annotation_th: - ax.text(val, i + 0.45, f"{round(val, 1)}% ", - va="center", ha="right", + # Label just OUTSIDE the bar tip (ha="left" -> extends right, away from the + # heatmap) with clip_on=False, so it is never cut. Drawing it inside the bar + # (right-anchored, white) clips the leading digit on short bars, hiding it + # under the heatmap (e.g. "2.0%" -> "0%"); this restores the never-cut form. + ax.text(val, i + 0.45, f" {round(val, 1)}%", + va="center", ha="left", weight=weight_annotation, - color="white", + clip_on=False, size=fontsize_imp_bar) # Adjust ticks @@ -91,7 +90,7 @@ def plot_feat_importance_bars_subcat(ax=None, label.set_visible(False) if shap_plot: - ax.set_xlim(0, max(list(dict_total.values()) + [0])) + ax.set_xlim(min(list_neg + [0]), max(list_pos + [0])) else: ax.set_xlim(0, v_max) @@ -114,24 +113,16 @@ def plot_feat_importance_bars_pos(ax=None, value_type="sum", normalize=False) # Plot bars if shap_plot: - # Cumulative feature impact per position, stacked in one direction with one - # thin white-edged segment per contributing feature (red=positive, blue=negative) - totals = [] - for j in range(df_pos.shape[1]): - vals = df_pos.iloc[:, j].values - bottom = 0.0 - for v in vals: - if v > 0: - ax.bar(j, v, bottom=bottom, color=ut.COLOR_SHAP_POS, - edgecolor="white", linewidth=0.3, align="edge") - bottom += v - for v in vals: - if v < 0: - ax.bar(j, abs(v), bottom=bottom, color=ut.COLOR_SHAP_NEG, - edgecolor="white", linewidth=0.3, align="edge") - bottom += abs(v) - totals.append(bottom) - ax.set_ylim(0, max(totals + [0])) + # Stack signed feature impact per position (positive=red up, negative=blue down). + # Kept consistent with the per-subcategory (right) bars: both are signed and + # diverging, so a net-negative position reads as a downward blue bar rather than + # being folded into an upward magnitude sum. + list_pos = list(df_pos[df_pos > 0].sum()) + list_neg = list(df_pos[df_pos < 0].sum()) + x_ticks = list(range(0, len(list_pos))) + ax.bar(x_ticks, list_pos, color=ut.COLOR_SHAP_POS, edgecolor=None, align="edge") + ax.bar(x_ticks, list_neg, color=ut.COLOR_SHAP_NEG, edgecolor=None, align="edge") + ax.set_ylim(min(list_neg + [0]), max(list_pos + [0])) else: list_imp = list(df_pos.sum()) x_ticks = list(range(0, len(list_imp))) diff --git a/tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py b/tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py index f3d5ac68..95df88c8 100644 --- a/tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py +++ b/tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py @@ -818,6 +818,79 @@ def test_shap_plot_true_returns_fig_ax(self): assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes) plt.close() + def test_shap_impact_bar_aligns_with_correct_row(self): + # Regression: the per-subcategory SHAP impact bar must sit on its OWN row, not be + # compacted onto the top rows. Put all impact in a single subcategory and check the + # bar's y-position matches that subcategory's heatmap row index. col_val is a + # mean-difference column (not a feat_impact column) so the side bars are shown. + cpp_plot = aa.CPPPlot() + df_feat = aa.load_features("DOM_GSEC") + subcats = list(dict.fromkeys(df_feat["subcategory"]))[:12] + df_feat = df_feat[df_feat["subcategory"].isin(subcats)].copy() + target = subcats[-1] # a subcategory near the bottom of the row order + df_feat[COL_FEAT_IMPACT_TEST] = [5.0 if s == target else 0.0 for s in df_feat["subcategory"]] + fig, hm = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True, + col_val="mean_dif", col_imp=COL_FEAT_IMPACT_TEST) + fig.canvas.draw() + ylabels = [t.get_text() for t in hm.get_yticklabels()] + row = ylabels.index(target) + hm_x0 = hm.get_position().x0 + bar_ax = next(a for a in fig.get_axes() + if a is not hm and a.get_position().x0 > hm_x0 + 0.1 + and a.get_position().width < 0.2 and a.get_position().height > 0.4) + y_centers = {round(p.get_y() + p.get_height() / 2) + for p in bar_ax.patches if getattr(p, "get_width", lambda: 0)() > 0.01} + assert y_centers == {row}, (y_centers, "expected only row", row) + plt.close() + + def test_shap_impact_bars_map_one_to_one_to_rows(self): + # Stronger alignment guard: give several distinct subcategories impact and verify + # each bar sits on exactly its own heatmap row (no compaction, no reordering). + cpp_plot = aa.CPPPlot() + df_feat = aa.load_features("DOM_GSEC") + subcats = list(dict.fromkeys(df_feat["subcategory"]))[:14] + df_feat = df_feat[df_feat["subcategory"].isin(subcats)].copy() + targets = {subcats[1], subcats[6], subcats[-1]} # well-separated rows + df_feat[COL_FEAT_IMPACT_TEST] = [3.0 if s in targets else 0.0 + for s in df_feat["subcategory"]] + fig, hm = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True, + col_val="mean_dif", col_imp=COL_FEAT_IMPACT_TEST) + fig.canvas.draw() + ylabels = [t.get_text() for t in hm.get_yticklabels()] + expected_rows = {ylabels.index(t) for t in targets} + hm_x0 = hm.get_position().x0 + bar_ax = next(a for a in fig.get_axes() + if a is not hm and a.get_position().x0 > hm_x0 + 0.1 + and a.get_position().width < 0.2 and a.get_position().height > 0.4) + bar_rows = {round(p.get_y() + p.get_height() / 2) + for p in bar_ax.patches if getattr(p, "get_width", lambda: 0)() > 0.01} + assert bar_rows == expected_rows, (sorted(bar_rows), "expected", sorted(expected_rows)) + plt.close() + + def test_importance_bar_labels_extend_outward_not_clipped(self): + # The cumulative-importance % labels must extend outward (to the right of the bar + # tip), never left past the x=0 baseline into the heatmap where they'd be hidden/ + # clipped. Regression guard for the inside-bar (right-anchored, white) rendering + # that cut the leading digit of short-bar labels. + cpp_plot = aa.CPPPlot() + df_feat = get_df_feat() + fig, hm = cpp_plot.feature_map(df_feat=df_feat) + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + hm_x0 = hm.get_position().x0 + left_edges = [] + for a in fig.axes: # the right importance bar: narrow + tall, right of the heatmap + p = a.get_position() + if a is hm or p.x0 <= hm_x0 + 0.1 or p.width >= 0.2 or p.height <= 0.4: + continue + for t in a.texts: + if t.get_text().strip().endswith("%"): + ext = t.get_window_extent(renderer) + left_edges.append(a.transData.inverted().transform((ext.x0, ext.y0))[0]) + assert left_edges, "no % importance-bar labels found" + assert min(left_edges) >= -1e-6, f"a label extends left of the baseline: {min(left_edges)}" + plt.close() + def test_default_bars_are_gray_not_signed(self): """shap_plot=False keeps the gray cumulative bars and shows no SHAP +/- colors.""" cpp_plot = aa.CPPPlot() @@ -839,21 +912,27 @@ def test_shap_bars_are_red_and_blue(self): assert _rgba(FEAT_IMP_GRAY) not in colors plt.close() - def test_shap_bars_are_one_direction(self): - """The cumulative impact bars are stacked in one direction (no negative extent).""" + def test_shap_bars_are_signed_diverging(self): + """SHAP impact bars are signed and diverge from the zero baseline: positive impact + extends one way, negative the other. The right per-subcategory bars diverge + horizontally (some negative width); the top per-position bars diverge vertically + (some negative height). (Regression guard: a magnitude-only 'one direction' stack + would fold net-negative rows/positions the wrong way -- checked on get_width()/ + get_height(), since barh/bar keep the signed extent there, not in get_x()/get_y().)""" cpp_plot = aa.CPPPlot() - df_feat = get_df_feat_shap() + df_feat = get_df_feat_shap() # alternating +/- impact -> both signs present fig, _ = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True, col_imp=COL_FEAT_IMPACT_TEST, col_val=COL_MEAN_DIF_TEST) - # SHAP-colored impact bars start at the zero baseline (one-direction cumulative stack) shap_rgba = {_rgba(SHAP_POS), _rgba(SHAP_NEG)} - n_shap_bars = 0 + widths, heights = [], [] for ax in fig.axes: for p in ax.patches: if tuple(round(x, 3) for x in p.get_facecolor()) in shap_rgba: - n_shap_bars += 1 - assert round(p.get_x(), 6) >= 0 and round(p.get_y(), 6) >= 0 - assert n_shap_bars > 0 + widths.append(round(p.get_width(), 6)) + heights.append(round(p.get_height(), 6)) + assert widths and heights + assert any(w < 0 for w in widths), "expected left-extending (negative) subcategory bars" + assert any(h < 0 for h in heights), "expected down-extending (negative) position bars" plt.close() def test_shap_markers_present(self):