Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 30 additions & 39 deletions aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,19 @@ def plot_feat_importance_bars_subcat(ax=None,
list_cat=df_feat[col_cat].to_list(),
col_cat=col_cat)
if shap_plot:
# Cumulative feature impact per category, stacked in one direction with one
# thin white-edged segment per feature (red=positive, blue=negative)
# Stack signed feature impact per category (positive=red right, negative=blue left).
# Build one value PER category in list_cat (row) order -- categories with no impact
# get 0 -- and pass the whole ordered list to a single ax.barh call, so every bar
# reserves its own row slot and stays aligned with the heatmap rows. (Drawing bars
# one category at a time by name lets matplotlib's categorical converter compact the
# non-empty ones onto the top rows -- see test_shap_impact_bar_aligns_with_correct_row.)
df = df_feat[[col_cat, col_imp]]
dict_total = {}
for cat in list_cat:
vals = df.loc[df[col_cat] == cat, col_imp].values
left = 0.0
for v in vals:
if v > 0:
ax.barh(cat, v, left=left, color=ut.COLOR_SHAP_POS,
edgecolor="white", linewidth=0.3, align="edge")
left += v
for v in vals:
if v < 0:
ax.barh(cat, abs(v), left=left, color=ut.COLOR_SHAP_NEG,
edgecolor="white", linewidth=0.3, align="edge")
left += abs(v)
dict_total[cat] = left
s_pos = df[df[col_imp] > 0].groupby(by=col_cat)[col_imp].sum()
s_neg = df[df[col_imp] < 0].groupby(by=col_cat)[col_imp].sum()
list_pos = [s_pos.get(x, 0) for x in list_cat]
list_neg = [s_neg.get(x, 0) for x in list_cat]
ax.barh(list_cat, list_pos, color=ut.COLOR_SHAP_POS, edgecolor=None, align="edge")
ax.barh(list_cat, list_neg, color=ut.COLOR_SHAP_NEG, edgecolor=None, align="edge")
else:
# Get feature importance per scale class
df_imp = df_feat[[col_cat, col_imp]].groupby(by=col_cat).sum()
Expand All @@ -77,10 +72,14 @@ def plot_feat_importance_bars_subcat(ax=None,
annotation_th = v_max / 2 if annotation_th is None else annotation_th
for i, val in enumerate(list_imp):
if val >= annotation_th:
ax.text(val, i + 0.45, f"{round(val, 1)}% ",
va="center", ha="right",
# Label just OUTSIDE the bar tip (ha="left" -> extends right, away from the
# heatmap) with clip_on=False, so it is never cut. Drawing it inside the bar
# (right-anchored, white) clips the leading digit on short bars, hiding it
# under the heatmap (e.g. "2.0%" -> "0%"); this restores the never-cut form.
ax.text(val, i + 0.45, f" {round(val, 1)}%",
va="center", ha="left",
weight=weight_annotation,
color="white",
clip_on=False,
size=fontsize_imp_bar)

# Adjust ticks
Expand All @@ -91,7 +90,7 @@ def plot_feat_importance_bars_subcat(ax=None,
label.set_visible(False)

if shap_plot:
ax.set_xlim(0, max(list(dict_total.values()) + [0]))
ax.set_xlim(min(list_neg + [0]), max(list_pos + [0]))
else:
ax.set_xlim(0, v_max)

Expand All @@ -114,24 +113,16 @@ def plot_feat_importance_bars_pos(ax=None,
value_type="sum", normalize=False)
# Plot bars
if shap_plot:
# Cumulative feature impact per position, stacked in one direction with one
# thin white-edged segment per contributing feature (red=positive, blue=negative)
totals = []
for j in range(df_pos.shape[1]):
vals = df_pos.iloc[:, j].values
bottom = 0.0
for v in vals:
if v > 0:
ax.bar(j, v, bottom=bottom, color=ut.COLOR_SHAP_POS,
edgecolor="white", linewidth=0.3, align="edge")
bottom += v
for v in vals:
if v < 0:
ax.bar(j, abs(v), bottom=bottom, color=ut.COLOR_SHAP_NEG,
edgecolor="white", linewidth=0.3, align="edge")
bottom += abs(v)
totals.append(bottom)
ax.set_ylim(0, max(totals + [0]))
# Stack signed feature impact per position (positive=red up, negative=blue down).
# Kept consistent with the per-subcategory (right) bars: both are signed and
# diverging, so a net-negative position reads as a downward blue bar rather than
# being folded into an upward magnitude sum.
list_pos = list(df_pos[df_pos > 0].sum())
list_neg = list(df_pos[df_pos < 0].sum())
x_ticks = list(range(0, len(list_pos)))
ax.bar(x_ticks, list_pos, color=ut.COLOR_SHAP_POS, edgecolor=None, align="edge")
ax.bar(x_ticks, list_neg, color=ut.COLOR_SHAP_NEG, edgecolor=None, align="edge")
ax.set_ylim(min(list_neg + [0]), max(list_pos + [0]))
else:
list_imp = list(df_pos.sum())
x_ticks = list(range(0, len(list_imp)))
Expand Down
95 changes: 87 additions & 8 deletions tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,79 @@ def test_shap_plot_true_returns_fig_ax(self):
assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)
plt.close()

def test_shap_impact_bar_aligns_with_correct_row(self):
# Regression: the per-subcategory SHAP impact bar must sit on its OWN row, not be
# compacted onto the top rows. Put all impact in a single subcategory and check the
# bar's y-position matches that subcategory's heatmap row index. col_val is a
# mean-difference column (not a feat_impact column) so the side bars are shown.
cpp_plot = aa.CPPPlot()
df_feat = aa.load_features("DOM_GSEC")
subcats = list(dict.fromkeys(df_feat["subcategory"]))[:12]
df_feat = df_feat[df_feat["subcategory"].isin(subcats)].copy()
target = subcats[-1] # a subcategory near the bottom of the row order
df_feat[COL_FEAT_IMPACT_TEST] = [5.0 if s == target else 0.0 for s in df_feat["subcategory"]]
fig, hm = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True,
col_val="mean_dif", col_imp=COL_FEAT_IMPACT_TEST)
fig.canvas.draw()
ylabels = [t.get_text() for t in hm.get_yticklabels()]
row = ylabels.index(target)
hm_x0 = hm.get_position().x0
bar_ax = next(a for a in fig.get_axes()
if a is not hm and a.get_position().x0 > hm_x0 + 0.1
and a.get_position().width < 0.2 and a.get_position().height > 0.4)
y_centers = {round(p.get_y() + p.get_height() / 2)
for p in bar_ax.patches if getattr(p, "get_width", lambda: 0)() > 0.01}
assert y_centers == {row}, (y_centers, "expected only row", row)
plt.close()

def test_shap_impact_bars_map_one_to_one_to_rows(self):
# Stronger alignment guard: give several distinct subcategories impact and verify
# each bar sits on exactly its own heatmap row (no compaction, no reordering).
cpp_plot = aa.CPPPlot()
df_feat = aa.load_features("DOM_GSEC")
subcats = list(dict.fromkeys(df_feat["subcategory"]))[:14]
df_feat = df_feat[df_feat["subcategory"].isin(subcats)].copy()
targets = {subcats[1], subcats[6], subcats[-1]} # well-separated rows
df_feat[COL_FEAT_IMPACT_TEST] = [3.0 if s in targets else 0.0
for s in df_feat["subcategory"]]
fig, hm = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True,
col_val="mean_dif", col_imp=COL_FEAT_IMPACT_TEST)
fig.canvas.draw()
ylabels = [t.get_text() for t in hm.get_yticklabels()]
expected_rows = {ylabels.index(t) for t in targets}
hm_x0 = hm.get_position().x0
bar_ax = next(a for a in fig.get_axes()
if a is not hm and a.get_position().x0 > hm_x0 + 0.1
and a.get_position().width < 0.2 and a.get_position().height > 0.4)
bar_rows = {round(p.get_y() + p.get_height() / 2)
for p in bar_ax.patches if getattr(p, "get_width", lambda: 0)() > 0.01}
assert bar_rows == expected_rows, (sorted(bar_rows), "expected", sorted(expected_rows))
plt.close()

