Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,6 @@ cython_debug/

# Weights and Biases files
wandb/

# geo-prior model checkpoints
*.pth
34 changes: 34 additions & 0 deletions src/geoprior/.env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Geo-prior pipeline configuration.
# Copy to `src/geoprior/.env` and fill in. `.env` is gitignored — never commit it.
# Real environment variables take precedence over values in `.env`.

# ---- BigQuery (source occurrence data) ----
GEOPRIOR_BQ_PROJECT=leps-ai
GEOPRIOR_BQ_DATASET=global_butterflies_2604
# BigQuery auth: point to a service-account key, OR run
# `gcloud auth application-default login` and leave this unset.
# GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json

# ---- Filesystem paths ----
# Frozen category map. Defaults to the committed src/geoprior/geoprior_categ_map.json
# — only set this to override.
# GEOPRIOR_CATEG_MAP=src/geoprior/geoprior_categ_map.json
# Working dir for generated train/val/test.json (+ sibling artifacts).
GEOPRIOR_DATA_DIR=/path/to/data/geoprior
# Vision split CSVs (val.csv / test.csv) that provide the hold-out gbif_ids.
GEOPRIOR_SPLITS_DIR=/path/to/data/splits
# Where trained checkpoints are written.
GEOPRIOR_MODEL_DIR=/path/to/models/geoprior

# ---- Weights & Biases ----
WANDB_ENTITY=moth-ai
WANDB_PROJECT=Global-Butterfly
# SECRET — get from https://wandb.ai/authorize. Do NOT commit a real key.
# Alternatively, run training with --wandb_offline and omit this.
WANDB_API_KEY=your-wandb-api-key

# ---- Fusion eval (downstream, optional) ----
# GEOPRIOR_CLF_VAL_PREDS=/path/to/clf/val_predictions.csv
# GEOPRIOR_CLF_TEST_PREDS=/path/to/clf/test_predictions.csv
# GEOPRIOR_VAL_PREDS=/path/to/geoprior_preds/val
# GEOPRIOR_TEST_PREDS=/path/to/geoprior_preds/test
148 changes: 148 additions & 0 deletions src/geoprior/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Geo-prior pipeline

Trains and evaluates a **geographic prior** for global butterflies: a model of
`p(species | latitude, longitude, date)` that re-ranks / gates the image
classifier's predictions. The network is a SINR-style **FCNet** (no images —
only coordinates, dates, and species labels).

This package is self-contained: BigQuery → training data → trained model, with
all configuration in `.env` and the model code kept in-tree.

---

## Directory layout

```
src/geoprior/
├── README.md ← you are here
├── config.py ← central config (reads .env); nothing hardcoded
├── .env.sample ← copy to .env and fill in (real .env is gitignored)
├── requirements.txt ← all dependencies
├── geoprior_categ_map.json ← FROZEN species→class_id artifact (12,317 classes)
├── geoprior_categ_map.PROVENANCE.md← where the frozen map comes from
├── build_geoprior_categ_map.py ← Stage 0: BQ → category map (regenerate/verify)
├── build_geoprior_json.py ← Stage 1: BQ + splits + map → train/val/test.json
├── train_geoprior.py ← Stage 2: JSON → trained FCNet (+ wandb)
├── predict_geoprior.py ← Stage 3: per-occurrence priors for fusion
├── fusion_eval_top5.py ← Stage 4: geo-prior + classifier top-5 eval
└── geoprior_fagner/ ← FCNet network, in-tree (frozen @ upstream commit)
├── README.md, models.py, losses.py, dataloader.py
```

All scripts are run **as modules from the repo root** so package imports
(`from src.geoprior…`) resolve:

```bash
python -m src.geoprior.<script_name> [flags]
```

---

## Setup

```bash
pip install -r src/geoprior/requirements.txt # see notes for the CUDA torch wheel
cp src/geoprior/.env.sample src/geoprior/.env # then edit .env
```

### Configuration (`.env`)

Every path / identifier / secret lives in `src/geoprior/.env` (loaded by
`config.py`). See `.env.sample` for the annotated list. Summary:

