Source code for hugiml.pruning

# Copyright 2026 Srikumar Krishnamoorthy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Regulated "remove / refit / calibrate" workflow for HUG-IML.

EBMs are valued partly because model terms can be inspected and sometimes
edited (e.g. to remove an ethically problematic interaction term).  This
module gives HUG-IML an analogous *controlled editing* workflow that is
rigorous enough for regulated-domain review cycles.

Workflow
--------
1. Inspect patterns via ``clf.feature_importances()`` or ``clf.get_pattern_info()``.
2. Create a ``PatternEditor`` and call ``remove()`` with a list of pattern
   indices (or keyword filters).
3. Call ``refit(X_tr, y_tr)`` to re-train the *downstream* classifier on
   the pruned pattern matrix.  The C++ mining results are unchanged.
4. Optionally call ``calibrate(X_cal, y_cal)`` to wrap the refitted model
   with Platt scaling / isotonic regression.
5. Call ``finalize()`` to get a new classifier instance with the edited
   pattern set baked in, and ``audit_report()`` for a JSON audit trail.

Example
-------
    from hugiml.pruning import PatternEditor

    editor = PatternEditor(clf)
    editor.remove([3, 7, 12], reason="pattern references protected attribute 'gender'")
    editor.remove_by_keyword("income", reason="unstable feature (high PSI)")
    new_clf = editor.refit(X_tr, y_tr).calibrate(X_cal, y_cal).finalize()

    print(editor.audit_report())
    new_clf.predict_proba(X_te)
