Tutorial 03 — Advanced usage: boost selection, extraction, and custom configs

This tutorial covers the two iterative pipelines and shows how to author model configs from scratch.

Topics:

  • n_iter_boost_selection: iteratively narrow the feature set before final model selection (useful when starting with thousands of genes)

  • n_iter_extraction: identify top per-class regulatory factors across multiple rounds of selection + explanation

  • Writing custom JSON config files for any supported model family

  • Inspecting per-unit reports stored inside a Deck

  • Saving and reloading trained models

[ ]:
import json
import os
import shutil
import tempfile
import warnings

import pandas as pd

from ageas import Hangar, n_iter_boost_selection, n_iter_extraction
from ageas.tool import Multimodal_Corpus, make_fake_adata

warnings.filterwarnings("ignore")

Shared setup

[ ]:
adata = make_fake_adata(n_class=2, n_clusters_per_class=2)
adata_path = "ageas_tut03.h5ad"
adata.write_h5ad(adata_path)

corpus = Multimodal_Corpus(adata_path, label_key="celltype", backed=False)
print(f"Corpus: {len(corpus)} cells, {corpus.adata.n_vars} features")

1. Custom model configs

Each JSON file in the config folder corresponds to one trained model instance (a Unit). The sub-folder name selects the model family.

Sub-folder

Model class

Backend

logreg

LogReg_Classifier

scikit-learn

svc

SVM_Classifier

scikit-learn

mnb

MNB_Classifier

scikit-learn

xgb

XGB_Classifier

XGBoost

mlp

NN_Classifier (MLP)

PyTorch Lightning

resnet

NN_Classifier (ResNet)

PyTorch Lightning

rnn

NN_Classifier (RNN)

PyTorch Lightning

Required keys in each JSON: max_epochs, model_params. Neural-net configs also have a train_config section.

[ ]:
config_dir = tempfile.mkdtemp(prefix="ageas_cfg_")

# --- Logistic regression with L1 penalty ---
os.makedirs(os.path.join(config_dir, "logreg"))
logreg_cfg = {
    "max_epochs": 1,
    "model_params": {
        "penalty": "l1",
        "C": 1.0,
        "solver": "saga",
        "tol": 1e-4,
    },
}
with open(os.path.join(config_dir, "logreg", "logreg_l1_C1.json"), "w") as fh:
    json.dump(logreg_cfg, fh, indent=4)

# --- XGBoost ---
os.makedirs(os.path.join(config_dir, "xgb"))
xgb_cfg = {
    "max_epochs": 1,
    "model_params": {
        "booster": "gbtree",
        "max_depth": 4,
        "eta": 0.1,
    },
    "train_config": {
        "num_boost_round": 50,
        "early_stopping_rounds": None,
        "verbose_eval": False,
    },
}
with open(os.path.join(config_dir, "xgb", "xgb_fast.json"), "w") as fh:
    json.dump(xgb_cfg, fh, indent=4)

hangar = Hangar(config_dir)
print(f"Hangar: {len(hangar.units)} units | Units: {list(hangar.units.keys())}")

2. n_iter_boost_selection

Boost selection is most useful when you start with thousands of genes and want to narrow down to a compact candidate set before final training. Each iteration:

  1. Runs n_kfold_selection on the current (shrinking) feature set

  2. Debriefs the surviving models to score features

  3. Keeps the top extract_top_n / extract_ratio^(iter) genes

After all boost iterations a final selection is run on the pruned features.

[ ]:
deck, fea_kept = n_iter_boost_selection(
    hangar=hangar,
    query_dataset=corpus,
    test_dataset=corpus,
    max_boost_iter=2,
    extract_ratio=0.5,
    extract_top_n=8,          # target: 8 genes after boosting
    seed=42,
    verbose=False,
    selection_args=dict(
        kfold_selection_list=[2],
        valid_fraction=0.1,
        monitor_metric="test.accuracy",
        retention_point=0.5,
        cutoff_point=0.0,
    ),
)
shutil.rmtree("cache", ignore_errors=True)

print(f"Features kept  : {len(fea_kept)}")
print(f"Surviving units: {len(deck.squad)}")
print("Retained genes :", fea_kept)

3. n_iter_extraction

n_iter_extraction is the full regulatory factor discovery pipeline. It goes beyond boost selection by:

  • Removing IQR outlier features between iterations (genes with extreme attribution scores are flagged rather than trusted)

  • L1-aggregating explanation scores across all iterations

  • Returning a ranked factor table per class

The returned top_factors table has a column Outlier_Iter — genes pruned as outliers in iteration i have Outlier_Iter = i; normal genes have Outlier_Iter = -1.

[ ]:
top_factors, final_exp = n_iter_extraction(
    hangar=hangar,
    query_dataset=corpus,
    test_dataset=corpus,
    max_extraction_iter=2,
    extract_top_n=5,          # top 5 factors per class
    use_gene_names=True,      # index by adata.var['name']
    seed=42,
    verbose=False,
    selection_args=dict(
        kfold_selection_list=[2],
        valid_fraction=0.1,
        monitor_metric="test.accuracy",
        retention_point=0.5,
        cutoff_point=0.0,
    ),
)
shutil.rmtree("cache", ignore_errors=True)

print("Top regulatory factors:")
display(top_factors)
[ ]:
print("Final integrated explanation (first 5 rows):")
display(final_exp.head())

4. Inspecting per-unit reports

After every selection pass, each Unit in deck.squad accumulates a nested report dict:

unit.report[operation_name]['fold_1']   → per-fold metrics
unit.report[operation_name]['final']    → last-mission metrics

This is useful for diagnosing which models perform well and on which folds.

[ ]:
rows = []
for uid, unit in deck.squad.items():
    for op_name, op_report in unit.report.items():
        if "final" not in op_report:
            continue
        final = op_report["final"]
        source = final["test"] if final["test"] is not None else final["vali"]
        prefix = "test" if final["test"] is not None else "vali"
        rows.append({
            "unit": uid,
            "operation": op_name,
            "accuracy": source.get(f"{prefix}.accuracy"),
            "auroc":    source.get(f"{prefix}.auroc"),
        })

pd.DataFrame(rows).sort_values("accuracy", ascending=False)

5. Saving and reloading trained models

Each Unit has a .model attribute whose save_model / load_model methods persist the underlying estimator in its native format (JSON for sklearn models, XGBoost binary format, PyTorch state-dict for NN models).

[ ]:
save_dir = tempfile.mkdtemp(prefix="ageas_saved_")

for uid, unit in deck.squad.items():
    model_path = os.path.join(save_dir, f"{uid}.model")
    unit.model.save_model(model_path)
    print(f"Saved {uid}{model_path}")

# To reload later:
#   unit.model.load_model(model_path)

# Cleanup
shutil.rmtree(save_dir, ignore_errors=True)
shutil.rmtree(config_dir, ignore_errors=True)