| Variable | Purpose | Default |
|---|---|---|
| `GEOPRIOR_BQ_PROJECT` / `GEOPRIOR_BQ_DATASET` | BigQuery source | `leps-ai` / `global_butterflies_2604` |
| `GOOGLE_APPLICATION_CREDENTIALS` | BQ auth (service-account key) | unset → uses gcloud ADC |
| `GEOPRIOR_CATEG_MAP` | frozen category map | in-repo `geoprior_categ_map.json` |
| `GEOPRIOR_DATA_DIR` | generated train/val/test.json | `<repo>/data/geoprior` |
| `GEOPRIOR_SPLITS_DIR` | vision `val.csv`/`test.csv` (hold-out gbif_ids) | `<repo>/data/splits` |
| `GEOPRIOR_MODEL_DIR` | checkpoint output | `<repo>/models/geoprior` |
| `WANDB_ENTITY` / `WANDB_PROJECT` | W&B logging | `moth-ai` / `Global-Butterfly` |
| `WANDB_API_KEY` | **secret** W&B key | — (or use `--wandb_offline`) |

**Secrets** (`WANDB_API_KEY`, `GOOGLE_APPLICATION_CREDENTIALS`) are never
hardcoded and never committed — `.env` is gitignored.

---

## Pipeline

### Stage 0 — Category map (frozen)
The `species → class_id` map is the **frozen class-space contract**: the trained
model's output indices are bound to its alphabetical ordering. It is committed
(`geoprior_categ_map.json`) and only regenerated/verified, never silently
overwritten. See `geoprior_categ_map.PROVENANCE.md`.

```bash
# verify the committed map still matches BigQuery (no writes)
python -m src.geoprior.build_geoprior_categ_map
# (re)materialise all artifacts
python -m src.geoprior.build_geoprior_categ_map --write --out-dir "$GEOPRIOR_DATA_DIR"
```

### Stage 1 — Build training JSON
Pulls geocoded occurrences from BigQuery, excludes the vision `val`/`test`
gbif_ids from train (prevents leakage), maps species via the frozen map, and
writes COCO-style `train/val/test.json` to `GEOPRIOR_DATA_DIR`.

```bash
python -m src.geoprior.build_geoprior_json
```

### Stage 2 — Train
```bash
python -m src.geoprior.train_geoprior \
--train_data_json "$GEOPRIOR_DATA_DIR/train.json" \
--model_save_path "$GEOPRIOR_MODEL_DIR" \
--epochs 30 --batch_size 1024 --embed_dim 256 \
--max_instances_per_class 100 # add --wandb_offline to skip W&B
```
Saves a checkpoint after every epoch plus `model_final_*.pth`.

### Stage 3 — Predict (for fusion)
```bash
python -m src.geoprior.predict_geoprior \
--test_data_json "$GEOPRIOR_DATA_DIR/val.json" \
--model_path "$GEOPRIOR_MODEL_DIR/model_final_*.pth" \
--results_dir "$GEOPRIOR_DATA_DIR/preds/val"
```

### Stage 4 — Fusion eval (top-5)
```bash
python -m src.geoprior.fusion_eval_top5 # paths come from .env
```

---

## Inputs you must provide