"""

from __future__ import annotations

import copy
import dataclasses
import datetime
import json
import warnings
from typing import Any

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

__all__ = [
    "PatternEditor",
    "RemovalRecord",
]


# ---------------------------------------------------------------------------
# Audit entry
# ---------------------------------------------------------------------------


[docs] @dataclasses.dataclass class RemovalRecord: """Audit record for a single pattern-removal action.""" timestamp: str pattern_indices: list[int] pattern_labels: list[str] reason: str removed_by: str
# --------------------------------------------------------------------------- # PatternEditor # ---------------------------------------------------------------------------
[docs] class PatternEditor: """Controlled pattern editing with full audit trail. Parameters ---------- clf : fitted HUGIMLClassifierNative The *original* model. This object is **not** mutated; all edits produce a fresh copy stored internally. operator_name : str Human-readable identifier of the person/process making the edits (for the audit trail). """ def __init__(self, clf: Any, operator_name: str = "analyst") -> None: if not hasattr(clf, "patterns_"): raise RuntimeError("Classifier must be fitted before creating a PatternEditor.") self._orig_clf = clf self._operator = operator_name self._working_clf: Any = copy.deepcopy(clf) self._audit_log: list[RemovalRecord] = [] self._calibrated: bool = False self._calibration_method: str | None = None self._finalized: bool = False self._refitted: bool = False # ------------------------------------------------------------------ # Removal operations # ------------------------------------------------------------------
[docs] def remove( self, pattern_indices: list[int], reason: str = "unspecified", ) -> PatternEditor: """Remove patterns by index (0-based, relative to the *current* working set). Parameters ---------- pattern_indices : list of int Indices into the *current* pattern list. Use ``list_patterns()`` to preview indices. reason : str Audit reason (e.g. 'protected attribute', 'operationally invalid'). Returns ------- self (for method chaining) """ self._check_not_finalized() clf = self._working_clf n = len(clf.patterns_) valid = [i for i in pattern_indices if 0 <= i < n] invalid = set(pattern_indices) - set(valid) if invalid: warnings.warn(f"PatternEditor.remove: ignoring out-of-range indices {invalid}.") if not valid: return self all_labels = clf.get_hug_features() removed_labels = [all_labels[i] for i in valid] keep_mask = np.ones(n, dtype=bool) for i in valid: keep_mask[i] = False keep_idx = np.where(keep_mask)[0] # Update patterns, structural pattern metadata, and the retained # training pattern matrix so downstream refits see a consistent # pattern space. clf.patterns_ = [clf.patterns_[i] for i in keep_idx] if hasattr(clf, "_pattern_orders_"): clf._pattern_orders_ = np.asarray(clf._pattern_orders_)[keep_idx] if hasattr(clf, "_interaction_pattern_mask_"): clf._interaction_pattern_mask_ = np.asarray(clf._interaction_pattern_mask_, dtype=bool)[ keep_idx ] if hasattr(clf, "x_train_hup_") and clf.x_train_hup_ is not None: clf.x_train_hup_ = clf.x_train_hup_[:, keep_idx] self._audit_log.append( RemovalRecord( timestamp=datetime.datetime.utcnow().isoformat(), pattern_indices=list(valid), pattern_labels=removed_labels, reason=reason, removed_by=self._operator, ) ) return self
[docs] def remove_by_keyword( self, keyword: str, reason: str = "keyword match", case_sensitive: bool = False, ) -> PatternEditor: """Remove all patterns whose label contains ``keyword``. Parameters ---------- keyword : str reason : str case_sensitive : bool Returns ------- self """ self._check_not_finalized() labels = self._working_clf.get_hug_features() if not case_sensitive: matches = [i for i, lbl in enumerate(labels) if keyword.lower() in lbl.lower()] else: matches = [i for i, lbl in enumerate(labels) if keyword in lbl] if not matches: warnings.warn(f"PatternEditor.remove_by_keyword: no patterns matched '{keyword}'.") return self return self.remove(matches, reason=f"{reason} (keyword='{keyword}')")
[docs] def remove_low_support( self, min_support: float = 0.01, reason: str = "support below threshold", ) -> PatternEditor: """Remove patterns with training support below ``min_support``. Parameters ---------- min_support : float Minimum fraction of training samples (0 to 1). reason : str Returns ------- self """ self._check_not_finalized() clf = self._working_clf if not hasattr(clf, "x_train_hup_"): raise RuntimeError("x_train_hup_ not available — cannot filter by support.") n_train = clf.x_train_hup_.shape[0] supports = np.asarray(clf.x_train_hup_.sum(axis=0)).ravel() / n_train to_remove = [i for i, s in enumerate(supports) if s < min_support] return self.remove(to_remove, reason=f"{reason} (min_support={min_support})")
# ------------------------------------------------------------------ # Refit downstream classifier # ------------------------------------------------------------------
[docs] def refit( self, X_tr: Any, y_tr: Any, estimator: Any = None, ) -> PatternEditor: """Refit the downstream classifier on the (pruned) pattern matrix. The HUG mining results (``patterns_``) are unchanged; only the downstream ``Pipeline`` (``model_``) is replaced. Parameters ---------- X_tr : array-like or DataFrame Training data (should be the same split used to fit the original model). y_tr : array-like estimator : sklearn estimator, optional If None, uses the original downstream estimator class with the same hyperparameters. Returns ------- self """ self._check_not_finalized() clf = self._working_clf import copy as _copy from sklearn.pipeline import Pipeline # Build pattern matrix for training data. After remove(), the working # classifier already carries the original training HUP matrix with the # same columns pruned. Reuse it when the caller passes the same-row # training split; rebuilding it by scanning every pattern can dominate # refit time for large adaptive models. y_arr = np.asarray(y_tr, dtype=np.int64) cached_hup = getattr(clf, "x_train_hup_", None) if cached_hup is not None and cached_hup.shape[0] == len(y_arr): clf.x_train_hup_ = csr_matrix(cached_hup, dtype=np.float32) else: hup_tr = clf.transform(X_tr) clf.x_train_hup_ = csr_matrix(hup_tr, dtype=np.float32) # Recompute downstream construction metadata after pruning. Hybrid # feature modes may have strict TopK masks and selected original-feature # catalogs that are aligned to the pre-pruned pattern space. for attr in ( "_downstream_feature_names_full_", "_strict_topk_feature_scores_", "_strict_topk_feature_mask_", "_strict_topk_selected_feature_names_", "_strict_topk_applied_during_construction_", "_downstream_pattern_support_", "_downstream_non_missing_rate_", "_downstream_variance_", "_original_feature_mask_downstream_", "_original_feature_scores_downstream_", "_original_selected_feature_names_downstream_", "_original_feature_names_downstream_full_", ): clf.__dict__.pop(attr, None) clf._current_y_for_downstream_topk_ = y_arr X_downstream = clf._make_downstream_features(X_tr, clf.x_train_hup_, fit=True) X_downstream = clf._apply_strict_topk_budget_fit(X_downstream, y_arr) clf.x_train_downstream_ = X_downstream try: clf._cache_downstream_feature_metadata() except Exception: pass finally: clf.__dict__.pop("_current_y_for_downstream_topk_", None) if estimator is not None: new_est = _copy.deepcopy(estimator) else: orig_clf_step = self._orig_clf.model_.named_steps.get("clf") new_est = _copy.deepcopy(orig_clf_step) new_model = Pipeline([("clf", new_est)]) new_model.fit(X_downstream, y_arr) clf.model_ = new_model self._refitted = True self._calibrated = False self._calibration_method = None return self
# ------------------------------------------------------------------ # Calibration # ------------------------------------------------------------------
[docs] def calibrate( self, X_cal: Any, y_cal: Any, method: str = "isotonic", ) -> PatternEditor: """Wrap the refitted downstream model with probability calibration. Uses ``sklearn.calibration.CalibratedClassifierCV`` applied *post-fit* to a calibration set that should be held out from both training and test. Parameters ---------- X_cal : array-like or DataFrame y_cal : array-like method : {'sigmoid', 'isotonic'} Returns ------- self """ self._check_not_finalized() from sklearn.calibration import CalibratedClassifierCV from sklearn.pipeline import Pipeline clf = self._working_clf hup_cal = csr_matrix(clf.transform(X_cal), dtype=np.float32) X_downstream_cal = clf._make_downstream_features(X_cal, hup_cal, fit=False) X_downstream_cal = clf._apply_strict_topk_budget_transform(X_downstream_cal) y_arr = np.asarray(y_cal, dtype=np.int64) inner_clf = clf.model_.named_steps["clf"] # sklearn >= 1.8 removed cv="prefit"; cv=None with ensemble=False # is the documented replacement, preserving pre-fitted behaviour. try: cal_clf = CalibratedClassifierCV(inner_clf, cv="prefit", method=method) cal_clf.fit(X_downstream_cal, y_arr) except (ValueError, TypeError): cal_clf = CalibratedClassifierCV(inner_clf, cv=None, ensemble=False, method=method) cal_clf.fit(X_downstream_cal, y_arr) clf.model_ = Pipeline([("clf", cal_clf)]) self._calibrated = True self._calibration_method = method return self
# ------------------------------------------------------------------ # Finalize # ------------------------------------------------------------------
[docs] def finalize(self) -> Any: """Return the edited classifier as a new standalone instance. After calling ``finalize()``, further edits on this editor are blocked. The returned object is a fully independent copy. Returns ------- HUGIMLClassifierNative (edited copy) """ if not self._refitted: raise RuntimeError( "Call refit() before finalize() — the downstream model must be retrained " "after pattern removal." ) self._finalized = True return copy.deepcopy(self._working_clf)
# ------------------------------------------------------------------ # Inspection helpers # ------------------------------------------------------------------
[docs] def list_patterns(self) -> pd.DataFrame: """Return editable HUG patterns in the current working model. PatternEditor edits mined HUG patterns only. Original features and augmented-pair downstream features are visible through ``list_downstream_features()`` but are not directly removable by this editor. """ clf = self._working_clf labels = clf.get_hug_features() n_train = clf.x_train_hup_.shape[0] if hasattr(clf, "x_train_hup_") else None try: imp = clf.feature_importances() imp_pat = imp[imp.get("feature_type", "pattern") == "pattern"] coef_map = dict(zip(imp_pat["pattern"], imp_pat["coefficient"])) support_col = "pattern_support" if "pattern_support" in imp_pat.columns else "support" sup_map = dict(zip(imp_pat["pattern"], imp_pat[support_col])) except Exception: coef_map = {} sup_map = {} rows = [] for i, lbl in enumerate(labels): sup = sup_map.get(lbl, None) if sup is None and n_train and hasattr(clf, "x_train_hup_"): sup = float(clf.x_train_hup_[:, i].sum()) / n_train rows.append( { "idx": i, "pattern": lbl, "feature_type": "pattern", "editable": True, "coefficient": coef_map.get(lbl, float("nan")), "pattern_support": round(sup, 4) if sup is not None else float("nan"), "support": round(sup, 4) if sup is not None else float("nan"), } ) return pd.DataFrame(rows)
[docs] def list_downstream_features(self) -> pd.DataFrame: """Return all downstream features with PatternEditor editability. The returned table includes original features, HUG patterns, and augmented-pair transforms when present. Only rows with ``feature_type == 'pattern'`` are directly editable through ``remove()`` and related PatternEditor methods. """ clf = self._working_clf try: imp = clf.feature_importances().copy() except Exception: names = list(getattr(clf, "get_downstream_features", lambda: [])()) imp = pd.DataFrame({"feature": names}) imp["display_name"] = names imp["feature_type"] = [ "pattern" if str(name).startswith("pattern:") else "augmented_pair" if str(name).startswith("augmented_pair:") else "original" for name in names ] imp["editable"] = imp["feature_type"].eq("pattern") imp["editor_scope"] = np.where( imp["editable"], "editable_pattern", "downstream_context_only" ) return imp.reset_index(drop=True)
[docs] def diff(self) -> dict: """Return a summary of changes made relative to the original model. Returns ------- dict with keys: n_original, n_current, n_removed, removed_patterns """ original_labels = set(self._orig_clf.get_hug_features()) current_labels = set(self._working_clf.get_hug_features()) removed = sorted(original_labels - current_labels) downstream = self.list_downstream_features() non_editable = ( downstream[~downstream["editable"]] if "editable" in downstream else pd.DataFrame() ) return { "scope": "hug_patterns_only", "n_original": len(original_labels), "n_current": len(current_labels), "n_removed": len(removed), "removed_patterns": removed, "n_downstream_features_current": int(len(downstream)), "n_non_editable_downstream_features_current": int(len(non_editable)), "non_editable_feature_types_current": ( sorted(non_editable["feature_type"].dropna().unique().tolist()) if "feature_type" in non_editable else [] ), }
# ------------------------------------------------------------------ # Audit report # ------------------------------------------------------------------
[docs] def audit_report(self, indent: int = 2) -> str: """Return a JSON string describing all edits made. The report includes operator name, timestamps, reasons, and the diff summary. """ report = { "operator": self._operator, "generated_at": datetime.datetime.utcnow().isoformat(), "diff": self.diff(), "calibration": { "applied": self._calibrated, "method": self._calibration_method, }, "editor_scope": "hug_patterns_only", "non_editable_downstream_features_visible": True, "removals": [dataclasses.asdict(r) for r in self._audit_log], } return json.dumps(report, indent=indent, default=str)
[docs] def save_audit_report(self, path: str) -> None: """Write the audit report to a JSON file.""" with open(path, "w", encoding="utf-8") as fh: fh.write(self.audit_report())
# ------------------------------------------------------------------ # Context manager support (optional) # ------------------------------------------------------------------ def __enter__(self) -> PatternEditor: return self def __exit__(self, *exc) -> None: pass # Caller is responsible for calling finalize() def __repr__(self) -> str: diff = self.diff() return ( f"PatternEditor(" f"original={diff['n_original']}, " f"current={diff['n_current']}, " f"removed={diff['n_removed']}, " f"calibrated={self._calibrated})" ) # ------------------------------------------------------------------ def _check_not_finalized(self) -> None: if self._finalized: raise RuntimeError( "PatternEditor has already been finalized. " "Create a new PatternEditor to make further edits." )