def test_importance_bar_labels_extend_outward_not_clipped(self):
# The cumulative-importance % labels must extend outward (to the right of the bar
# tip), never left past the x=0 baseline into the heatmap where they'd be hidden/
# clipped. Regression guard for the inside-bar (right-anchored, white) rendering
# that cut the leading digit of short-bar labels.
cpp_plot = aa.CPPPlot()
df_feat = get_df_feat()
fig, hm = cpp_plot.feature_map(df_feat=df_feat)
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
hm_x0 = hm.get_position().x0
left_edges = []
for a in fig.axes: # the right importance bar: narrow + tall, right of the heatmap
p = a.get_position()
if a is hm or p.x0 <= hm_x0 + 0.1 or p.width >= 0.2 or p.height <= 0.4:
continue
for t in a.texts:
if t.get_text().strip().endswith("%"):
ext = t.get_window_extent(renderer)
left_edges.append(a.transData.inverted().transform((ext.x0, ext.y0))[0])
assert left_edges, "no % importance-bar labels found"
assert min(left_edges) >= -1e-6, f"a label extends left of the baseline: {min(left_edges)}"
plt.close()

def test_default_bars_are_gray_not_signed(self):
"""shap_plot=False keeps the gray cumulative bars and shows no SHAP +/- colors."""
cpp_plot = aa.CPPPlot()
Expand All @@ -839,21 +912,27 @@ def test_shap_bars_are_red_and_blue(self):
assert _rgba(FEAT_IMP_GRAY) not in colors
plt.close()