- **BigQuery access** to `GEOPRIOR_BQ_PROJECT.GEOPRIOR_BQ_DATASET`
(`gbif_inat_occurrences`, `gbif_occurrence_location`), via ADC or a
service-account key.
- **Vision split CSVs** `val.csv` / `test.csv` in `GEOPRIOR_SPLITS_DIR`
(produced by the vision pipeline's `src/dataset_tools/bq_squashfs/split.py`).
Only their `gbif_id` column is used.
- A **GPU** is optional — the model is ~3.7 M params; CPU works.

The category map is already provided (committed, frozen).

---

## The model

- **Network:** `FCNet` (4× residual blocks over a coordinate/date encoding),
kept in-tree in `geoprior_fagner/`, from Fagner Cunha's lepsAI (Apache-2.0).
- **Class space:** 12,317 species (every species with ≥1 geocoded occurrence in
the `public_gbif_2026-05` snapshot).
- **Inputs:** 6 features (cos/sin of lat, lon, day-of-year).
- **Reference run** `geoprior-fcnet-global-12317cls-v1`: 30 epochs, batch 1024,
lr 5e-4 (decay 0.98/epoch), embed_dim 256, BalancedSampler cap 100/class;
final loss ≈ 0.15; checkpoint ≈ 14.7 MB.

## Reproducibility notes

- The frozen category map + the in-tree `geoprior_fagner` (pinned at an upstream
commit) together fix the model's architecture/class contract. Changing either
means retraining.
- The build is deterministic given the same BigQuery snapshot, splits, and map.
1 change: 1 addition & 0 deletions src/geoprior/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Geo-prior (FCNet) pipeline: BigQuery -> training data -> model."""
181 changes: 181 additions & 0 deletions src/geoprior/build_geoprior_categ_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/env python3
"""
Build / verify the geo-prior category map (species_name -> class_id).

This is the FROZEN class-space contract for the geo-prior FCNet model. The
trained model's output indices are bound to this exact alphabetical ordering,
so the committed ``geoprior_categ_map.json`` must NOT change once a model has
been trained against it. This script regenerates the map from BigQuery and
*verifies* it against the committed artifact; it refuses to silently overwrite
on drift (use --write to (re)materialise the artifacts intentionally).

Source (BigQuery, project ``leps-ai``)
--------------------------------------
- leps-ai.global_butterflies_2604.gbif_inat_occurrences
gbifID -> verbatimSpeciesScientificName
- leps-ai.global_butterflies_2604.gbif_occurrence_location
gbif_id -> decimallatitude / decimallongitude / eventdate
(itself derived from the public GBIF mirror, snapshot
``public_gbif_2026-05``; see src/dataset_tools/bq_squashfs/README.md)

Class-space definition
----------------------
Every species (``verbatimSpeciesScientificName``) with >= 1 geocoded
occurrence (non-null lat, lon, eventdate), sorted alphabetically with
Python's default ``sorted()`` over the species-name strings, enumerated
0..N-1. As of the 2026-05 snapshot this is 12,317 classes.

Outputs (only written with --write)
-----------------------------------
geoprior_categ_map.json species -> class_id (int) [FROZEN artifact]
geoprior_label_map.json class_id (str) -> species (HF-style reverse)
geoprior_metadata.json species -> {class_id, n_geocoded_occurrences}
master_species_list.txt sorted species, one per line
master_species_with_counts.json species -> n_geocoded_occurrences

Usage
-----
# verify the committed frozen map still matches BigQuery (default; no writes)
python -m src.geoprior.build_geoprior_categ_map

# (re)generate every artifact into the configured data dir
python -m src.geoprior.build_geoprior_categ_map --write
"""
import argparse
import json
import sys
from pathlib import Path

from google.cloud import bigquery

from src.geoprior import config

# Counts the geocoded occurrences per species. The WHERE clause mirrors
# build_geoprior_json.py::fetch_all_geocoded so the class space is exactly the
# set of species that survive into the geo-prior JSON pipeline.
SPECIES_COUNT_QUERY = f"""
SELECT
o.verbatimSpeciesScientificName AS species_name,
COUNT(*) AS n_geocoded_occurrences
FROM `{config.TBL_OCCURRENCES}` o
JOIN `{config.TBL_LOCATION}` l
ON l.gbif_id = o.gbifID
WHERE l.decimallatitude IS NOT NULL
AND l.decimallongitude IS NOT NULL
AND l.eventdate IS NOT NULL
AND o.verbatimSpeciesScientificName IS NOT NULL
GROUP BY species_name
"""


def fetch_species_counts():
"""Return {species_name: n_geocoded_occurrences} and bytes billed."""
client = bigquery.Client(project=config.BQ_PROJECT)
job = client.query(SPECIES_COUNT_QUERY)
rows = list(job.result())
counts = {r["species_name"]: int(r["n_geocoded_occurrences"]) for r in rows}
return counts, (job.total_bytes_billed or 0)


def build_maps(counts):
"""Derive every artifact from the {species: count} dict (alphabetical)."""
species = sorted(counts.keys())
categ_map = {s: i for i, s in enumerate(species)}
label_map = {str(i): s for s, i in categ_map.items()}
metadata = {
s: {"class_id": categ_map[s], "n_geocoded_occurrences": counts[s]}
for s in species
}
with_counts = {s: counts[s] for s in species}
return species, categ_map, label_map, metadata, with_counts


def _diff_categ_map(frozen, regenerated):
"""Human-readable summary of how a regenerated map differs from frozen."""
fk, rk = set(frozen), set(regenerated)
added = sorted(rk - fk)
removed = sorted(fk - rk)
reindexed = sorted(s for s in (fk & rk) if frozen[s] != regenerated[s])
return added, removed, reindexed


def verify_against(frozen_path, regenerated, label):
"""Compare a regenerated dict to an on-disk JSON. Return True if identical."""
frozen_path = Path(frozen_path)
if not frozen_path.exists():
print(f" [{label}] frozen file not found: {frozen_path} (skipping)")
return None
frozen = json.loads(frozen_path.read_text())
if frozen == regenerated:
print(f" [{label}] VERIFY OK — {len(regenerated):,} entries, identical to {frozen_path}")
return True
print(f" [{label}] VERIFY FAILED — differs from {frozen_path}")
if label == "categ_map":
added, removed, reindexed = _diff_categ_map(frozen, regenerated)
print(f" added species: {len(added)} e.g. {added[:5]}")
print(f" removed species: {len(removed)} e.g. {removed[:5]}")
print(f" reindexed (id changed): {len(reindexed)} e.g. {reindexed[:5]}")
else:
print(f" frozen has {len(frozen):,} entries, regenerated has {len(regenerated):,}")
return False


def write_artifacts(out_dir, species, categ_map, label_map, metadata, with_counts):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "geoprior_categ_map.json").write_text(json.dumps(categ_map))
(out_dir / "geoprior_label_map.json").write_text(json.dumps(label_map))
(out_dir / "geoprior_metadata.json").write_text(json.dumps(metadata))
(out_dir / "master_species_with_counts.json").write_text(json.dumps(with_counts))
(out_dir / "master_species_list.txt").write_text("\n".join(species) + "\n")
for name in ("geoprior_categ_map.json", "geoprior_label_map.json",
"geoprior_metadata.json", "master_species_with_counts.json",
"master_species_list.txt"):
p = out_dir / name
print(f" wrote {p} ({p.stat().st_size/1e3:.0f} KB)")


def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--frozen", default=str(config.CATEG_MAP_PATH),
help="Path to the committed frozen geoprior_categ_map.json to verify against")
ap.add_argument("--verify-counts", default=None,
help="Optional path to an existing master_species_with_counts.json "
"to validate the BQ count query against")
ap.add_argument("--write", action="store_true",
help="Materialise all five artifacts into --out-dir")
ap.add_argument("--out-dir", default=str(config.DATA_DIR),
help="Directory to write artifacts to (with --write)")
ap.add_argument("--force", action="store_true",
help="Write even if verification against --frozen fails")
args = ap.parse_args()

print(f"Querying BigQuery (project={config.BQ_PROJECT}) for geocoded species counts ...")
counts, billed = fetch_species_counts()
species, categ_map, label_map, metadata, with_counts = build_maps(counts)
total_occ = sum(counts.values())
print(f" {len(species):,} species, {total_occ:,} geocoded occurrences, "
f"scanned {billed/1e6:.1f} MB (~${billed/1e12*5:.4f})")

print("Verifying ...")
ok = verify_against(args.frozen, categ_map, "categ_map")
if args.verify_counts:
verify_against(args.verify_counts, with_counts, "counts")

if args.write:
if ok is False and not args.force:
print("Refusing to --write: regenerated map differs from the frozen "
"artifact. Re-run with --force only if you intend to retire the "
"current class space (and retrain the model).")
sys.exit(1)
print(f"Writing artifacts to {args.out_dir} ...")
write_artifacts(args.out_dir, species, categ_map, label_map, metadata, with_counts)

# Exit non-zero on a real mismatch so CI / callers can catch drift.
if ok is False and not args.force:
sys.exit(1)


if __name__ == "__main__":
main()
Loading
Loading