diff --git a/lars/util/confusion_matrix.py b/lars/util/confusion_matrix.py index c5ef756..e019673 100644 --- a/lars/util/confusion_matrix.py +++ b/lars/util/confusion_matrix.py @@ -25,10 +25,10 @@ def plot_confusion_matrix(df, label_col='label', pred_col='llm_label', normalize ax = plt.gca() cm = confusion_matrix(true_labels, pred_labels, normalize=normalize) - disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_,) - disp.plot(ax=ax, cmap=plt.cm.Blues) + disp.plot(ax=ax, cmap=plt.cm.Blues, xticks_rotation=45) ax.set_title('Confusion Matrix')