{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 03 — Advanced usage: boost selection, extraction, and custom configs\n", "\n", "This tutorial covers the two iterative pipelines and shows how to author\n", "model configs from scratch.\n", "\n", "**Topics:**\n", "- `n_iter_boost_selection`: iteratively narrow the feature set before final\n", " model selection (useful when starting with thousands of genes)\n", "- `n_iter_extraction`: identify top per-class regulatory factors across\n", " multiple rounds of selection + explanation\n", "- Writing custom JSON config files for any supported model family\n", "- Inspecting per-unit reports stored inside a `Deck`\n", "- Saving and reloading trained models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import shutil\n", "import tempfile\n", "import warnings\n", "\n", "import pandas as pd\n", "\n", "from ageas import Hangar, n_iter_boost_selection, n_iter_extraction\n", "from ageas.tool import Multimodal_Corpus, make_fake_adata\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Shared setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adata = make_fake_adata(n_class=2, n_clusters_per_class=2)\n", "adata_path = \"ageas_tut03.h5ad\"\n", "adata.write_h5ad(adata_path)\n", "\n", "corpus = Multimodal_Corpus(adata_path, label_key=\"celltype\", backed=False)\n", "print(f\"Corpus: {len(corpus)} cells, {corpus.adata.n_vars} features\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Custom model configs\n", "\n", "Each JSON file in the config folder corresponds to one trained model instance\n", "(a `Unit`). The sub-folder name selects the model family.\n", "\n", "| Sub-folder | Model class | Backend |\n", "|------------|-------------|---------|\n", "| `logreg` | `LogReg_Classifier` | scikit-learn |\n", "| `svc` | `SVM_Classifier` | scikit-learn |\n", "| `mnb` | `MNB_Classifier` | scikit-learn |\n", "| `xgb` | `XGB_Classifier` | XGBoost |\n", "| `mlp` | `NN_Classifier` (MLP) | PyTorch Lightning |\n", "| `resnet` | `NN_Classifier` (ResNet) | PyTorch Lightning |\n", "| `rnn` | `NN_Classifier` (RNN) | PyTorch Lightning |\n", "\n", "Required keys in each JSON: `max_epochs`, `model_params`.\n", "Neural-net configs also have a `train_config` section." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config_dir = tempfile.mkdtemp(prefix=\"ageas_cfg_\")\n", "\n", "# --- Logistic regression with L1 penalty ---\n", "os.makedirs(os.path.join(config_dir, \"logreg\"))\n", "logreg_cfg = {\n", " \"max_epochs\": 1,\n", " \"model_params\": {\n", " \"penalty\": \"l1\",\n", " \"C\": 1.0,\n", " \"solver\": \"saga\",\n", " \"tol\": 1e-4,\n", " },\n", "}\n", "with open(os.path.join(config_dir, \"logreg\", \"logreg_l1_C1.json\"), \"w\") as fh:\n", " json.dump(logreg_cfg, fh, indent=4)\n", "\n", "# --- XGBoost ---\n", "os.makedirs(os.path.join(config_dir, \"xgb\"))\n", "xgb_cfg = {\n", " \"max_epochs\": 1,\n", " \"model_params\": {\n", " \"booster\": \"gbtree\",\n", " \"max_depth\": 4,\n", " \"eta\": 0.1,\n", " },\n", " \"train_config\": {\n", " \"num_boost_round\": 50,\n", " \"early_stopping_rounds\": None,\n", " \"verbose_eval\": False,\n", " },\n", "}\n", "with open(os.path.join(config_dir, \"xgb\", \"xgb_fast.json\"), \"w\") as fh:\n", " json.dump(xgb_cfg, fh, indent=4)\n", "\n", "hangar = Hangar(config_dir)\n", "print(f\"Hangar: {len(hangar.units)} units | Units: {list(hangar.units.keys())}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. `n_iter_boost_selection`\n", "\n", "Boost selection is most useful when you start with thousands of genes and\n", "want to narrow down to a compact candidate set before final training. Each\n", "iteration:\n", "1. Runs `n_kfold_selection` on the current (shrinking) feature set\n", "2. Debriefs the surviving models to score features\n", "3. Keeps the top `extract_top_n / extract_ratio^(iter)` genes\n", "\n", "After all boost iterations a final selection is run on the pruned features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "deck, fea_kept = n_iter_boost_selection(\n", " hangar=hangar,\n", " query_dataset=corpus,\n", " test_dataset=corpus,\n", " max_boost_iter=2,\n", " extract_ratio=0.5,\n", " extract_top_n=8, # target: 8 genes after boosting\n", " seed=42,\n", " verbose=False,\n", " selection_args=dict(\n", " kfold_selection_list=[2],\n", " valid_fraction=0.1,\n", " monitor_metric=\"test.accuracy\",\n", " retention_point=0.5,\n", " cutoff_point=0.0,\n", " ),\n", ")\n", "shutil.rmtree(\"cache\", ignore_errors=True)\n", "\n", "print(f\"Features kept : {len(fea_kept)}\")\n", "print(f\"Surviving units: {len(deck.squad)}\")\n", "print(\"Retained genes :\", fea_kept)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. `n_iter_extraction`\n", "\n", "`n_iter_extraction` is the full regulatory factor discovery pipeline.\n", "It goes beyond boost selection by:\n", "- Removing **IQR outlier features** between iterations (genes with extreme\n", " attribution scores are flagged rather than trusted)\n", "- **L1-aggregating** explanation scores across all iterations\n", "- Returning a **ranked factor table** per class\n", "\n", "The returned `top_factors` table has a column `Outlier_Iter` — genes\n", "pruned as outliers in iteration `i` have `Outlier_Iter = i`; normal genes\n", "have `Outlier_Iter = -1`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_factors, final_exp = n_iter_extraction(\n", " hangar=hangar,\n", " query_dataset=corpus,\n", " test_dataset=corpus,\n", " max_extraction_iter=2,\n", " extract_top_n=5, # top 5 factors per class\n", " use_gene_names=True, # index by adata.var['name']\n", " seed=42,\n", " verbose=False,\n", " selection_args=dict(\n", " kfold_selection_list=[2],\n", " valid_fraction=0.1,\n", " monitor_metric=\"test.accuracy\",\n", " retention_point=0.5,\n", " cutoff_point=0.0,\n", " ),\n", ")\n", "shutil.rmtree(\"cache\", ignore_errors=True)\n", "\n", "print(\"Top regulatory factors:\")\n", "display(top_factors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Final integrated explanation (first 5 rows):\")\n", "display(final_exp.head())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Inspecting per-unit reports\n", "\n", "After every selection pass, each `Unit` in `deck.squad` accumulates a\n", "nested report dict:\n", "\n", "```\n", "unit.report[operation_name]['fold_1'] → per-fold metrics\n", "unit.report[operation_name]['final'] → last-mission metrics\n", "```\n", "\n", "This is useful for diagnosing which models perform well and on which folds." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rows = []\n", "for uid, unit in deck.squad.items():\n", " for op_name, op_report in unit.report.items():\n", " if \"final\" not in op_report:\n", " continue\n", " final = op_report[\"final\"]\n", " source = final[\"test\"] if final[\"test\"] is not None else final[\"vali\"]\n", " prefix = \"test\" if final[\"test\"] is not None else \"vali\"\n", " rows.append({\n", " \"unit\": uid,\n", " \"operation\": op_name,\n", " \"accuracy\": source.get(f\"{prefix}.accuracy\"),\n", " \"auroc\": source.get(f\"{prefix}.auroc\"),\n", " })\n", "\n", "pd.DataFrame(rows).sort_values(\"accuracy\", ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Saving and reloading trained models\n", "\n", "Each `Unit` has a `.model` attribute whose `save_model` / `load_model`\n", "methods persist the underlying estimator in its native format (JSON for\n", "sklearn models, XGBoost binary format, PyTorch state-dict for NN models)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_dir = tempfile.mkdtemp(prefix=\"ageas_saved_\")\n", "\n", "for uid, unit in deck.squad.items():\n", " model_path = os.path.join(save_dir, f\"{uid}.model\")\n", " unit.model.save_model(model_path)\n", " print(f\"Saved {uid} → {model_path}\")\n", "\n", "# To reload later:\n", "# unit.model.load_model(model_path)\n", "\n", "# Cleanup\n", "shutil.rmtree(save_dir, ignore_errors=True)\n", "shutil.rmtree(config_dir, ignore_errors=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }