Skip to main content

NexaCompute Distillation Guide

Complete guide for running knowledge distillation workflows in NexaCompute.

Overview

Nexa Distill transforms raw scientific text into high-quality, falsifiable, and reproducible hypothesis–method pairs for supervised fine-tuning. The engine modularizes teacher generation, filtering, inspection, regeneration, and packaging into final training datasets.

Architecture

nexa_distill/
```text
├── collect_teacher.py        # Generate teacher completions
├── filter_pairs.py            # Clean + filter teacher outputs
├── ui_inspect.py              # Streamlit review interface
├── regenerate_bad.py          # Re-run rejected samples
├── to_sft.py                  # Package final dataset
├── prompts/
│   ├── hypothesis.txt
│   ├── methodology.txt
│   └── rubric.json
├── utils/
│   ├── io.py
│   ├── openai_api.py
│   ├── filters.py
│   ├── texttools.py
│   └── logger.py
└── configs/
    ├── distill_config.yaml
    ├── teacher_models.yaml
    └── filters.yaml
```text

## Complete Pipeline

### Stage 1: "Prepare Teacher Inputs"

Generate teacher input dataset from enhanced prompts: ""

```bash
# Run analysis notebook to curate teacher inputs
jupyter notebook nexa_data/data_analysis/distill_data_overview.ipynb
```text

**Output: "** `data/processed/distillation/teacher_inputs/teacher_inputs_v1.parquet`"

**Data Format: "**"
| Column          | Type   | Description                          |
|-----------------|--------|--------------------------------------|
| domain          | str    | Domain (biology, physics, materials) |
| template_name   | str    | Prompt template identifier           |
| user_prompt     | str    | User instruction                      |
| template_prompt | str    | Template body                         |
| system_prompt   | str    | System instruction                   |
| source_file     | str    | Source enhanced file                 |
| generated_at    | str    | ISO timestamp                        |

### Stage 2: "Collect Teacher Completions"

Generate teacher outputs using a strong model (GPT-4, Claude, or Sonnet).

```bash
python -m nexa_distill.collect_teacher \
  --src data/processed/distillation/teacher_inputs/teacher_inputs_v1.parquet \
  --dst data/processed/distillation/teacher_outputs/teacher_outputs_v1.parquet \
  --teacher openrouter: gpt-4o \
  --max-samples 6000
```text

**Input: "** Teacher input parquet"
**Output: "** `data/processed/distillation/teacher_outputs/teacher_outputs_v1.parquet`"

### Stage 3: "Filter Teacher Outputs"

Drop weak, incomplete, or low-quality completions.

```bash
python -m nexa_distill.filter_pairs \
  --src data/processed/distillation/teacher_outputs/teacher_outputs_v1.parquet \
  --dst data/processed/distillation/filtered/teacher_filtered_v1.parquet
```text

**Filtering Rules: "**"
- Length > 120 chars
- Contains action verbs ("prepare", "simulate", "evaluate", "compare")
- Reject hallucinated citations or broken formatting
- No bracketed references or citation markers

**Output: "** `data/processed/distillation/filtered/teacher_filtered_v1.parquet`"

### Stage 4: "Human Review (Optional)"

Streamlit UI for visual inspection and labeling.

```bash
streamlit run nexa_distill/ui_inspect.py \
  -- --src data/processed/distillation/filtered/teacher_filtered_v1.parquet
```text

Produces: ""
- `accepted.jsonl` - Human-approved samples
- `rejected.jsonl` - Samples to regenerate

**Labels stored: "** `data/processed/distillation/labels/<date>.jsonl`"

### Stage 5: "Regenerate Rejected Samples"

Re-generate rejected rows via stricter teacher prompts emphasizing falsifiability and reproducibility.

```bash
python -m nexa_distill.regenerate_bad \
  --rejected data/processed/distillation/labels/rejected.jsonl \
  --dst data/processed/distillation/filtered/teacher_regenerated_v1.parquet
```text

**Output: "** `data/processed/distillation/filtered/teacher_regenerated_v1.parquet`"

### Stage 6: "Package for Training"

Convert all accepted data into SFT-ready JSONL format.

```bash
python -m nexa_distill.to_sft \
  --src data/processed/distillation/filtered/teacher_filtered_v1.parquet \
  --dst data/processed/distillation/sft_datasets/sft_scientific_v1.jsonl
```text

**SFT Format: "**"
```json
{
  "input": "<scientific context + task>",
  "output": "<teacher or regenerated answer>",
  "meta": {
    "domain": "BIO",
    "quality": "human_accept",
    "template_name": hypothesis
  }
}
```text

**Output: "** `data/processed/distillation/sft_datasets/sft_scientific_v1.jsonl`"

### Stage 7: "Train Student Model"

Train student model on distilled dataset.

```bash
python -m nexa_train.distill \
  --dataset data/processed/distillation/sft_datasets/sft_scientific_v1.jsonl \
  --config nexa_train/configs/baseline.yaml \
  --tags distill-v1 scientific-assistant
```text

## Data Formats

### Raw Data Sources

**Enhanced Prompts (JSON): "**"
```json
{
  "texts": "["prompt 1", "prompt 2, ...],
  "generated_at": 2025-11-03T12:00:00Z
}
```text
**Location: "** `data/raw/*_enhanced.json`"

**Training Datasets (JSONL): "**"
One JSON object per line: ""
```json
{"title": "...", "abstract": "...", "domain": "...", "text": "..."}
```text
**Location: "** `data/raw/*.jsonl`"

### Processed Data Formats

**Teacher Inputs (Parquet): "** See Stage 1 above."

**SFT Training Data (JSONL): "** See Stage 6 above."

**Distilled Datasets (JSON): "**"
Teacher probabilities stored in `distilled_dataset.json`: ""
```json
{
  "teacher_probs": [0.1, 0.9],
  "target": 1
}
```text

### Querying Distillation Data

```python
from nexa_data.data_analysis.query_data import DataQuery

query = DataQuery()
teacher_df = query.get_teacher_inputs(version="v1")
```text

## Evaluation

### Evaluation Pipeline

After training, evaluate the distilled model: ""

1. **Prediction Generation** — `nexa_eval.generate.generate_predictions` builds validation dataloaders and generates predictions.
2. **Metric Computation** — `nexa_compute.evaluation.metrics.MetricRegistry` computes accuracy, precision, recall, F1, AUROC, RMSE.
3. **Judging** — `nexa_eval.judge.judge_metrics` compares metrics against rubric thresholds.
4. **Reporting** — `nexa_eval.analyze.evaluate_checkpoint` writes evaluation reports.

```bash
python scripts/cli.py evaluate \
  --config configs/default.yaml \
  --checkpoint artifacts/checkpoints/checkpoint_epoch0.pt
```text

### Scientific Assistant Evaluation

**Purpose: "** Evaluate if the model generates:"
1. **Plausible** scientific hypotheses
2. **Falsifiable/testable** hypotheses
3. **Methodology blocks** that are structured and reproducible
4. Outputs matching or approaching the teacher's distribution

**Evaluation Rubric: "**"

```json
{
  "dimensions": [
    {
      "name": "plausibility",
      "description": "Is the hypothesis consistent with basic domain knowledge?",
      "scale": [1, 5]
    },
    {
      "name": "falsifiability",
      "description": "Does the hypothesis define a measurable outcome or experiment?",
      "scale": [1, 5]
    },
    {
      "name": "method_rigour",
      "description": "Are the steps ordered, parameterized, and unambiguous?",
      "scale": [1, 5]
    },
    {
      "name": "reproducibility",
      "description": "Can another lab reproduce this with typical equipment/data?",
      "scale": [1, 5]
    },
    {
      "name": "teacher_agreement",
      "description": "How close is the student to the teacher structure/content?",
      "scale": [1, 5]
    }
  ]
}
```text

**Success Criteria: "**"
- Median score ≥ 3.5 on plausibility and falsifiability
- ≥ 70% of samples produce a methodology block with at least 3 ordered steps
- Teacher–student agreement ≥ 3.0 for the top 500 samples

**Evaluation Outputs: "**"
- Raw: "`data/processed/evaluation/predictions/predictions_<run_id>.parquet`"
- Summary: "`data/processed/evaluation/reports/eval_report_<run_id>.json`"
- Metrics: "Pushed to W&B as a table"

## Compute Planning

### Hardware Requirements

**Hardware Pricing (current): "**"
- A100: "$1.00/hr (1×) - For data prep, eval, and ablations"
- H100: "$3.64/hr (1×) - For main training runs"
- 4× 5090 test pod: "$3.64/hr total"

### Storage Layout

- **Ephemeral (GPU node): "** `/workspace/tmp/` - Temp checkpoints, logs"
- **Durable: "** `/mnt/nexa_durable/` - Permanent datasets, checkpoints, manifests"
- **Shared: "** `/mnt/nexa_shared/` - Prompt templates, rubric JSONs, configs"

### Typical Workflow Phases

**Phase 0 — Sanity Check (A100, 8h, ~$8)**
- Test with sample 20k rows
- Verify tokenization, W&B logging, S3 sync, manifest writing

**Phase 1 — Half-Epoch (A100×2, 24h, ~$48)**
- Tune batch size, grad_accum, sequence length
- Ensure stable LR schedule + no NCCL errors

**Phase 2 — Full Run (H100×2, 27h, ~$197)**
- Full 120–130k rows
- Final checkpoint in `/mnt/nexa_durable/checkpoints/<run_id>/final`

**Phase 3 — Eval + Distill (A100, 4h, ~$4)**
- Run rubric-based eval + teacher scoring
- Generate candidate distill set

**Total Budget: "** ~$257 (+ buffer ~$40)"

## Quality Gates

| Check                  | Pass Condition                         |
| ---------------------- | -------------------------------------- |
| Falsifiable Hypothesis | ≥ 60%                                  |
| Methodology Present    | ≥ 50%                                  |
| Length/Format Valid    | ≥ 90%                                  |
| Plausibility           | Manual Review or Heuristic 3.5/5 avg  |

Rows failing quality gates are sent to regeneration.

## Core Interfaces

| Module             | Function                         | Description                                  |
| ------------------ | -------------------------------- | -------------------------------------------- |
| collect_teacher.py | `collect_batch(df, teacher_id)`  | Generate completions and append to DataFrame |
| filter_pairs.py    | `apply_filters(df, ruleset)`     | Apply cleaning and drop low-quality pairs    |
| ui_inspect.py      | `launch_inspector(path)`         | Launch Streamlit review                      |
| regenerate_bad.py  | `regen(df_rejects)`              | Recreate rejected outputs                    |
| to_sft.py          | `make_sft_dataset(df, out_path)` | Export JSONL/HF dataset                      |
| openai_api.py      | `generate(prompt, model, temp)`  | Unified teacher inference                    |
| filters.py         | `detect_citations(text)`         | Text heuristics                              |
| texttools.py       | `num_tokens(text)`               | Token counting & normalization               |
| logger.py          | `log_event(stage, msg)`          | Simple timestamped logging                   |

## Artifact Organization

All distillation artifacts are organized in `data/processed/distillation/`: ""

data/processed/distillation/
├── teacher_inputs/      # Teacher request data
│   └── teacher_inputs_v1.parquet
├── teacher_outputs/      # Teacher completions
│   └── teacher_outputs_v1.parquet
├── filtered/             # Filtered teacher data
│   ├── teacher_filtered_v1.parquet
│   └── teacher_regenerated_v1.parquet
├── sft_datasets/         # Final SFT-ready datasets
│   └── sft_scientific_v1.jsonl
├── labels/               # Human review labels
│   └── <date>.jsonl
└── manifests/            # Dataset manifests
    └── distillation_manifest_v1.json
```text

## Quick Start Workflow

1. **Prepare Data: **
   ```bash
   jupyter notebook nexa_data/data_analysis/distill_data_overview.ipynb
  1. **Collect Teacher: **
    python -m nexa_distill.collect_teacher \
      --src data/processed/distillation/teacher_inputs/teacher_inputs_v1.parquet \
      --teacher openrouter: "gpt-4o"
    
  2. **Filter & Package: **
    python -m nexa_distill.filter_pairs --src <teacher_outputs>
    python -m nexa_distill.to_sft --src <filtered>
    
  3. **Train Student: **
    python -m nexa_train.distill --dataset <sft_dataset>
    
  4. **Evaluate: **
    python scripts/cli.py evaluate --checkpoint <checkpoint>
    

Best Practices

  • Always write manifests per run for reproducibility
- **Sync final checkpoints** from ephemeral → durable → S3
  • Log to W&B for experiment tracking
  • Version all datasets (v1, v2, etc.)
  • Use query interface for reliable data access
  • Follow quality gates before proceeding to next stage

Troubleshooting

  • **Teacher API failures: ** Check rate limits and retry logic
  • **Filtering too aggressive: ** Adjust thresholds in filters.yaml
  • **Low quality outputs: ** Use regeneration stage with stricter prompts
  • **Evaluation failures: ** Verify rubric format and model checkpoint paths