def test_shap_bars_are_one_direction(self):
"""The cumulative impact bars are stacked in one direction (no negative extent)."""
def test_shap_bars_are_signed_diverging(self):
"""SHAP impact bars are signed and diverge from the zero baseline: positive impact
extends one way, negative the other. The right per-subcategory bars diverge
horizontally (some negative width); the top per-position bars diverge vertically
(some negative height). (Regression guard: a magnitude-only 'one direction' stack
would fold net-negative rows/positions the wrong way -- checked on get_width()/
get_height(), since barh/bar keep the signed extent there, not in get_x()/get_y().)"""
cpp_plot = aa.CPPPlot()
df_feat = get_df_feat_shap()
df_feat = get_df_feat_shap() # alternating +/- impact -> both signs present
fig, _ = cpp_plot.feature_map(df_feat=df_feat, shap_plot=True,
col_imp=COL_FEAT_IMPACT_TEST, col_val=COL_MEAN_DIF_TEST)
# SHAP-colored impact bars start at the zero baseline (one-direction cumulative stack)
shap_rgba = {_rgba(SHAP_POS), _rgba(SHAP_NEG)}
n_shap_bars = 0
widths, heights = [], []
for ax in fig.axes:
for p in ax.patches:
if tuple(round(x, 3) for x in p.get_facecolor()) in shap_rgba:
n_shap_bars += 1
assert round(p.get_x(), 6) >= 0 and round(p.get_y(), 6) >= 0
assert n_shap_bars > 0
widths.append(round(p.get_width(), 6))
heights.append(round(p.get_height(), 6))
assert widths and heights
assert any(w < 0 for w in widths), "expected left-extending (negative) subcategory bars"
assert any(h < 0 for h in heights), "expected down-extending (negative) position bars"
plt.close()

def test_shap_markers_present(self):
Expand Down
Loading