diff --git a/.gitignore b/.gitignore index c5b1c9f3..d1f34d2c 100644 --- a/.gitignore +++ b/.gitignore @@ -95,4 +95,8 @@ benchmark_results/ *.dat # CatBoost -catboost_info/ \ No newline at end of file +catboost_info/ + +# Dev artifacts +*.pt +data/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e77808..5139bf4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M +- `FlatSASRec` network — flat SASRec implementation without the ItemNet hierarchy. Pre-norm transformer encoder with id-embeddings, causal masking, softmax and BCE losses +- `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Joint training of adaptor + transformer on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors +- `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order +- `SequenceBatchDataset` — lightweight torch Dataset wrapper for sequence training data +- Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor + + ## [0.18.0] - 21.02.2026 ### Added diff --git a/benchmark/compare_sasrec_unisrec.py b/benchmark/compare_sasrec_unisrec.py new file mode 100644 index 00000000..a14abd04 --- /dev/null +++ b/benchmark/compare_sasrec_unisrec.py @@ -0,0 +1,510 @@ +"""Compare RecTools SASRec vs UniSRec-ID on ML-20M. + +Both use full softmax, Adam, n_factors=256, 10 epochs. +MIN_RATING=-1 (no filter), MIN_ITEM_INTERACTIONS=5, MIN_USER_INTERACTIONS=2. +Writes results to benchmark/comparison_report.md. + +Usage: + python benchmark/compare_sasrec_unisrec.py + +Data is downloaded automatically if not present. +If pretrained embeddings are not found, random embeddings are generated +(sufficient for ID-only comparison). +""" + +import gc +import io +import time +import zipfile +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import requests +import torch +from tqdm import tqdm + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.fast_transformers import UniSRecModel +from rectools.fast_transformers.preprocessing import build_sequences +from rectools.models import SASRecModel + +BENCHMARK_DIR = Path(__file__).resolve().parent +DATA_DIR = BENCHMARK_DIR / "data" / "ml-20m" +RATINGS_PATH = DATA_DIR / "ratings.csv" +CACHE_EMB_PATH = DATA_DIR / "qwen_embeddings.pt" +REPORT_PATH = BENCHMARK_DIR / "comparison_report.md" + +ML20M_URL = "https://files.grouplens.org/datasets/movielens/ml-20m.zip" + +MIN_RATING = -1 +MIN_ITEM_INTERACTIONS = 5 +MIN_USER_INTERACTIONS = 2 + +EPOCHS = 10 +PATIENCE = None +BATCH_SIZE = 128 +SESSION_MAX_LEN = 200 +N_FACTORS = 256 +N_BLOCKS = 2 +N_HEADS = 1 +LR = 1e-3 + + +def download_ml20m() -> None: + """Download and extract ML-20M if not present.""" + if RATINGS_PATH.exists(): + return + print(f"Downloading ML-20M from {ML20M_URL} ...") + DATA_DIR.mkdir(parents=True, exist_ok=True) + resp = requests.get(ML20M_URL, stream=True, timeout=600) + resp.raise_for_status() + buf = io.BytesIO() + total = int(resp.headers.get("content-length", 0)) + with tqdm(total=total, unit="B", unit_scale=True, desc="Download") as pbar: + for chunk in resp.iter_content(chunk_size=1 << 20): + buf.write(chunk) + pbar.update(len(chunk)) + print("Extracting...") + with zipfile.ZipFile(buf) as zf: + for member in zf.namelist(): + # ml-20m/ratings.csv -> DATA_DIR/ratings.csv + basename = Path(member).name + if not basename: + continue + target = DATA_DIR / basename + with zf.open(member) as src, open(target, "wb") as dst: + dst.write(src.read()) + print(f"Extracted to {DATA_DIR}") + + +def load_and_preprocess() -> pd.DataFrame: + download_ml20m() + ratings = pd.read_csv(RATINGS_PATH) + ratings.columns = ["user_id", "item_id", "rating", "timestamp"] + + if MIN_RATING > 0: + ratings = ratings[ratings["rating"] >= MIN_RATING] + + if MIN_ITEM_INTERACTIONS > 0: + item_counts = ratings.groupby("item_id").size() + popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index + ratings = ratings[ratings["item_id"].isin(popular)] + + if MIN_USER_INTERACTIONS > 0: + user_counts = ratings.groupby("user_id").size() + valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index + ratings = ratings[ratings["user_id"].isin(valid)] + + return ratings + + +def split_eval(ratings: pd.DataFrame): + ratings = ratings.sort_values(["user_id", "timestamp"]) + grouped = ratings.groupby("user_id") + test_idx = grouped.tail(1).index + remaining = ratings.drop(test_idx) + val_idx = remaining.groupby("user_id").tail(1).index + train_idx = remaining.drop(val_idx).index + return ratings.loc[train_idx], ratings.loc[val_idx], ratings.loc[test_idx] + + +def to_tensors(df: pd.DataFrame): + return ( + torch.tensor(df["user_id"].values, dtype=torch.long), + torch.tensor(df["item_id"].values, dtype=torch.long), + torch.tensor(df["timestamp"].values, dtype=torch.long), + ) + + +def get_pretrained_embeddings(item_ids: pd.Series, dim: int = 1024) -> torch.Tensor: + """Load cached embeddings or generate random ones for ID-only comparison.""" + if CACHE_EMB_PATH.exists(): + print(f"Loading pretrained embeddings from {CACHE_EMB_PATH}") + return torch.load(CACHE_EMB_PATH, weights_only=True) + + max_id = int(item_ids.max()) + print(f"No pretrained embeddings found at {CACHE_EMB_PATH}") + print(f"Generating random embeddings ({max_id + 1}, {dim}) for ID-only comparison") + torch.manual_seed(42) + emb = torch.randn(max_id + 1, dim) + emb[0] = 0.0 + return emb + + +@torch.no_grad() +def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256): + net = model.net + net.cuda().eval() + device = torch.device("cuda") + maxlen = model.session_max_len + + item_embs = net.project_all() + unique_items = model.item_id_mapping + ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))} + + train_grouped = train_df.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() + test_grouped = test_df.groupby("user_id")["item_id"].first().to_dict() + test_users = list(test_grouped.keys()) + + hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 + for start in tqdm(range(0, len(test_users), batch_size), desc="Eval UniSRec"): + batch_users = test_users[start : start + batch_size] + seqs, targets = [], [] + for uid in batch_users: + history = train_grouped.get(uid, []) + mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] + if not mapped: + continue + seq = mapped[-maxlen:] + seqs.append([0] * (maxlen - len(seq)) + seq) + targets.append(ext_to_int.get(test_grouped[uid])) + if not seqs: + continue + x = torch.tensor(seqs, dtype=torch.long, device=device) + h = net.encode_last(x) + scores = h @ item_embs.T + scores[:, 0] = float("-inf") + for i, target_int in enumerate(targets): + if target_int is None: + continue + _, topk_idx = scores[i].topk(k) + topk = topk_idx.cpu().tolist() + if target_int in topk: + rank = topk.index(target_int) + hits += 1 + ndcg_sum += 1.0 / np.log2(rank + 2) + mrr_sum += 1.0 / (rank + 1) + total += 1 + return {"HR@10": hits / total, "NDCG@10": ndcg_sum / total, "MRR@10": mrr_sum / total, "n_users": total} + + +def evaluate_sasrec(model, dataset_for_recommend, test_df, k=10): + test_users = test_df["user_id"].unique() + reco = model.recommend(users=test_users, dataset=dataset_for_recommend, k=k, filter_viewed=False) + + test_targets = test_df.groupby("user_id")["item_id"].first().to_dict() + hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 + for uid, group in reco.groupby(Columns.User): + target = test_targets.get(uid) + if target is None: + continue + items = group[Columns.Item].tolist() + if target in items: + rank = items.index(target) + hits += 1 + ndcg_sum += 1.0 / np.log2(rank + 2) + mrr_sum += 1.0 / (rank + 1) + total += 1 + return {"HR@10": hits / total, "NDCG@10": ndcg_sum / total, "MRR@10": mrr_sum / total, "n_users": total} + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + + +def write_report(timings: dict, metrics: dict, data_info: dict) -> str: + gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A" + date_str = datetime.now().strftime("%Y-%m-%d %H:%M") + dataset_str = ( + f"ML-20M (min_rating={MIN_RATING}," f" min_item={MIN_ITEM_INTERACTIONS}," f" min_user={MIN_USER_INTERACTIONS})" + ) + lines = [ + "# SASRec vs UniSRec-ID Comparison", + "", + f"**Date:** {date_str} ", + f"**GPU:** {gpu_name} ", + f"**Dataset:** {dataset_str}", + "", + "## Data", + "", + "| | Count |", + "|---|---:|", + f"| Interactions | {data_info['n_interactions']:,} |", + f"| Users | {data_info['n_users']:,} |", + f"| Items | {data_info['n_items']:,} |", + f"| Train | {data_info['n_train']:,} |", + f"| Val | {data_info['n_val']:,} |", + f"| Test | {data_info['n_test']:,} |", + "", + "## Config", + "", + "| Parameter | Value |", + "|---|---|", + f"| n_factors | {N_FACTORS} |", + f"| n_blocks | {N_BLOCKS} |", + f"| n_heads | {N_HEADS} |", + f"| session_max_len | {SESSION_MAX_LEN} |", + f"| batch_size | {BATCH_SIZE} |", + f"| lr | {LR} |", + "| loss | softmax |", + "| optimizer | Adam |", + f"| epochs | {EPOCHS} |", + f"| patience | {PATIENCE} |", + "| dropout | 0.1 |", + "", + "## Timing", + "", + "| Stage | SASRec | UniSRec ID |", + "|---|---:|---:|", + ] + + for stage in ["data_load", "preprocessing", "model_init", "training", "eval"]: + s = timings.get(f"sasrec_{stage}", 0) + u = timings.get(f"unisrec_{stage}", 0) + label = { + "data_load": "Data load & split", + "preprocessing": "Preprocessing", + "model_init": "Model init", + "training": f"Training ({EPOCHS} epochs)", + "eval": "Evaluation", + }[stage] + lines.append(f"| {label} | {s:.1f}s | {u:.1f}s |") + + s_total = sum(timings.get(f"sasrec_{s}", 0) for s in ["preprocessing", "model_init", "training", "eval"]) + u_total = sum(timings.get(f"unisrec_{s}", 0) for s in ["preprocessing", "model_init", "training", "eval"]) + lines.append(f"| **Total** | **{s_total:.1f}s** | **{u_total:.1f}s** |") + + s_epoch = timings.get("sasrec_training", 0) / max(timings.get("sasrec_epochs_done", 1), 1) + u_epoch = timings.get("unisrec_training", 0) / max(timings.get("unisrec_epochs_done", 1), 1) + s_epochs_done = timings.get("sasrec_epochs_done", EPOCHS) + u_epochs_done = timings.get("unisrec_epochs_done", EPOCHS) + prep_speedup = timings.get("prep_speedup", 0) + lines.extend( + [ + "", + "| | SASRec | UniSRec ID |", + "|---|---:|---:|", + f"| Epochs completed | {s_epochs_done} | {u_epochs_done} |", + f"| Time per epoch | {s_epoch:.1f}s | {u_epoch:.1f}s |", + f"| Preprocessing speedup | — | {prep_speedup:.0f}x |", + ] + ) + + n_test_users = metrics["sasrec"]["n_users"] + lines.extend( + [ + "", + f"## Quality (test set, {n_test_users:,} users)", + "", + "| Model | HR@10 | NDCG@10 | MRR@10 |", + "|---|---:|---:|---:|", + ] + ) + for name, key in [("SASRec", "sasrec"), ("UniSRec ID", "unisrec")]: + m = metrics[key] + lines.append(f"| {name} | {m['HR@10']:.4f} | {m['NDCG@10']:.4f} | {m['MRR@10']:.4f} |") + + hr_diff = (metrics["unisrec"]["HR@10"] / metrics["sasrec"]["HR@10"] - 1) * 100 + ndcg_diff = (metrics["unisrec"]["NDCG@10"] / metrics["sasrec"]["NDCG@10"] - 1) * 100 + lines.extend( + [ + "", + f"UniSRec vs SASRec: HR@10 {hr_diff:+.1f}%, NDCG@10 {ndcg_diff:+.1f}%", + ] + ) + + report = "\n".join(lines) + "\n" + REPORT_PATH.write_text(report) + print(f"\nReport written to {REPORT_PATH}") + return report + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("This benchmark requires CUDA. No GPU detected.") + torch.set_float32_matmul_precision("high") + timings = {} + + print(f"SASRec vs UniSRec-ID | {EPOCHS} epochs | n_factors={N_FACTORS} | Adam | softmax") + print("=" * 70) + + # ── Data ── + t0 = time.time() + ratings = load_and_preprocess() + train_ratings, val_ratings, test_ratings = split_eval(ratings) + train_with_val = pd.concat([train_ratings, val_ratings]) + timings["data_load"] = time.time() - t0 + + data_info = { + "n_interactions": len(ratings), + "n_users": ratings["user_id"].nunique(), + "n_items": ratings["item_id"].nunique(), + "n_train": len(train_ratings), + "n_val": len(val_ratings), + "n_test": len(test_ratings), + } + n_int = data_info["n_interactions"] + n_usr = data_info["n_users"] + n_itm = data_info["n_items"] + print(f"Data: {n_int:,} interactions, {n_usr:,} users, {n_itm:,} items") + print(f"Split: train={data_info['n_train']:,}, val={data_info['n_val']:,}, test={data_info['n_test']:,}") + + user_ids_t, item_ids_t, timestamps_t = to_tensors(train_with_val) + pretrained = get_pretrained_embeddings(ratings["item_id"]) + + # ══════════════════════════════════════════════════════════════ + # 1. SASRec (RecTools) + # ══════════════════════════════════════════════════════════════ + print(f"\n{'=' * 70}") + print(f"1. SASRec (RecTools) — {EPOCHS} epochs") + print(f"{'=' * 70}") + + # Preprocessing + t0 = time.time() + df_rectools = pd.DataFrame( + { + Columns.User: train_with_val["user_id"].values, + Columns.Item: train_with_val["item_id"].values, + Columns.Weight: 1.0, + Columns.Datetime: pd.to_datetime(train_with_val["timestamp"], unit="s"), + } + ) + dataset = Dataset.construct(df_rectools) + timings["sasrec_preprocessing"] = time.time() - t0 + print(f" Preprocessing (Dataset.construct): {timings['sasrec_preprocessing']:.2f}s") + + # Model init + training + def sasrec_trainer(**kwargs): + import pytorch_lightning as pl + + callbacks = [] + if PATIENCE is not None: + from pytorch_lightning.callbacks import EarlyStopping + + callbacks.append(EarlyStopping(monitor="val_loss", patience=PATIENCE, mode="min")) + return pl.Trainer( + max_epochs=EPOCHS, + min_epochs=1, + callbacks=callbacks or None, + enable_checkpointing=False, + enable_model_summary=False, + logger=True, + enable_progress_bar=True, + devices=1, + ) + + sasrec_kwargs = dict( + n_factors=N_FACTORS, + n_blocks=N_BLOCKS, + n_heads=N_HEADS, + session_max_len=SESSION_MAX_LEN, + dropout_rate=0.1, + loss="softmax", + lr=LR, + batch_size=BATCH_SIZE, + epochs=EPOCHS, + train_min_user_interactions=MIN_USER_INTERACTIONS, + dataloader_num_workers=0, + verbose=1, + get_trainer_func=sasrec_trainer, + ) + if PATIENCE is not None: + + def sasrec_val_mask(interactions_df, **kwargs): + idx = interactions_df.groupby(Columns.User).tail(1).index + mask = pd.Series(False, index=interactions_df.index) + mask.loc[idx] = True + return mask + + sasrec_kwargs["get_val_mask_func"] = sasrec_val_mask + + t0 = time.time() + sasrec = SASRecModel(**sasrec_kwargs) + timings["sasrec_model_init"] = time.time() - t0 + + t0 = time.time() + sasrec.fit(dataset) + timings["sasrec_training"] = time.time() - t0 + timings["sasrec_epochs_done"] = sasrec.fit_trainer.current_epoch + 1 + print(f" Training: {timings['sasrec_training']:.1f}s, {timings['sasrec_epochs_done']} epochs") + + # Eval + print(" Evaluating...") + t0 = time.time() + sasrec_metrics = evaluate_sasrec(sasrec, dataset, test_ratings) + timings["sasrec_eval"] = time.time() - t0 + print(f" Eval: {timings['sasrec_eval']:.1f}s") + hr = sasrec_metrics["HR@10"] + ndcg = sasrec_metrics["NDCG@10"] + mrr = sasrec_metrics["MRR@10"] + print(f" HR@10={hr:.4f} NDCG@10={ndcg:.4f} MRR@10={mrr:.4f}") + del sasrec + cleanup() + + # ══════════════════════════════════════════════════════════════ + # 2. UniSRec ID + # ══════════════════════════════════════════════════════════════ + print(f"\n{'=' * 70}") + print(f"2. UniSRec ID — {EPOCHS} epochs") + print(f"{'=' * 70}") + + # Preprocessing + torch.cuda.synchronize() + t0 = time.time() + _ = build_sequences(user_ids_t, item_ids_t, timestamps_t, max_len=SESSION_MAX_LEN, device="cuda") + torch.cuda.synchronize() + timings["unisrec_preprocessing"] = time.time() - t0 + print(f" Preprocessing (build_sequences): {timings['unisrec_preprocessing']:.4f}s") + timings["prep_speedup"] = timings["sasrec_preprocessing"] / max(timings["unisrec_preprocessing"], 1e-6) + print(f" Speedup vs Dataset.construct: {timings['prep_speedup']:.0f}x") + + # Model init + t0 = time.time() + unisrec_id = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=N_FACTORS, + projection_hidden=N_FACTORS, + n_blocks=N_BLOCKS, + n_heads=N_HEADS, + session_max_len=SESSION_MAX_LEN, + dropout=0.1, + adaptor_dropout=0.2, + adaptor_type="pca", + use_adaptor_ffn=True, + epochs=EPOCHS, + lr=LR, + optimizer="adam", + grad_clip=1.0, + weight_decay=0.0, + loss="softmax", + patience=PATIENCE, + batch_size=BATCH_SIZE, + dataloader_num_workers=0, + train_min_user_interactions=MIN_USER_INTERACTIONS, + device="cuda", + verbose=1, + ) + timings["unisrec_model_init"] = time.time() - t0 + + # Training + t0 = time.time() + unisrec_id.fit(user_ids_t, item_ids_t, timestamps_t) + timings["unisrec_training"] = time.time() - t0 + timings["unisrec_epochs_done"] = EPOCHS + print(f" Training (total fit): {timings['unisrec_training']:.1f}s") + + # Eval + print(" Evaluating...") + t0 = time.time() + unisrec_metrics = evaluate_unisrec(unisrec_id, train_with_val, test_ratings) + timings["unisrec_eval"] = time.time() - t0 + print(f" Eval: {timings['unisrec_eval']:.1f}s") + hr = unisrec_metrics["HR@10"] + ndcg = unisrec_metrics["NDCG@10"] + mrr = unisrec_metrics["MRR@10"] + print(f" HR@10={hr:.4f} NDCG@10={ndcg:.4f} MRR@10={mrr:.4f}") + del unisrec_id + cleanup() + + # ── Report ── + metrics = {"sasrec": sasrec_metrics, "unisrec": unisrec_metrics} + report = write_report(timings, metrics, data_info) + print("\n" + report) + + +if __name__ == "__main__": + main() diff --git a/benchmark/comparison_report.md b/benchmark/comparison_report.md new file mode 100644 index 00000000..fd136387 --- /dev/null +++ b/benchmark/comparison_report.md @@ -0,0 +1,58 @@ +# SASRec vs UniSRec-ID Comparison + +**Date:** 2026-04-24 19:59 +**GPU:** NVIDIA GeForce RTX 4090 +**Dataset:** ML-20M (min_rating=-1, min_item=5, min_user=2) + +## Data + +| | Count | +|---|---:| +| Interactions | 19,984,024 | +| Users | 138,493 | +| Items | 18,345 | +| Train | 19,707,038 | +| Val | 138,493 | +| Test | 138,493 | + +## Config + +| Parameter | Value | +|---|---| +| n_factors | 256 | +| n_blocks | 2 | +| n_heads | 1 | +| session_max_len | 200 | +| batch_size | 128 | +| lr | 0.001 | +| loss | softmax | +| optimizer | Adam | +| epochs | 10 | +| patience | None | +| dropout | 0.1 | + +## Timing + +| Stage | SASRec | UniSRec ID | +|---|---:|---:| +| Data load & split | 0.0s | 0.0s | +| Preprocessing | 14.6s | 0.5s | +| Model init | 0.0s | 0.0s | +| Training (10 epochs) | 911.8s | 639.5s | +| Evaluation | 175.6s | 28.0s | +| **Total** | **1102.1s** | **668.0s** | + +| | SASRec | UniSRec ID | +|---|---:|---:| +| Epochs completed | 11 | 10 | +| Time per epoch | 82.9s | 63.9s | +| Preprocessing speedup | — | 29x | + +## Quality (test set, 138,493 users) + +| Model | HR@10 | NDCG@10 | MRR@10 | +|---|---:|---:|---:| +| SASRec | 0.2417 | 0.1410 | 0.1103 | +| UniSRec ID | 0.2528 | 0.1495 | 0.1179 | + +UniSRec vs SASRec: HR@10 +4.6%, NDCG@10 +6.0% diff --git a/rectools/fast_transformers/__init__.py b/rectools/fast_transformers/__init__.py new file mode 100644 index 00000000..1803f728 --- /dev/null +++ b/rectools/fast_transformers/__init__.py @@ -0,0 +1,27 @@ +"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" + +from .metrics import compute_metrics, hitrate_at_k, mrr_at_k, ndcg_at_k +from .net import FlatSASRec, SASRecBlock +from .preprocessing import ( + SequenceBatchDataset, + align_embeddings, + build_sequences, +) +from .unisrec import UniSRec, UniSRecLightning, UniSRecModel +from .unisrec.net import FeedForward + +__all__ = [ + "build_sequences", + "align_embeddings", + "SequenceBatchDataset", + "FlatSASRec", + "SASRecBlock", + "UniSRec", + "FeedForward", + "UniSRecLightning", + "UniSRecModel", + "hitrate_at_k", + "ndcg_at_k", + "mrr_at_k", + "compute_metrics", +] diff --git a/rectools/fast_transformers/metrics.py b/rectools/fast_transformers/metrics.py new file mode 100644 index 00000000..80bcbc06 --- /dev/null +++ b/rectools/fast_transformers/metrics.py @@ -0,0 +1,150 @@ +"""GPU-friendly ranking metrics for leave-one-out evaluation. + +All functions operate on PyTorch tensors and stay on the original device +(CPU or CUDA), avoiding numpy/pandas roundtrips. Results are numerically +identical to the corresponding RecTools metrics with default settings: + +- :class:`rectools.metrics.HitRate` (k=K) +- :class:`rectools.metrics.NDCG` (k=K, log_base=2, divide_by_achievable=False) +- :class:`rectools.metrics.MRR` (k=K) + +These functions assume **leave-one-out** evaluation: each user has exactly +one ground-truth target item. +""" + +import typing as tp + +import torch + + +@torch.no_grad() +def hitrate_at_k( + topk_ids: torch.Tensor, + targets: torch.Tensor, +) -> torch.Tensor: + """Hit Rate @ K (leave-one-out). + + Parameters + ---------- + topk_ids : LongTensor (B, K) + Top-K predicted item IDs per user. + targets : LongTensor (B,) + Ground-truth item ID per user. + + Returns + ------- + Tensor (scalar) + Mean hit rate across users. + """ + hits = (topk_ids == targets.unsqueeze(1)).any(dim=1) + return hits.float().mean() + + +@torch.no_grad() +def ndcg_at_k( + topk_ids: torch.Tensor, + targets: torch.Tensor, + log_base: int = 2, +) -> torch.Tensor: + """NDCG @ K (leave-one-out, divide_by_achievable=False). + + Matches :class:`rectools.metrics.NDCG` with default parameters. + IDCG is computed as the maximum possible DCG when all K positions are + relevant (constant across users), which is the RecTools default. + + Parameters + ---------- + topk_ids : LongTensor (B, K) + Top-K predicted item IDs per user. + targets : LongTensor (B,) + Ground-truth item ID per user. + log_base : int, default 2 + Logarithm base for the discount factor. + + Returns + ------- + Tensor (scalar) + Mean NDCG across users. + """ + k = topk_ids.shape[1] + hits = (topk_ids == targets.unsqueeze(1)).float() # (B, K) + ranks = torch.arange(1, k + 1, device=topk_ids.device, dtype=torch.float) + discounts = 1.0 / torch.log(ranks + 1) * (1.0 / _log(log_base)) + dcg = (hits * discounts.unsqueeze(0)).sum(dim=1) # (B,) + idcg = discounts.sum() + return (dcg / idcg).mean() + + +@torch.no_grad() +def mrr_at_k( + topk_ids: torch.Tensor, + targets: torch.Tensor, +) -> torch.Tensor: + """MRR @ K (leave-one-out). + + Parameters + ---------- + topk_ids : LongTensor (B, K) + Top-K predicted item IDs per user. + targets : LongTensor (B,) + Ground-truth item ID per user. + + Returns + ------- + Tensor (scalar) + Mean reciprocal rank across users. + """ + hits = topk_ids == targets.unsqueeze(1) # (B, K) + # For each user find the rank of the first hit (1-based), 0 if no hit + has_hit = hits.any(dim=1) + # argmax returns the first True index + first_hit_rank = hits.float().argmax(dim=1) + 1 # (B,) + rr = torch.zeros_like(first_hit_rank, dtype=torch.float) + rr[has_hit] = 1.0 / first_hit_rank[has_hit].float() + return rr.mean() + + +@torch.no_grad() +def compute_metrics( + topk_ids: torch.Tensor, + targets: torch.Tensor, + ks: tp.Optional[tp.List[int]] = None, + log_base: int = 2, +) -> tp.Dict[str, float]: + """Compute HR, NDCG, MRR at multiple K values. + + Parameters + ---------- + topk_ids : LongTensor (B, K_max) + Top-K_max predicted item IDs per user. + targets : LongTensor (B,) + Ground-truth item ID per user. + ks : list of int, optional + K values to evaluate. Defaults to ``[K_max]``. + log_base : int, default 2 + Logarithm base for NDCG discount. + + Returns + ------- + dict + Keys like ``"HR@10"``, ``"NDCG@10"``, ``"MRR@10"``. + """ + k_max = topk_ids.shape[1] + if ks is None: + ks = [k_max] + results: tp.Dict[str, float] = {} + for k in ks: + if k > k_max: + raise ValueError(f"k={k} exceeds topk_ids width {k_max}") + top = topk_ids[:, :k] + results[f"HR@{k}"] = hitrate_at_k(top, targets).item() + results[f"NDCG@{k}"] = ndcg_at_k(top, targets, log_base=log_base).item() + results[f"MRR@{k}"] = mrr_at_k(top, targets).item() + return results + + +def _log(base: int) -> float: + """Natural log of base (cached constant).""" + import math + + return math.log(base) diff --git a/rectools/fast_transformers/net.py b/rectools/fast_transformers/net.py new file mode 100644 index 00000000..f9e06b00 --- /dev/null +++ b/rectools/fast_transformers/net.py @@ -0,0 +1,175 @@ +"""Flat SASRec network: pre-norm transformer encoder with plain id embeddings.""" + +import typing as tp + +import torch +from torch import nn + + +class SASRecBlock(nn.Module): + """Pre-norm transformer block: LayerNorm -> MHA -> residual -> LayerNorm -> FFN -> residual.""" + + def __init__(self, n_factors: int, n_heads: int, dropout: float = 0.1) -> None: + super().__init__() + self.ln1 = nn.LayerNorm(n_factors) + self.mha = nn.MultiheadAttention(n_factors, n_heads, dropout=dropout, batch_first=True) + self.ln2 = nn.LayerNorm(n_factors) + self.ffn = nn.Sequential( + nn.Linear(n_factors, n_factors * 4), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(n_factors * 4, n_factors), + nn.Dropout(dropout), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: tp.Optional[torch.Tensor] = None, + key_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = self.ln1(x) + h, _ = self.mha(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) + x = x + h + h = self.ln2(x) + x = x + self.ffn(h) + return x + + +class FlatSASRec(nn.Module): + """ + Flat SASRec: sequential recommender with plain id-embedding table + (no ItemNet hierarchy). + + Parameters + ---------- + n_items : int + Total number of items (excluding padding token 0). + n_factors : int + Embedding / hidden dimension. + n_blocks : int + Number of transformer blocks. + n_heads : int + Number of attention heads. + session_max_len : int + Maximum sequence length. + dropout : float + Dropout rate. + """ + + PADDING_IDX = 0 + + def __init__( + self, + n_items: int, + n_factors: int, + n_blocks: int, + n_heads: int, + session_max_len: int, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.n_items = n_items + self.n_factors = n_factors + self.session_max_len = session_max_len + + # +1 for padding at index 0 + self.item_emb = nn.Embedding(n_items + 1, n_factors, padding_idx=self.PADDING_IDX) + self.pos_emb = nn.Embedding(session_max_len, n_factors) + self.emb_dropout = nn.Dropout(dropout) + + self.blocks = nn.ModuleList([SASRecBlock(n_factors, n_heads, dropout) for _ in range(n_blocks)]) + self.final_ln = nn.LayerNorm(n_factors) + + def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode full sequence. + + Parameters + ---------- + x : LongTensor (B, L) + Item id sequences (0 = padding). + + Returns + ------- + Tensor (B, L, D) + """ + B, L = x.shape + positions = torch.arange(L, device=x.device).unsqueeze(0) + h = self.item_emb(x) + self.pos_emb(positions) + h = self.emb_dropout(h) + + # timeline_mask: zero out padding positions to prevent NaN from attention + timeline_mask = (x != self.PADDING_IDX).unsqueeze(-1).float() # (B, L, 1) + attn_mask = self._causal_mask(L, x.device) + key_padding_mask = x == self.PADDING_IDX + + for block in self.blocks: + h = h * timeline_mask + h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + h = h * timeline_mask + h = self.final_ln(h) + return h + + def encode_last(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode and return only the last non-padding position representation. + + Parameters + ---------- + x : LongTensor (B, L) + + Returns + ------- + Tensor (B, D) + """ + h = self.encode(x) # (B, L, D) + return h[:, -1, :] # left-padded: last position is always rightmost + + def all_item_embeddings(self) -> torch.Tensor: + """ + Return embeddings for all items (1..n_items), excluding padding. + + Returns + ------- + Tensor (n_items, D) + """ + ids = torch.arange(1, self.n_items + 1, device=self.item_emb.weight.device) + return self.item_emb(ids) + + def forward(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Training forward pass. + + Parameters + ---------- + batch : dict + Must contain 'x' (B, L) and 'y' (B, L). + Optionally 'negatives' (B, L, N) for candidate-logits branch. + + Returns + ------- + logits : Tensor + If negatives present: (B, L, 1 + N) — positive + negative logits. + Otherwise: (B, L, n_items) — full catalog logits. + """ + x = batch["x"] # (B, L) + y = batch["y"] # (B, L) + + h = self.encode(x) # (B, L, D) + + if "negatives" in batch: + negatives = batch["negatives"] # (B, L, N) + pos_emb = self.item_emb(y).unsqueeze(3) # (B, L, D, 1) + neg_emb = self.item_emb(negatives) # (B, L, N, D) + neg_emb = neg_emb.transpose(2, 3) # (B, L, D, N) + all_emb = torch.cat([pos_emb, neg_emb], dim=3) # (B, L, D, 1+N) + logits = (h.unsqueeze(2) @ all_emb).squeeze(2) # (B, L, 1+N) + # -> shape is (B, L, 1+N) where first column is positive logit + else: + item_embs = self.all_item_embeddings() # (n_items, D) + logits = h @ item_embs.T # (B, L, n_items) + return logits diff --git a/rectools/fast_transformers/preprocessing/__init__.py b/rectools/fast_transformers/preprocessing/__init__.py new file mode 100644 index 00000000..507b1c0a --- /dev/null +++ b/rectools/fast_transformers/preprocessing/__init__.py @@ -0,0 +1,13 @@ +"""Vectorized sequence preprocessing for transformer recommenders.""" + +from .sequence_data import ( + SequenceBatchDataset, + align_embeddings, + build_sequences, +) + +__all__ = [ + "build_sequences", + "align_embeddings", + "SequenceBatchDataset", +] diff --git a/rectools/fast_transformers/preprocessing/sequence_data.py b/rectools/fast_transformers/preprocessing/sequence_data.py new file mode 100644 index 00000000..4c23c99a --- /dev/null +++ b/rectools/fast_transformers/preprocessing/sequence_data.py @@ -0,0 +1,173 @@ +"""Vectorized sequence building for transformer recommender training. + +All operations use pure PyTorch tensor ops, avoiding pandas/numpy overhead. +On GPU this gives ~30x speedup over pandas-based preprocessing on ML-20M. +""" + +import typing as tp + +import torch +from torch.utils.data import Dataset as TorchDataset + + +def build_sequences( + user_ids: torch.Tensor, + item_ids: torch.Tensor, + timestamps: torch.Tensor, + max_len: int, + min_interactions: int = 2, + device: tp.Optional[str] = None, +) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build left-padded input/target sequence pairs from interaction data. + + Groups interactions by user, sorts by timestamp, and produces + ``(x, y)`` pairs where ``y[i, j] = x[i, j+1]`` (next-item prediction). + Item IDs are remapped to contiguous internal indices ``1..N`` + (0 is reserved for padding). + + Parameters + ---------- + user_ids : LongTensor (N,) + User ID for each interaction. + item_ids : LongTensor (N,) + Item ID for each interaction. + timestamps : LongTensor (N,) + Timestamp for each interaction (any monotonic int64 values). + max_len : int + Maximum sequence length. + min_interactions : int, default 2 + Minimum interactions per user to be included. + device : str, optional + Device for computation. Defaults to the device of ``user_ids`` + (pass ``"cuda"`` explicitly for GPU acceleration). + + Returns + ------- + x : LongTensor (U, max_len) + Left-padded input sequences (0 = padding). + y : LongTensor (U, max_len) + Left-padded target sequences. + unique_items : LongTensor + External item IDs that appear in the data. + result_users : LongTensor + External user IDs that passed the ``min_interactions`` filter. + + Examples + -------- + >>> users = torch.tensor([0, 0, 0, 1, 1, 1]) + >>> items = torch.tensor([10, 20, 30, 40, 50, 60]) + >>> times = torch.tensor([1, 2, 3, 1, 2, 3]) + >>> x, y, uniq_items, uniq_users = build_sequences(users, items, times, max_len=4) + >>> x.shape[1] + 4 + """ + if device is None: + device = str(user_ids.device) + user_ids = user_ids.to(device) + item_ids = item_ids.to(device) + timestamps = timestamps.to(device) + + unique_items, item_inv = torch.unique(item_ids, return_inverse=True) + internal_items = item_inv + 1 + + unique_users, user_inv = torch.unique(user_ids, return_inverse=True) + + order1 = torch.argsort(timestamps, stable=True) + order2 = torch.argsort(user_inv[order1], stable=True) + order = order1[order2] + + sorted_user_inv = user_inv[order] + sorted_items = internal_items[order] + + changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1 + starts = torch.cat([torch.tensor([0], device=device), changes]) + ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)]) + lengths = ends - starts + + mask = lengths >= min_interactions + starts = starts[mask] + ends = ends[mask] + lengths = lengths[mask] + n_users = len(starts) + + capped_lens = torch.clamp(lengths, max=max_len + 1) + + effective_lens = torch.clamp(capped_lens - 1, min=0) + total_elements = effective_lens.sum().item() + + x = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + y = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + + if total_elements > 0: + user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) + cumsum = effective_lens.cumsum(0) + offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave( + cumsum - effective_lens, effective_lens + ) + + x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets + y_src = x_src + 1 + col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets + + x[user_indices, col_indices] = sorted_items[x_src] + y[user_indices, col_indices] = sorted_items[y_src] + + valid_user_indices = torch.where(mask)[0] + result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users + + return x, y, unique_items, result_users + + +def align_embeddings( + pretrained: torch.Tensor, + unique_items: torch.Tensor, + n_items: int, +) -> torch.Tensor: + """Reorder a pretrained embedding matrix to match internal item ID order. + + Internal IDs are contiguous ``1..n_items`` as produced by + :func:`build_sequences`. Index 0 is padding (zeros). + + Parameters + ---------- + pretrained : Tensor (V, D) or (V, K, D) + Pretrained embeddings indexed by external item ID. + unique_items : LongTensor + External item IDs returned by :func:`build_sequences`. + n_items : int + Number of unique items. + + Returns + ------- + Tensor (n_items + 1, D) or (n_items + 1, K, D) + Aligned embeddings with padding row at index 0. + """ + device = pretrained.device + idx = unique_items.long().to(device) + valid = (idx >= 0) & (idx < pretrained.shape[0]) + + if pretrained.ndim == 2: + aligned = torch.zeros(n_items + 1, pretrained.shape[1], device=device) + else: + aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2], device=device) + + aligned[1:][valid] = pretrained[idx[valid]] + return aligned + + +class SequenceBatchDataset(TorchDataset): + """Lightweight Dataset wrapping prebuilt (x, y) sequence tensors.""" + + def __init__(self, x: torch.Tensor, y: torch.Tensor, transform: tp.Optional[tp.Callable] = None): + self.x = x + self.y = y + self.transform = transform + + def __len__(self) -> int: + return len(self.x) + + def __getitem__(self, idx: int) -> tp.Dict[str, torch.Tensor]: + batch = {"x": self.x[idx], "y": self.y[idx]} + if self.transform: + batch = self.transform(batch) + return batch diff --git a/rectools/fast_transformers/unisrec/__init__.py b/rectools/fast_transformers/unisrec/__init__.py new file mode 100644 index 00000000..dac2611d --- /dev/null +++ b/rectools/fast_transformers/unisrec/__init__.py @@ -0,0 +1,12 @@ +"""UniSRec: sequential recommender with pretrained text embeddings.""" + +from .lightning import UniSRecLightning +from .model import UniSRecModel +from .net import FeedForward, UniSRec + +__all__ = [ + "UniSRec", + "FeedForward", + "UniSRecLightning", + "UniSRecModel", +] diff --git a/rectools/fast_transformers/unisrec/demo_kion.md b/rectools/fast_transformers/unisrec/demo_kion.md new file mode 100644 index 00000000..9d715124 --- /dev/null +++ b/rectools/fast_transformers/unisrec/demo_kion.md @@ -0,0 +1,262 @@ +# UniSRec Training Demo: KION Dataset + +This guide demonstrates training a UniSRec sequential recommender on the KION movie dataset using real text embeddings from movie descriptions. + +## Overview + +UniSRec jointly trains a PCA-based adaptor and a SASRec transformer encoder on frozen pretrained text embeddings. This allows the model to leverage semantic item representations without requiring collaborative item IDs. + +## Prerequisites + +```bash +pip install torch pytorch-lightning sentence-transformers +``` + +## 1. Prepare Data + +### Download the KION dataset + +```bash +git clone https://github.com/irsafilo/KION_DATASET kion_data +``` + +### Load and filter interactions + +```python +import pandas as pd +import torch + +# Load interactions +interactions = pd.read_csv("kion_data/interactions.csv") +interactions = interactions.rename(columns={"last_watch_dt": "timestamp"}) +interactions["timestamp"] = pd.to_datetime(interactions["timestamp"]).astype(int) // 10**9 + +# Filter: min 5 interactions per item, min 2 per user +item_counts = interactions.groupby("item_id").size() +interactions = interactions[interactions["item_id"].isin(item_counts[item_counts >= 5].index)] +user_counts = interactions.groupby("user_id").size() +interactions = interactions[interactions["user_id"].isin(user_counts[user_counts >= 2].index)] + +print(f"Interactions: {len(interactions):,}") +print(f"Users: {interactions['user_id'].nunique():,}") +print(f"Items: {interactions['item_id'].nunique():,}") +# Interactions: 643,786 +# Users: 201,851 +# Items: 6,228 +``` + +### Leave-last-out split + +```python +interactions = interactions.sort_values(["user_id", "timestamp"]) +test = interactions.groupby("user_id").tail(1) +train_val = interactions.drop(test.index) + +print(f"Train+Val: {len(train_val):,}, Test: {len(test):,}") +# Train+Val: 441,935, Test: 201,851 +``` + +## 2. Generate Text Embeddings + +Use English movie descriptions from the dataset with Qwen3-Embedding-0.6B: + +```bash +pip install transformers +``` + +```python +from transformers import AutoTokenizer, AutoModel + +# Load item metadata (English descriptions) +items = pd.read_csv("kion_data/data_en/items_en.csv") +items = items.set_index("item_id") + +# Build description text +texts = {} +for item_id, row in items.iterrows(): + parts = [str(row.get("title", ""))] + if pd.notna(row.get("description")): + parts.append(str(row["description"])) + if pd.notna(row.get("genres")): + parts.append(f"Genres: {row['genres']}") + texts[item_id] = " ".join(parts) + +# Encode with Qwen3-Embedding-0.6B +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") +encoder = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B", dtype=torch.float16) +encoder.cuda().eval() + +max_item_id = items.index.max() +embeddings = torch.zeros(max_item_id + 1, 1024) + +item_ids_list = list(texts.keys()) +text_list = list(texts.values()) + +with torch.no_grad(): + for start in range(0, len(text_list), 32): + batch_texts = text_list[start:start + 32] + batch_ids = item_ids_list[start:start + 32] + encoded = tokenizer(batch_texts, return_tensors="pt", padding=True, + truncation=True, max_length=512).to("cuda") + outputs = encoder(**encoded) + mask = encoded["attention_mask"].unsqueeze(-1).half() + pooled = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) + pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1) + for i, item_id in enumerate(batch_ids): + embeddings[item_id] = pooled[i].cpu().float() + +torch.save(embeddings, "item_embeddings.pt") +print(f"Embeddings: {embeddings.shape}") +# Embeddings: torch.Size([16519, 1024]) +``` + +## 3. Train UniSRec + +```python +from rectools.fast_transformers import UniSRecModel + +embeddings = torch.load("item_embeddings.pt", weights_only=True) + +user_ids = torch.tensor(train_val["user_id"].values, dtype=torch.long) +item_ids = torch.tensor(train_val["item_id"].values, dtype=torch.long) +timestamps = torch.tensor(train_val["timestamp"].values, dtype=torch.long) + +model = UniSRecModel( + pretrained_item_embeddings=embeddings, + # Architecture + n_factors=256, + projection_hidden=512, + n_blocks=2, + n_heads=2, + session_max_len=50, + dropout=0.1, + adaptor_dropout=0.2, + adaptor_type="pca", + use_adaptor_ffn=True, + ffn_type="conv1d", + ffn_expansion=1, + # Training + epochs=10, + lr=1e-4, + lr_head=0.3, + lr_wp=0.1, + lr_transformer=3.0, + optimizer="adamw", + scheduler="cosine_warmup", + warmup_ratio=0.05, + min_lr_ratio=0.1, + grad_clip=1.0, + weight_decay=0.01, + loss="softmax", + batch_size=128, + train_min_user_interactions=2, + verbose=1, +) + +model.fit(user_ids, item_ids, timestamps) +# Training: ~194s on RTX 3090 (10 epochs) +``` + +### Save / load checkpoint + +```python +model.save_checkpoint("unisrec_kion.pt") + +# Later: +model2 = UniSRecModel(pretrained_item_embeddings=embeddings, n_factors=256, ...) +model2.load_checkpoint("unisrec_kion.pt", device="cuda") +``` + +## 4. Evaluate + +Leave-last-out evaluation with HR@K and NDCG@K: + +```python +import numpy as np + +net = model.net +net.eval().cuda() +device = torch.device("cuda") + +# Get projected item embeddings +item_embs = net.project_all() +unique_items = model.item_id_mapping +ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))} + +# Build user histories +train_grouped = train_val.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() +test_grouped = test.groupby("user_id")["item_id"].first().to_dict() +test_users = list(test_grouped.keys()) + +hits10, ndcg10, total = 0, 0.0, 0 +maxlen = model.session_max_len + +with torch.no_grad(): + for start in range(0, len(test_users), 256): + batch_users = test_users[start:start + 256] + seqs, targets = [], [] + for uid in batch_users: + history = train_grouped.get(uid, []) + mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] + if not mapped: + continue + seq = mapped[-maxlen:] + seqs.append([0] * (maxlen - len(seq)) + seq) + targets.append(ext_to_int.get(test_grouped[uid])) + if not seqs: + continue + x = torch.tensor(seqs, dtype=torch.long, device=device) + h = net.encode_last(x) + scores = h @ item_embs.T + scores[:, 0] = float("-inf") + for i, target_int in enumerate(targets): + if target_int is None: + continue + _, topk = scores[i].topk(10) + topk = topk.cpu().tolist() + if target_int in topk: + rank = topk.index(target_int) + hits10 += 1 + ndcg10 += 1.0 / np.log2(rank + 2) + total += 1 + +print(f"HR@10 = {hits10/total:.4f}") +print(f"NDCG@10 = {ndcg10/total:.4f}") +``` + +## 5. Results + +Trained on NVIDIA RTX 3090, 10 epochs, same architecture (256d, 2 blocks, 2 heads, max_len=50): + +| Model | Embedder | HR@5 | NDCG@5 | HR@10 | NDCG@10 | Train Time | +|-------|----------|------|--------|-------|---------|------------| +| **UniSRec** | all-MiniLM-L6-v2 (384d) | 0.1421 | 0.0988 | 0.1896 | 0.1145 | ~194s | +| **UniSRec** | Qwen3-Embedding-0.6B (1024d) | 0.1529 | 0.1012 | 0.2018 | 0.1171 | ~178s | +| **SASRec** (RecTools) | ID embeddings | 0.1606 | 0.1081 | 0.2175 | 0.1265 | ~166s | + +Qwen3-Embedding-0.6B closes most of the gap to SASRec (HR@10 delta: 1.6pp vs 2.8pp with MiniLM). SASRec with learned ID embeddings is stronger when sufficient interaction data is available. UniSRec's advantage is in cold-start and transfer scenarios where text embeddings provide semantic signal for items with no interaction history. + +## Key Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `n_factors` | Hidden dimension of the transformer | 256 | +| `adaptor_type` | Adaptor type: `"pca"` or `"bn"` | `"pca"` | +| `session_max_len` | Maximum sequence length | 200 | +| `epochs` | Number of training epochs | 10 | +| `lr` | Base learning rate (adaptor layernorms) | 1e-4 | +| `lr_wp` | Multiplier for PCA whitening projection | 0.1 | +| `lr_transformer` | Multiplier for transformer layers | 3.0 | +| `lr_head` | Multiplier for head layer | 0.3 | +| `loss` | Loss function: `"softmax"`, `"BCE"`, `"gBCE"`, `"sampled_softmax"` | `"softmax"` | +| `patience` | Early stopping patience (None = disabled) | None | +| `scheduler` | LR scheduler: `None` or `"cosine_warmup"` | None | + +## ONNX Export + +```python +model.export_to_onnx( + encoder_path="unisrec_encoder.onnx", + items_path="unisrec_items.onnx", +) +``` diff --git a/rectools/fast_transformers/unisrec/lightning.py b/rectools/fast_transformers/unisrec/lightning.py new file mode 100644 index 00000000..e579e32f --- /dev/null +++ b/rectools/fast_transformers/unisrec/lightning.py @@ -0,0 +1,205 @@ +"""Lightning wrapper for UniSRec with configurable loss, optimizer, scheduler.""" + +import math +import typing as tp + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.optim.lr_scheduler import LambdaLR + +from .net import UniSRec + +SUPPORTED_LOSSES = ("softmax", "BCE", "gBCE", "sampled_softmax") +SUPPORTED_OPTIMIZERS = ("adam", "adamw") +SUPPORTED_SCHEDULERS = (None, "cosine_warmup") + + +class UniSRecLightning(pl.LightningModule): + """ + Thin Lightning wrapper for joint UniSRec training. + + Wraps a :class:`UniSRec` network with configurable loss, optimizer, + and learning-rate scheduler. + """ + + def __init__( + self, + net: UniSRec, + param_groups: tp.List[tp.Dict[str, tp.Any]], + loss: str = "softmax", + n_negatives: tp.Optional[int] = None, + gbce_t: float = 0.2, + optimizer: str = "adamw", + scheduler: tp.Optional[str] = None, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, + total_steps: tp.Optional[int] = None, + ) -> None: + super().__init__() + self.net = net + self._param_groups = param_groups + self.loss_name = loss + self.n_negatives = n_negatives + self.gbce_t = gbce_t + self.optimizer_name = optimizer + self.scheduler_name = scheduler + self.warmup_ratio = warmup_ratio + self.min_lr_ratio = min_lr_ratio + self.total_steps = total_steps + + # ── helpers ── + + def _get_item_embs(self, item_ids: torch.Tensor) -> torch.Tensor: + return self.net._adapt_score(self.net._sample_frozen(item_ids)) + + def _get_all_embs(self) -> torch.Tensor: + return self.net.project_all() + + def _get_pos_neg_logits( + self, + hidden: torch.Tensor, + labels: torch.Tensor, + negatives: torch.Tensor, + ) -> torch.Tensor: + """Compute (B, L, 1+N) logits where index 0 = positive.""" + emb_pos = self._get_item_embs(labels) + logits_pos = (hidden * emb_pos).sum(dim=-1) + + emb_neg = self._get_item_embs(negatives) + logits_neg = torch.matmul( + hidden.unsqueeze(2), + emb_neg.transpose(2, 3), + ).squeeze(2) + + return torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1) + + # ── losses ── + + def _calc_loss( + self, + hidden: torch.Tensor, + batch: tp.Dict[str, torch.Tensor], + ) -> torch.Tensor: + labels = batch["y"] + has_neg = "negatives" in batch + + if self.loss_name == "softmax": + return self._full_softmax_loss(hidden, labels) + + if not has_neg: + raise ValueError(f"Loss '{self.loss_name}' requires negatives but batch has none") + + logits = self._get_pos_neg_logits(hidden, labels, batch["negatives"]) + mask = labels != 0 + + if self.loss_name == "sampled_softmax": + return self._sampled_softmax_loss(logits, mask) + if self.loss_name == "BCE": + return self._bce_loss(logits, mask) + if self.loss_name == "gBCE": + return self._gbce_loss(logits, mask) + + raise ValueError(f"Unknown loss: {self.loss_name}") + + def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + all_emb = self._get_all_embs() + logits = hidden @ all_emb.T + logits[:, :, 0] = float("-inf") + + targets = labels.clone() + targets[targets == 0] = -100 + return F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-100, + ) + + def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Sampled softmax: positive at index 0, swap to index 1 so index 0 can be ignored.""" + logits = logits.clone() + logits[:, :, [0, 1]] = logits[:, :, [1, 0]] + targets = mask.long() # 1 where non-padding, 0 where padding + return F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=0, + ) + + def _bce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + target = torch.zeros_like(logits) + target[:, :, 0] = 1.0 + loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + loss = loss.mean(-1) * mask + return loss.sum() / mask.sum().clamp(min=1) + + def _gbce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + n_items = self.net.n_items + n_neg = self.n_negatives or logits.size(-1) - 1 + alpha = n_neg / max(n_items - 1, 1) + beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) + + dtype = torch.float64 + pos_logits = logits[:, :, 0:1].to(dtype) + neg_logits = logits[:, :, 1:] + + eps = 1e-10 + pos_probs = torch.clamp(torch.sigmoid(pos_logits), eps, 1 - eps) + pos_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + eps, torch.finfo(dtype).max) + pos_adjusted = torch.clamp(1.0 / (pos_adjusted - 1), eps, torch.finfo(dtype).max) + pos_transformed = torch.log(pos_adjusted).to(logits.dtype) + + adjusted_logits = torch.cat([pos_transformed, neg_logits], dim=-1) + return self._bce_loss(adjusted_logits, mask) + + # ── training / validation ── + + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + hidden = self.net(batch["x"]) + loss = self._calc_loss(hidden, batch) + self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss + + def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + hidden = self.net(batch["x"]) + # Validation batch has y of shape (B, 1) -- take last hidden position only + hidden = hidden[:, -1:, :] + loss = self._calc_loss(hidden, batch) + self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss + + # ── optimizer / scheduler ── + + def configure_optimizers(self) -> tp.Any: + if self.optimizer_name == "adamw": + opt = torch.optim.AdamW(self._param_groups) + elif self.optimizer_name == "adam": + opt = torch.optim.Adam(self._param_groups) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer_name}") + + if self.scheduler_name is None: + return opt + + if self.scheduler_name == "cosine_warmup": + total = self.total_steps or 1 + warmup = int(total * self.warmup_ratio) + scheduler = _cosine_warmup_scheduler(opt, warmup, total, self.min_lr_ratio) + return {"optimizer": opt, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}} + + raise ValueError(f"Unknown scheduler: {self.scheduler_name}") + + +def _cosine_warmup_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, + min_lr_ratio: float = 0.0, +) -> LambdaLR: + def lr_lambda(step: int) -> float: + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda) diff --git a/rectools/fast_transformers/unisrec/model.py b/rectools/fast_transformers/unisrec/model.py new file mode 100644 index 00000000..4cb910b0 --- /dev/null +++ b/rectools/fast_transformers/unisrec/model.py @@ -0,0 +1,508 @@ +"""UniSRecModel: standalone sequential recommender with pretrained text embeddings.""" + +import typing as tp +from pathlib import Path + +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping +from torch.utils.data import DataLoader + +from ..preprocessing import SequenceBatchDataset, align_embeddings, build_sequences +from .lightning import SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS, UniSRecLightning +from .net import UniSRec + + +class _NegativeSampler: + """Add ``negatives`` field to a batch, avoiding positive collisions.""" + + def __init__(self, n_items: int, n_negatives: int) -> None: + self.n_items = n_items + self.n_negatives = n_negatives + + def __call__(self, batch: tp.Dict[str, torch.Tensor]) -> tp.Dict[str, torch.Tensor]: + y = batch["y"] + negs = torch.randint(1, self.n_items + 1, (*y.shape, self.n_negatives), device=y.device) + # Resample positions where negative == positive + collisions = negs == y.unsqueeze(-1) + if collisions.any(): + negs[collisions] = torch.randint(1, self.n_items + 1, (int(collisions.sum()),), device=y.device) + batch["negatives"] = negs + return batch + + +class _ProjectAllWrapper(torch.nn.Module): + def __init__(self, net: UniSRec) -> None: + super().__init__() + self.net = net + + def forward(self) -> torch.Tensor: + return self.net.project_all() + + +class UniSRecModel: + """ + UniSRec sequential recommender with pretrained text embeddings. + + Joint training of the adaptor and transformer encoder on + frozen pretrained embeddings (e.g. from a sentence-transformer). + + Parameters + ---------- + pretrained_item_embeddings : Tensor + Shape ``(max_external_item_id + 1, D_text)`` or + ``(max_external_item_id + 1, n_variants, D_text)``. + Index *i* holds the text embedding for the item whose **external** ID + equals *i*. Index 0 is padding (zeros). + """ + + def __init__( + self, + pretrained_item_embeddings: torch.Tensor, + # architecture + n_factors: int = 256, + projection_hidden: int = 512, + n_blocks: int = 2, + n_heads: int = 1, + session_max_len: int = 200, + dropout: float = 0.1, + adaptor_dropout: float = 0.2, + adaptor_type: str = "pca", + use_adaptor_ffn: bool = True, + ffn_type: str = "conv1d", + ffn_expansion: int = 1, + # training + epochs: int = 10, + lr: float = 1e-4, + lr_head: float = 0.3, + lr_wp: float = 0.1, + lr_transformer: float = 3.0, + # optimizer / scheduler + optimizer: str = "adamw", + scheduler: tp.Optional[str] = None, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, + grad_clip: float = 1.0, + weight_decay: float = 0.01, + # loss + loss: str = "softmax", + gbce_t: float = 0.2, + n_negatives: tp.Optional[int] = None, + # early stopping + patience: tp.Optional[int] = None, + # data + batch_size: int = 128, + dataloader_num_workers: int = 0, + train_min_user_interactions: int = 2, + device: tp.Optional[str] = None, + verbose: int = 0, + ) -> None: + if loss not in SUPPORTED_LOSSES: + raise ValueError(f"Unsupported loss '{loss}'. Choose from {SUPPORTED_LOSSES}") + if loss in ("BCE", "gBCE", "sampled_softmax"): + if not isinstance(n_negatives, int) or n_negatives <= 0: + raise ValueError(f"Loss '{loss}' requires n_negatives to be a positive integer") + if optimizer not in SUPPORTED_OPTIMIZERS: + raise ValueError(f"Unsupported optimizer '{optimizer}'. Choose from {SUPPORTED_OPTIMIZERS}") + if scheduler not in SUPPORTED_SCHEDULERS: + raise ValueError(f"Unsupported scheduler '{scheduler}'. Choose from {SUPPORTED_SCHEDULERS}") + + self.pretrained_item_embeddings = pretrained_item_embeddings + self.n_factors = n_factors + self.projection_hidden = projection_hidden + self.n_blocks = n_blocks + self.n_heads = n_heads + self.session_max_len = session_max_len + self.dropout = dropout + self.adaptor_dropout = adaptor_dropout + self.adaptor_type = adaptor_type + self.use_adaptor_ffn = use_adaptor_ffn + self.ffn_type = ffn_type + self.ffn_expansion = ffn_expansion + self.epochs = epochs + self.lr = lr + self.lr_head = lr_head + self.lr_wp = lr_wp + self.lr_transformer = lr_transformer + self.optimizer = optimizer + self.scheduler = scheduler + self.warmup_ratio = warmup_ratio + self.min_lr_ratio = min_lr_ratio + self.grad_clip = grad_clip + self.weight_decay = weight_decay + self.loss = loss + self.gbce_t = gbce_t + self.n_negatives = n_negatives + self.patience = patience + self.batch_size = batch_size + self.dataloader_num_workers = dataloader_num_workers + self.train_min_user_interactions = train_min_user_interactions + self.device = device + self.verbose = verbose + + self._net: tp.Optional[UniSRec] = None + self._unique_items: tp.Optional[torch.Tensor] = None + self._unique_users: tp.Optional[torch.Tensor] = None + self.is_fitted: bool = False + + # ── helpers ── + + def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer: + callbacks = [] + if self.patience is not None and val_dl is not None: + callbacks.append(EarlyStopping(monitor="val_loss", patience=self.patience, mode="min")) + + return pl.Trainer( + max_epochs=max_epochs, + gradient_clip_val=self.grad_clip, + callbacks=callbacks or None, + enable_checkpointing=False, + enable_model_summary=False, + logger=self.verbose > 0, + enable_progress_bar=self.verbose > 0, + ) + + def _make_lightning( + self, + net: UniSRec, + param_groups: tp.List[tp.Dict], + max_epochs: int, + train_dl: tp.Any, + ) -> UniSRecLightning: + total_steps = len(train_dl) * max_epochs if self.scheduler else None + return UniSRecLightning( + net=net, + param_groups=param_groups, + loss=self.loss, + n_negatives=self.n_negatives, + gbce_t=self.gbce_t, + optimizer=self.optimizer, + scheduler=self.scheduler, + warmup_ratio=self.warmup_ratio, + min_lr_ratio=self.min_lr_ratio, + total_steps=total_steps, + ) + + # ── param groups ── + + def _param_groups(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: + if self.adaptor_type == "pca": + adaptor: tp.List[tp.Dict[str, tp.Any]] = [ + {"params": [net.whitening_proj], "lr": self.lr * self.lr_wp, "weight_decay": 0.0}, + {"params": [net.whitening_bias], "lr": self.lr * 10.0, "weight_decay": 0.0}, + ] + else: + adaptor = [ + {"params": list(net.bn_input.parameters()), "lr": self.lr, "weight_decay": 0.0}, + {"params": list(net.bn_score.parameters()), "lr": self.lr, "weight_decay": 0.0}, + ] + head: tp.List[tp.Dict[str, tp.Any]] = [] + if net.head is not None: + head = [ + { + "params": list(net.head.parameters()), + "lr": self.lr * self.lr_head, + "weight_decay": self.weight_decay, + } + ] + transformer = [ + {"params": list(net.pos_emb.parameters()), "lr": self.lr * self.lr_transformer, "weight_decay": 0.0}, + { + "params": ( + [p for layer in net.attention_layers for p in layer.parameters()] + + [p for layer in net.forward_layers for p in layer.parameters()] + ), + "lr": self.lr * self.lr_transformer, + "weight_decay": self.weight_decay, + }, + { + "params": ( + [p for layer in net.attention_layernorms for p in layer.parameters()] + + [p for layer in net.forward_layernorms for p in layer.parameters()] + + list(net.last_layernorm.parameters()) + ), + "lr": self.lr, + "weight_decay": 0.0, + }, + ] + return adaptor + head + transformer + + # ── fit ── + + def fit( + self, + user_ids: torch.Tensor, + item_ids: torch.Tensor, + timestamps: torch.Tensor, + ) -> "UniSRecModel": + """ + Train the model on interaction data. + + Parameters + ---------- + user_ids : LongTensor (N,) + External user IDs for each interaction. + item_ids : LongTensor (N,) + External item IDs for each interaction. + timestamps : LongTensor (N,) + Timestamps (any monotonic int64 values). + + Returns + ------- + self + """ + x, y, unique_items, unique_users = build_sequences( + user_ids, + item_ids, + timestamps, + max_len=self.session_max_len, + min_interactions=self.train_min_user_interactions, + device=self.device, + ) + if len(x) == 0: + raise ValueError( + f"No users with >= {self.train_min_user_interactions} interactions. " "Cannot train on empty data." + ) + self._unique_items = unique_items.cpu() + self._unique_users = unique_users.cpu() + n_items = len(unique_items) + + aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items) + + net = UniSRec( + n_items=n_items, + pretrained_embeddings=aligned_emb, + n_factors=self.n_factors, + projection_hidden=self.projection_hidden, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + adaptor_dropout=self.adaptor_dropout, + adaptor_type=self.adaptor_type, + use_adaptor_ffn=self.use_adaptor_ffn, + ffn_type=self.ffn_type, + ffn_expansion=self.ffn_expansion, + ) + + # DataLoader with num_workers>0 requires CPU tensors + if x.is_cuda: + x, y = x.cpu(), y.cpu() + + neg_transform = None + if self.loss in ("BCE", "gBCE", "sampled_softmax"): + neg_transform = _NegativeSampler(n_items, self.n_negatives) + + train_dl = DataLoader( + SequenceBatchDataset(x, y, transform=neg_transform), + batch_size=self.batch_size, + shuffle=True, + num_workers=self.dataloader_num_workers, + ) + + val_dl = None + if self.patience is not None: + val_y_last = y[:, -1:] + val_dl = DataLoader( + SequenceBatchDataset(x, val_y_last, transform=neg_transform), + batch_size=self.batch_size, + shuffle=False, + num_workers=self.dataloader_num_workers, + ) + + lm = self._make_lightning(net, self._param_groups(net), self.epochs, train_dl) + trainer = self._make_trainer(self.epochs, val_dl) + trainer.fit(lm, train_dl, val_dl) + + self._net = net + self.is_fitted = True + return self + + # ── save / load ── + + def save_checkpoint(self, path: tp.Union[str, Path]) -> None: + assert self._net is not None + torch.save( + { + "net": self._net.state_dict(), + "unique_items": self._unique_items, + "unique_users": self._unique_users, + "n_items": len(self._unique_items), + }, + path, + ) + + def load_checkpoint(self, path: tp.Union[str, Path], device: tp.Optional[str] = None) -> None: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + ckpt = torch.load(path, map_location=device, weights_only=False) + self._unique_items = ckpt["unique_items"].cpu() + self._unique_users = ckpt["unique_users"].cpu() + n_items = ckpt["n_items"] + + aligned_emb = align_embeddings(self.pretrained_item_embeddings, self._unique_items, n_items) + + self._net = UniSRec( + n_items=n_items, + pretrained_embeddings=aligned_emb, + n_factors=self.n_factors, + projection_hidden=self.projection_hidden, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + adaptor_dropout=self.adaptor_dropout, + adaptor_type=self.adaptor_type, + use_adaptor_ffn=self.use_adaptor_ffn, + ffn_type=self.ffn_type, + ffn_expansion=self.ffn_expansion, + ) + self._net.load_state_dict(ckpt["net"]) + self._net.to(device).eval() + self.is_fitted = True + + # ── ONNX export ── + + def export_to_onnx( + self, + encoder_path: tp.Union[str, Path], + items_path: tp.Optional[tp.Union[str, Path]] = None, + opset_version: int = 18, + ) -> None: + """Export the model to ONNX. + + Parameters + ---------- + encoder_path + Path for the encoder graph (input_ids -> hidden states). + items_path + If given, also exports project_all (-> item embeddings). + opset_version + ONNX opset version (default 18). + """ + assert self._net is not None, "Model not fitted or loaded" + net = self._net + was_training = net.training + net.eval() + + device = next(net.parameters()).device + dummy = torch.zeros(1, 5, dtype=torch.long, device=device) + + torch.onnx.export( + net, + (dummy,), + str(encoder_path), + input_names=["input_ids"], + output_names=["hidden"], + opset_version=opset_version, + ) + + if items_path is not None: + wrapper = _ProjectAllWrapper(net) + wrapper.eval() + torch.onnx.export( + wrapper, + (), + str(items_path), + input_names=[], + output_names=["item_embs"], + opset_version=opset_version, + ) + + if was_training: + net.train() + + def map_item_ids(self, external_ids: torch.Tensor) -> torch.Tensor: + """Map external item IDs to internal IDs used by the model. + + Parameters + ---------- + external_ids : LongTensor + External item IDs. + + Returns + ------- + LongTensor + Internal IDs in ``[0, n_items]``. 0 means unknown item. + """ + assert self._unique_items is not None, "Model not fitted or loaded" + input_device = external_ids.device + external_cpu = external_ids.cpu() + sorted_items, sort_idx = self._unique_items.sort() + pos = torch.searchsorted(sorted_items, external_cpu) + pos = pos.clamp(max=len(sorted_items) - 1) + found = sorted_items[pos] == external_cpu + result = torch.zeros_like(external_cpu, dtype=torch.long) + result[found] = sort_idx[pos[found]] + 1 + return result.to(input_device) + + def recommend(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: + """Not supported. Use :meth:`predict_topk` instead. + + ``UniSRecModel`` operates on raw tensor sequences, not on + ``Dataset`` / user IDs expected by ``ModelBase.recommend()``. + Keeping the same name with a different signature would silently + break code that relies on the RecTools ``recommend`` contract. + """ + raise NotImplementedError( + "UniSRecModel does not implement recommend(). " + "Use predict_topk(input_ids, k) instead — it accepts " + "left-padded internal ID sequences and returns (scores, item_ids) tensors." + ) + + @torch.no_grad() + def predict_topk( + self, + input_ids: torch.Tensor, + k: int = 10, + ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Encode user sequences and return top-k items in a single GPU pass. + + This is the inference entry point for ``UniSRecModel``. It fuses + sequence encoding and dot-product ranking into one call, keeping + everything on GPU without intermediate numpy / scipy conversions. + + Compared to the ``TorchRanker.rank()`` path used by RecTools models: + + * Item embeddings (``project_all()``) are computed once and stay on + device, instead of being transferred to GPU on every batch. + * There is no encode → cpu → numpy → cuda → score → cpu → numpy + roundtrip — the encoder output feeds directly into scoring. + + Parameters + ---------- + input_ids : LongTensor (B, L) + Left-padded internal item ID sequences (0 = padding). + Use :meth:`map_item_ids` to convert external IDs to internal. + k : int + Number of items to return per user. + + Returns + ------- + scores : Tensor (B, k) + Dot-product scores, descending. + item_ids : LongTensor (B, k) + Internal item IDs (1-based). + """ + assert self._net is not None, "Model not fitted or loaded" + net = self._net + was_training = net.training + net.eval() + device = next(net.parameters()).device + h = net.encode_last(input_ids.to(device)) + item_embs = net.project_all() + scores = h @ item_embs.T + scores[:, 0] = float("-inf") + top_scores, top_ids = scores.topk(k, dim=1) + if was_training: + net.train() + return top_scores, top_ids + + @property + def net(self) -> UniSRec: + assert self._net is not None, "Model not fitted or loaded" + return self._net + + @property + def item_id_mapping(self) -> torch.Tensor: + return self._unique_items diff --git a/rectools/fast_transformers/unisrec/net.py b/rectools/fast_transformers/unisrec/net.py new file mode 100644 index 00000000..afff2f45 --- /dev/null +++ b/rectools/fast_transformers/unisrec/net.py @@ -0,0 +1,298 @@ +"""UniSRec network: SASRec encoder with pretrained text embeddings and learnable adaptor.""" + +import torch +from torch import nn + + +def _make_mlp(in_dim: int, hidden_dim: int, out_dim: int, dropout: float) -> nn.Sequential: + return nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, out_dim), + ) + + +class FeedForwardConv1d(nn.Module): + """Point-wise FFN via Conv1d (kernel_size=1), matching the reference UniSRec.""" + + def __init__(self, hidden_units: int, dropout_rate: float) -> None: + super().__init__() + self.conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + self.dropout1 = nn.Dropout(p=dropout_rate) + self.relu = nn.ReLU() + self.conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + self.dropout2 = nn.Dropout(p=dropout_rate) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + outputs = self.conv1(inputs.transpose(-1, -2)) + outputs = self.relu(self.dropout1(outputs)) + outputs = self.conv2(outputs) + outputs = self.dropout2(outputs) + return outputs.transpose(-1, -2) + + +# keep old name as alias +FeedForward = FeedForwardConv1d + + +def make_ffn(n_factors: int, ffn_type: str, expansion: int, dropout: float) -> nn.Module: + """Create a feed-forward block. + + Parameters + ---------- + ffn_type : ``"conv1d"`` | ``"linear_gelu"`` | ``"linear_relu"`` + expansion : hidden-dim multiplier (e.g. 1 or 4). + """ + if ffn_type == "conv1d": + return FeedForwardConv1d(n_factors, dropout) + hidden = n_factors * expansion + if ffn_type == "linear_gelu": + return nn.Sequential( + nn.Linear(n_factors, hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden, n_factors), + nn.Dropout(dropout), + ) + if ffn_type == "linear_relu": + return nn.Sequential( + nn.Linear(n_factors, hidden), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden, n_factors), + ) + raise ValueError(f"Unknown ffn_type: {ffn_type}. Choose from: conv1d, linear_gelu, linear_relu") + + +class UniSRec(nn.Module): + """ + UniSRec: sequential recommender with pretrained text embeddings + adaptor. + + Architecture:: + + frozen_emb --> adaptor (PCA/BN + optional MLP) --> SASRec encoder + + The adaptor projects frozen pretrained embeddings (e.g. from a + sentence-transformer) into the transformer hidden space. All training + is joint — adaptor and transformer are trained together in a single phase. + + Parameters + ---------- + n_items : int + Number of real items (excluding padding token at index 0). + pretrained_embeddings : Tensor + Shape ``(n_items + 1, D_text)`` or ``(n_items + 1, n_variants, D_text)``. + Index 0 = padding (zeros), indices 1..n_items = item text embeddings. + n_factors : int + Hidden / output dimension of the transformer. + projection_hidden : int + Intermediate dimension for the PCA adaptor head. + n_blocks : int + Number of transformer blocks. + n_heads : int + Number of attention heads. + session_max_len : int + Maximum sequence length (positional embedding size). + dropout : float + Dropout in transformer blocks. + adaptor_dropout : float + Dropout inside the adaptor MLP. + adaptor_type : ``"pca"`` | ``"bn"`` + Type of adaptor for projecting pretrained embeddings. + use_adaptor_ffn : bool + Whether to use a 2-layer MLP head after the linear projection. + initializer_range : float + Std for normal weight initialisation. + """ + + PADDING_IDX = 0 + + def __init__( + self, + n_items: int, + pretrained_embeddings: torch.Tensor, + n_factors: int = 256, + projection_hidden: int = 512, + n_blocks: int = 2, + n_heads: int = 1, + session_max_len: int = 200, + dropout: float = 0.1, + adaptor_dropout: float = 0.2, + adaptor_type: str = "pca", + use_adaptor_ffn: bool = True, + initializer_range: float = 0.02, + ffn_type: str = "conv1d", + ffn_expansion: int = 1, + ) -> None: + super().__init__() + self.n_items = n_items + self.n_factors = n_factors + self.session_max_len = session_max_len + self.n_blocks = n_blocks + self.adaptor_type = adaptor_type + self.use_adaptor_ffn = use_adaptor_ffn + self.initializer_range = initializer_range + + if not use_adaptor_ffn and adaptor_type != "pca": + raise ValueError("use_adaptor_ffn=False is only supported with adaptor_type='pca'") + + # ── Frozen pretrained embeddings ── + if pretrained_embeddings.ndim == 2: + pretrained_embeddings = pretrained_embeddings.unsqueeze(1) + self.register_buffer("frozen_emb", pretrained_embeddings) + self.n_variants = pretrained_embeddings.shape[1] + + qwen_dim = pretrained_embeddings.shape[2] + emb_for_init = pretrained_embeddings[1:, 0, :] # skip padding row + + # ── Adaptor ── + if adaptor_type == "pca": + self.whitening_bias = nn.Parameter(emb_for_init.mean(dim=0)) + if use_adaptor_ffn: + self.whitening_proj = nn.Parameter(self._pca_init(emb_for_init, projection_hidden)) + proj_dim = self.whitening_proj.shape[1] + self.head = _make_mlp(proj_dim, proj_dim, n_factors, adaptor_dropout) + else: + self.whitening_proj = nn.Parameter(self._pca_init(emb_for_init, n_factors)) + self.head = None + elif adaptor_type == "bn": + self.bn_input = nn.BatchNorm1d(qwen_dim) + self.bn_score = nn.BatchNorm1d(qwen_dim) + self.head = _make_mlp(qwen_dim, n_factors, n_factors, adaptor_dropout) + else: + raise ValueError(f"Unknown adaptor_type: {adaptor_type}") + + # ── Positional embedding + dropout ── + self.pos_emb = nn.Embedding(session_max_len, n_factors) + self.emb_dropout = nn.Dropout(dropout) + + # ── Transformer blocks (pre-norm) ── + self.attention_layernorms = nn.ModuleList() + self.attention_layers = nn.ModuleList() + self.forward_layernorms = nn.ModuleList() + self.forward_layers = nn.ModuleList() + self.last_layernorm = nn.LayerNorm(n_factors, eps=1e-12) + + for _ in range(n_blocks): + self.attention_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) + self.attention_layers.append(nn.MultiheadAttention(n_factors, n_heads, dropout, batch_first=True)) + self.forward_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) + self.forward_layers.append(make_ffn(n_factors, ffn_type, ffn_expansion, dropout)) + + self.apply(self._init_weights) + + # ── Init helpers ── + + @staticmethod + def _pca_init(embeddings: torch.Tensor, out_dim: int) -> torch.Tensor: + centered = embeddings - embeddings.mean(dim=0) + _, _, Vh = torch.linalg.svd(centered, full_matrices=False) + out_dim = min(out_dim, Vh.shape[0]) + return Vh[:out_dim].T.contiguous() + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # ── Adaptor ── + + def _adapt_input(self, x: torch.Tensor) -> torch.Tensor: + if self.adaptor_type == "pca": + projected = (x - self.whitening_bias) @ self.whitening_proj + return self.head(projected) if self.head is not None else projected + shape = x.shape + flat = x.view(-1, shape[-1]) + return self.head(self.bn_input(flat)).view(*shape[:-1], self.n_factors) + + def _adapt_score(self, x: torch.Tensor) -> torch.Tensor: + if self.adaptor_type == "pca": + projected = (x - self.whitening_bias) @ self.whitening_proj + return self.head(projected) if self.head is not None else projected + shape = x.shape + flat = x.view(-1, shape[-1]) + return self.head(self.bn_score(flat)).view(*shape[:-1], self.n_factors) + + def _sample_frozen(self, item_ids: torch.Tensor) -> torch.Tensor: + """Look up pretrained embeddings, sampling a random variant during training.""" + if self.n_variants == 1 or not self.training: + return self.frozen_emb[item_ids, 0] + vi = torch.randint(self.n_variants, item_ids.shape, device=item_ids.device) + vi = vi * (item_ids != 0).long() # padding always uses variant 0 + return self.frozen_emb[item_ids, vi] + + def project_all(self) -> torch.Tensor: + """Project all frozen embeddings (variant 0) through the score adaptor. + + Returns shape ``(n_items + 1, n_factors)``. + """ + return self._adapt_score(self.frozen_emb[:, 0]) + + # ── Encoder ── + + def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1) + + def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + B, L = input_ids.shape + positions = torch.arange(L, device=input_ids.device).unsqueeze(0) + seqs = seqs + self.pos_emb(positions) + seqs = self.emb_dropout(seqs) + + pad_mask = input_ids == self.PADDING_IDX # (B, L) + pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1) + seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding + + attn_mask = self._causal_mask(L, seqs.device) + key_padding_mask = pad_mask + + for i in range(self.n_blocks): + normed = self.attention_layernorms[i](seqs) + # Zero padding in Q/K/V so NaN can never appear in dot-products + normed = normed.masked_fill(pad_mask_3d, 0.0) + mha_out, _ = self.attention_layers[i]( + normed, + normed, + normed, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + ) + # masked_fill handles NaN*0 correctly (unlike multiplication) + seqs = (seqs + mha_out).masked_fill(pad_mask_3d, 0.0) + seqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs)) + seqs = seqs.masked_fill(pad_mask_3d, 0.0) + + return self.last_layernorm(seqs) + + # ── Public forward / encode ── + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Encode a sequence of item IDs through the adaptor + transformer. + + Parameters + ---------- + input_ids : LongTensor (B, L) + Left-padded item ID sequences (0 = padding). + + Returns + ------- + Tensor (B, L, n_factors) + """ + seqs = self._adapt_input(self._sample_frozen(input_ids)) + return self._encode(seqs, input_ids) + + def encode_last(self, input_ids: torch.Tensor) -> torch.Tensor: + """Encode and return the last-position representation (B, D).""" + h = self.forward(input_ids) # (B, L, D) + return h[:, -1, :] # left-padded → last position is always the rightmost diff --git a/tests/fast_transformers/__init__.py b/tests/fast_transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fast_transformers/test_metrics.py b/tests/fast_transformers/test_metrics.py new file mode 100644 index 00000000..5b90f52f --- /dev/null +++ b/tests/fast_transformers/test_metrics.py @@ -0,0 +1,288 @@ +"""Tests for GPU-friendly ranking metrics. + +Tests verify: + a) correctness on hand-crafted examples + b) exact match with RecTools metrics (HitRate, NDCG, MRR) +""" + +import numpy as np +import pandas as pd +import pytest +import torch + +from rectools import Columns +from rectools.fast_transformers.metrics import ( + compute_metrics, + hitrate_at_k, + mrr_at_k, + ndcg_at_k, +) +from rectools.metrics import MRR, NDCG, HitRate + +# --------------------------------------------------------------------------- +# Helpers to bridge tensor metrics <-> RecTools DataFrame metrics +# --------------------------------------------------------------------------- + + +def _build_rectools_inputs( + topk_ids: torch.Tensor, + targets: torch.Tensor, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Convert tensors to RecTools reco / interactions DataFrames.""" + B, K = topk_ids.shape + users, items, ranks = [], [], [] + for u in range(B): + for r in range(K): + users.append(u) + items.append(topk_ids[u, r].item()) + ranks.append(r + 1) + reco = pd.DataFrame( + { + Columns.User: users, + Columns.Item: items, + Columns.Rank: ranks, + } + ) + interactions = pd.DataFrame( + { + Columns.User: list(range(B)), + Columns.Item: targets.tolist(), + } + ) + return reco, interactions + + +# --------------------------------------------------------------------------- +# HitRate +# --------------------------------------------------------------------------- + + +class TestHitRate: + def test_all_hits(self) -> None: + topk = torch.tensor([[5, 2, 3], [1, 7, 9]]) + targets = torch.tensor([5, 7]) + assert hitrate_at_k(topk, targets).item() == pytest.approx(1.0) + + def test_no_hits(self) -> None: + topk = torch.tensor([[5, 2, 3], [1, 7, 9]]) + targets = torch.tensor([99, 88]) + assert hitrate_at_k(topk, targets).item() == pytest.approx(0.0) + + def test_partial_hits(self) -> None: + topk = torch.tensor([[5, 2, 3], [1, 7, 9]]) + targets = torch.tensor([5, 88]) + assert hitrate_at_k(topk, targets).item() == pytest.approx(0.5) + + def test_hit_at_last_position(self) -> None: + topk = torch.tensor([[1, 2, 3]]) + targets = torch.tensor([3]) + assert hitrate_at_k(topk, targets).item() == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# NDCG +# --------------------------------------------------------------------------- + + +class TestNDCG: + def test_perfect_ranking(self) -> None: + """Target at rank 1 => DCG = 1/log2(2) = 1.0, NDCG = 1/IDCG * 1.0.""" + topk = torch.tensor([[5]]) + targets = torch.tensor([5]) + # k=1: IDCG = 1/log2(2) = 1.0, DCG = 1.0, NDCG = 1.0 + assert ndcg_at_k(topk, targets).item() == pytest.approx(1.0) + + def test_no_hit(self) -> None: + topk = torch.tensor([[1, 2, 3]]) + targets = torch.tensor([99]) + assert ndcg_at_k(topk, targets).item() == pytest.approx(0.0) + + def test_hit_at_position_2(self) -> None: + """Target at rank 2 out of k=3.""" + topk = torch.tensor([[1, 5, 3]]) + targets = torch.tensor([5]) + # DCG = 1/log2(3), IDCG = 1/log2(2) + 1/log2(3) + 1/log2(4) + dcg = 1.0 / np.log2(3) + idcg = 1.0 / np.log2(2) + 1.0 / np.log2(3) + 1.0 / np.log2(4) + expected = dcg / idcg + assert ndcg_at_k(topk, targets).item() == pytest.approx(expected, abs=1e-6) + + def test_log_base_10(self) -> None: + topk = torch.tensor([[5, 1]]) + targets = torch.tensor([5]) + dcg = 1.0 / np.log10(2) + idcg = 1.0 / np.log10(2) + 1.0 / np.log10(3) + expected = dcg / idcg + assert ndcg_at_k(topk, targets, log_base=10).item() == pytest.approx(expected, abs=1e-6) + + +# --------------------------------------------------------------------------- +# MRR +# --------------------------------------------------------------------------- + + +class TestMRR: + def test_hit_at_rank_1(self) -> None: + topk = torch.tensor([[5, 2, 3]]) + targets = torch.tensor([5]) + assert mrr_at_k(topk, targets).item() == pytest.approx(1.0) + + def test_hit_at_rank_3(self) -> None: + topk = torch.tensor([[1, 2, 5]]) + targets = torch.tensor([5]) + assert mrr_at_k(topk, targets).item() == pytest.approx(1.0 / 3) + + def test_no_hit(self) -> None: + topk = torch.tensor([[1, 2, 3]]) + targets = torch.tensor([99]) + assert mrr_at_k(topk, targets).item() == pytest.approx(0.0) + + def test_multiple_users(self) -> None: + topk = torch.tensor([[5, 2, 3], [1, 2, 7]]) + targets = torch.tensor([5, 7]) + # user 0: 1/1, user 1: 1/3 + expected = (1.0 + 1.0 / 3) / 2 + assert mrr_at_k(topk, targets).item() == pytest.approx(expected) + + +# --------------------------------------------------------------------------- +# compute_metrics +# --------------------------------------------------------------------------- + + +class TestComputeMetrics: + def test_default_k(self) -> None: + topk = torch.tensor([[5, 2], [1, 7]]) + targets = torch.tensor([5, 99]) + result = compute_metrics(topk, targets) + assert "HR@2" in result + assert "NDCG@2" in result + assert "MRR@2" in result + + def test_multiple_ks(self) -> None: + topk = torch.tensor([[5, 2, 3, 4], [1, 7, 9, 8]]) + targets = torch.tensor([5, 9]) + result = compute_metrics(topk, targets, ks=[1, 2, 4]) + assert "HR@1" in result and "HR@2" in result and "HR@4" in result + + def test_k_exceeds_width_raises(self) -> None: + topk = torch.tensor([[5, 2]]) + targets = torch.tensor([5]) + with pytest.raises(ValueError, match="exceeds"): + compute_metrics(topk, targets, ks=[5]) + + +# --------------------------------------------------------------------------- +# Cross-validation with RecTools metrics +# --------------------------------------------------------------------------- + + +class TestMatchRecTools: + """Verify that our GPU metrics produce identical results to RecTools.""" + + @pytest.fixture() + def scenario_mixed(self) -> tuple[torch.Tensor, torch.Tensor]: + """4 users, k=5. Mix of hits at various ranks and misses.""" + topk = torch.tensor( + [ + [10, 20, 30, 40, 50], # target=30, hit at rank 3 + [11, 21, 31, 41, 51], # target=99, no hit + [12, 22, 32, 42, 52], # target=12, hit at rank 1 + [13, 23, 33, 43, 53], # target=53, hit at rank 5 + ] + ) + targets = torch.tensor([30, 99, 12, 53]) + return topk, targets + + @pytest.fixture() + def scenario_all_hit(self) -> tuple[torch.Tensor, torch.Tensor]: + topk = torch.tensor( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + ) + targets = torch.tensor([2, 4, 9]) + return topk, targets + + @pytest.fixture() + def scenario_no_hit(self) -> tuple[torch.Tensor, torch.Tensor]: + topk = torch.tensor([[1, 2, 3], [4, 5, 6]]) + targets = torch.tensor([99, 88]) + return topk, targets + + @pytest.mark.parametrize("fixture_name", ["scenario_mixed", "scenario_all_hit", "scenario_no_hit"]) + def test_hitrate_matches_rectools(self, fixture_name: str, request: pytest.FixtureRequest) -> None: + topk, targets = request.getfixturevalue(fixture_name) + k = topk.shape[1] + reco, interactions = _build_rectools_inputs(topk, targets) + + ours = hitrate_at_k(topk, targets).item() + theirs = HitRate(k=k).calc(reco, interactions) + assert ours == pytest.approx(theirs, abs=1e-7), f"HR@{k}: ours={ours}, rectools={theirs}" + + @pytest.mark.parametrize("fixture_name", ["scenario_mixed", "scenario_all_hit", "scenario_no_hit"]) + def test_ndcg_matches_rectools(self, fixture_name: str, request: pytest.FixtureRequest) -> None: + topk, targets = request.getfixturevalue(fixture_name) + k = topk.shape[1] + reco, interactions = _build_rectools_inputs(topk, targets) + + ours = ndcg_at_k(topk, targets).item() + theirs = NDCG(k=k).calc(reco, interactions) + assert ours == pytest.approx(theirs, abs=1e-7), f"NDCG@{k}: ours={ours}, rectools={theirs}" + + @pytest.mark.parametrize("fixture_name", ["scenario_mixed", "scenario_all_hit", "scenario_no_hit"]) + def test_mrr_matches_rectools(self, fixture_name: str, request: pytest.FixtureRequest) -> None: + topk, targets = request.getfixturevalue(fixture_name) + k = topk.shape[1] + reco, interactions = _build_rectools_inputs(topk, targets) + + ours = mrr_at_k(topk, targets).item() + theirs = MRR(k=k).calc(reco, interactions) + assert ours == pytest.approx(theirs, abs=1e-7), f"MRR@{k}: ours={ours}, rectools={theirs}" + + @pytest.mark.parametrize("fixture_name", ["scenario_mixed", "scenario_all_hit", "scenario_no_hit"]) + def test_all_ks_match_rectools(self, fixture_name: str, request: pytest.FixtureRequest) -> None: + """Test at multiple K values to make sure slicing is correct.""" + topk, targets = request.getfixturevalue(fixture_name) + k_max = topk.shape[1] + ks = list(range(1, k_max + 1)) + + reco, interactions = _build_rectools_inputs(topk, targets) + + ours = compute_metrics(topk, targets, ks=ks) + for k in ks: + rt_hr = HitRate(k=k).calc(reco, interactions) + rt_ndcg = NDCG(k=k).calc(reco, interactions) + rt_mrr = MRR(k=k).calc(reco, interactions) + assert ours[f"HR@{k}"] == pytest.approx(rt_hr, abs=1e-7), f"HR@{k}" + assert ours[f"NDCG@{k}"] == pytest.approx(rt_ndcg, abs=1e-7), f"NDCG@{k}" + assert ours[f"MRR@{k}"] == pytest.approx(rt_mrr, abs=1e-7), f"MRR@{k}" + + def test_random_large_batch(self) -> None: + """Randomized test with 500 users, k=20.""" + torch.manual_seed(42) + B, K = 500, 20 + n_items = 1000 + topk = torch.randint(1, n_items, (B, K)) + targets = torch.randint(1, n_items, (B,)) + # Ensure some hits by placing target at random positions + for i in range(0, B, 3): + pos = torch.randint(0, K, (1,)).item() + topk[i, pos] = targets[i] + + reco, interactions = _build_rectools_inputs(topk, targets) + + for k in [1, 5, 10, 20]: + our_hr = hitrate_at_k(topk[:, :k], targets).item() + our_ndcg = ndcg_at_k(topk[:, :k], targets).item() + our_mrr = mrr_at_k(topk[:, :k], targets).item() + + rt_hr = HitRate(k=k).calc(reco, interactions) + rt_ndcg = NDCG(k=k).calc(reco, interactions) + rt_mrr = MRR(k=k).calc(reco, interactions) + + assert our_hr == pytest.approx(rt_hr, abs=1e-6), f"HR@{k}" + assert our_ndcg == pytest.approx(rt_ndcg, abs=1e-6), f"NDCG@{k}" + assert our_mrr == pytest.approx(rt_mrr, abs=1e-6), f"MRR@{k}" diff --git a/tests/fast_transformers/test_net.py b/tests/fast_transformers/test_net.py new file mode 100644 index 00000000..8a3e9c7d --- /dev/null +++ b/tests/fast_transformers/test_net.py @@ -0,0 +1,45 @@ +"""Tests for FlatSASRec network.""" + +import pytest +import torch + +from rectools.fast_transformers.net import FlatSASRec + + +@pytest.fixture() +def net() -> FlatSASRec: + return FlatSASRec(n_items=30, n_factors=16, n_blocks=1, n_heads=2, session_max_len=8, dropout=0.0) + + +class TestFlatSASRec: + def test_full_catalog_logits_shape(self, net: FlatSASRec) -> None: + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + logits = net(batch) + assert logits.shape == (2, 5, 30) # (B, L, n_items) + + def test_candidate_logits_shape(self, net: FlatSASRec) -> None: + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 30, (2, 5, 3)), + } + logits = net(batch) + assert logits.shape == (2, 5, 4) # (B, L, 1 + n_neg) + + def test_encode_last_shape(self, net: FlatSASRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3]]) + emb = net.encode_last(x) + assert emb.shape == (1, 16) + + def test_determinism(self, net: FlatSASRec) -> None: + """Same input produces identical output across two forward passes.""" + net.eval() + x_a = torch.tensor([[0, 0, 0, 5, 10]]) + x_b = torch.tensor([[0, 0, 0, 5, 10]]) + with torch.no_grad(): + e_a = net.encode_last(x_a) + e_b = net.encode_last(x_b) + torch.testing.assert_close(e_a, e_b) diff --git a/tests/fast_transformers/test_onnx_export.py b/tests/fast_transformers/test_onnx_export.py new file mode 100644 index 00000000..5eed9c72 --- /dev/null +++ b/tests/fast_transformers/test_onnx_export.py @@ -0,0 +1,251 @@ +"""Tests for ONNX export of UniSRec network and UniSRecModel.export_to_onnx.""" + +from pathlib import Path + +import numpy as np +import pytest +import torch + +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + +from rectools.fast_transformers.unisrec.model import UniSRecModel # noqa: E402 +from rectools.fast_transformers.unisrec.net import UniSRec # noqa: E402 + + +@pytest.fixture() +def net() -> UniSRec: + torch.manual_seed(0) + pretrained = torch.randn(11, 32) + pretrained[0] = 0.0 + model = UniSRec( + n_items=10, + pretrained_embeddings=pretrained, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + model.eval() + return model + + +def _export_and_load(net: torch.nn.Module, args, tmp_path: Path, **kwargs): + path = str(tmp_path / "model.onnx") + torch.onnx.export(net, args, path, opset_version=18, **kwargs) + model = onnx.load(path) + onnx.checker.check_model(model) + return ort.InferenceSession(path) + + +class TestUniSRecOnnxExport: + def test_export_succeeds(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + path = str(tmp_path / "model.onnx") + torch.onnx.export( + net, + (dummy,), + path, + input_names=["input_ids"], + output_names=["hidden"], + opset_version=18, + ) + model = onnx.load(path) + onnx.checker.check_model(model) + + def test_forward_roundtrip(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy,), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + with torch.no_grad(): + expected = net(dummy).numpy() + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + @pytest.mark.xfail(reason="dynamic_shapes requires dynamo=True which is not used here") + def test_dynamic_batch(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + batch = torch.export.Dim("batch", min=1) + sess = _export_and_load( + net, + (dummy,), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + dynamic_shapes=({0: batch},), + ) + batch_input = torch.tensor( + [[0, 0, 1, 2, 3], [0, 1, 4, 5, 6], [0, 0, 0, 7, 8]], + dtype=torch.long, + ) + with torch.no_grad(): + expected = net(batch_input).numpy() + result = sess.run(None, {"input_ids": batch_input.numpy()})[0] + assert result.shape[0] == 3 + np.testing.assert_allclose(result, expected, atol=1e-5) + + @pytest.mark.xfail(reason="dynamic_shapes requires dynamo=True which is not used here") + def test_different_sequence_lengths(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + batch = torch.export.Dim("batch", min=1) + seq_len = torch.export.Dim("seq_len", min=1, max=8) + sess = _export_and_load( + net, + (dummy,), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + dynamic_shapes=({0: batch, 1: seq_len},), + ) + short = torch.tensor([[0, 1, 2]], dtype=torch.long) + with torch.no_grad(): + expected = net(short).numpy() + result = sess.run(None, {"input_ids": short.numpy()})[0] + assert result.shape == (1, 3, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_padding_only_input(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy,), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + all_pad = torch.zeros(1, 5, dtype=torch.long) + with torch.no_grad(): + expected = net(all_pad).numpy() + result = sess.run(None, {"input_ids": all_pad.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_output_shape(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy,), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + assert result.shape == (1, 5, 16) + + def test_project_all_roundtrip(self, net: UniSRec, tmp_path: Path) -> None: + class _ProjectAll(torch.nn.Module): + def __init__(self, inner: UniSRec): + super().__init__() + self.inner = inner + + def forward(self) -> torch.Tensor: + return self.inner.project_all() + + wrapper = _ProjectAll(net) + wrapper.eval() + path = str(tmp_path / "project_all.onnx") + torch.onnx.export( + wrapper, + (), + path, + input_names=[], + output_names=["item_embs"], + opset_version=18, + ) + model = onnx.load(path) + onnx.checker.check_model(model) + sess = ort.InferenceSession(path) + with torch.no_grad(): + expected = net.project_all().numpy() + result = sess.run(None, {})[0] + assert result.shape == (11, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + +class TestUniSRecModelExport: + """Tests for UniSRecModel.export_to_onnx.""" + + @pytest.fixture() + def model(self) -> UniSRecModel: + torch.manual_seed(0) + pretrained = torch.randn(11, 32) + pretrained[0] = 0.0 + m = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + epochs=0, + ) + from rectools.fast_transformers.preprocessing.sequence_data import align_embeddings + + unique_items = torch.arange(1, 11) + aligned = align_embeddings(pretrained, unique_items, 10) + net = UniSRec( + n_items=10, + pretrained_embeddings=aligned, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + net.eval() + m._net = net + m._unique_items = unique_items + m._unique_users = torch.arange(5) + m.is_fitted = True + return m + + def test_export_encoder(self, model: UniSRecModel, tmp_path: Path) -> None: + path = tmp_path / "encoder.onnx" + model.export_to_onnx(str(path)) + loaded = onnx.load(str(path)) + onnx.checker.check_model(loaded) + + def test_export_encoder_roundtrip(self, model: UniSRecModel, tmp_path: Path) -> None: + path = tmp_path / "encoder.onnx" + model.export_to_onnx(str(path)) + sess = ort.InferenceSession(str(path)) + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + with torch.no_grad(): + expected = model.net(dummy).numpy() + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_export_encoder_and_items(self, model: UniSRecModel, tmp_path: Path) -> None: + enc_path = tmp_path / "encoder.onnx" + items_path = tmp_path / "items.onnx" + model.export_to_onnx(str(enc_path), items_path=str(items_path)) + + loaded_enc = onnx.load(str(enc_path)) + onnx.checker.check_model(loaded_enc) + loaded_items = onnx.load(str(items_path)) + onnx.checker.check_model(loaded_items) + + def test_items_roundtrip(self, model: UniSRecModel, tmp_path: Path) -> None: + items_path = tmp_path / "items.onnx" + model.export_to_onnx(str(tmp_path / "enc.onnx"), items_path=str(items_path)) + sess = ort.InferenceSession(str(items_path)) + with torch.no_grad(): + expected = model.net.project_all().numpy() + result = sess.run(None, {})[0] + assert result.shape == (11, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_unfitted_model_raises(self, tmp_path: Path) -> None: + pretrained = torch.randn(5, 8) + m = UniSRecModel(pretrained_item_embeddings=pretrained, n_factors=8) + with pytest.raises(AssertionError): + m.export_to_onnx(str(tmp_path / "model.onnx")) diff --git a/tests/fast_transformers/test_sequence_data.py b/tests/fast_transformers/test_sequence_data.py new file mode 100644 index 00000000..1fdd261f --- /dev/null +++ b/tests/fast_transformers/test_sequence_data.py @@ -0,0 +1,402 @@ +"""Tests for vectorized sequence building and data utilities.""" + +import torch + +from rectools.fast_transformers.preprocessing.sequence_data import ( + SequenceBatchDataset, + align_embeddings, + build_sequences, +) + +DEVICE = "cpu" + + +class TestBuildSequences: + """Tests for the build_sequences function.""" + + def test_basic_two_users(self) -> None: + """Two users with 3 interactions each, max_len=4.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert x.shape == (2, 4) + assert y.shape == (2, 4) + + # Items are mapped to internal 1-based IDs; 0 = padding + # unique_items is sorted, so: [10, 20, 30, 40, 50, 60] + # internal IDs: 10->1, 20->2, 30->3, 40->4, 50->5, 60->6 + + # User 0: items [10, 20, 30] in order => internal [1, 2, 3] + # x = [0, 1, 2] left-padded to len 4 => [0, 0, 1, 2] + # y = [0, 2, 3] left-padded to len 4 => [0, 0, 2, 3] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + # User 1: items [40, 50, 60] in order => internal [4, 5, 6] + # x = [0, 4, 5] => [0, 0, 4, 5] + # y = [0, 5, 6] => [0, 0, 5, 6] + assert x[1].tolist() == [0, 0, 4, 5] + assert y[1].tolist() == [0, 0, 5, 6] + + assert result_users.tolist() == [0, 1] + + def test_unique_items_mapping(self) -> None: + """unique_items should map internal_id - 1 => external_id.""" + user_ids = torch.tensor([0, 0, 0]) + item_ids = torch.tensor([100, 50, 200]) + timestamps = torch.tensor([1, 2, 3]) + + _, _, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE + ) + + # torch.unique sorts, so unique_items = [50, 100, 200] + assert unique_items.tolist() == [50, 100, 200] + + def test_min_interactions_filtering(self) -> None: + """Users with fewer than min_interactions should be dropped.""" + user_ids = torch.tensor([0, 0, 0, 1, 2, 2]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # User 1 has only 1 interaction => dropped + assert x.shape[0] == 2 + assert result_users.tolist() == [0, 2] + + def test_min_interactions_higher_threshold(self) -> None: + """Higher min_interactions threshold filters more aggressively.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80, 90]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=3, device=DEVICE + ) + + # User 0 has 3, User 1 has 2 (dropped), User 2 has 4 + assert x.shape[0] == 2 + assert result_users.tolist() == [0, 2] + + def test_all_users_filtered_out(self) -> None: + """When all users have fewer than min_interactions, return empty tensors.""" + user_ids = torch.tensor([0, 1, 2]) + item_ids = torch.tensor([10, 20, 30]) + timestamps = torch.tensor([1, 2, 3]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert x.shape == (0, 4) + assert y.shape == (0, 4) + assert len(result_users) == 0 + + def test_max_len_truncation(self) -> None: + """Sequences longer than max_len should be truncated, keeping the most recent items.""" + user_ids = torch.tensor([0, 0, 0, 0, 0]) + item_ids = torch.tensor([10, 20, 30, 40, 50]) + timestamps = torch.tensor([1, 2, 3, 4, 5]) + + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) + + # 5 items total. capped_lens = min(5, 3+1) = 4, effective = 3 + # Sorted items: 10->1, 20->2, 30->3, 40->4, 50->5 + # last 4 items for x/y windowing: items at positions [1..4] + # x takes [1,2,3] => internal [2,3,4]; y takes [2,3,4] => internal [3,4,5] + assert x.shape == (1, 3) + assert y.shape == (1, 3) + assert x[0].tolist() == [2, 3, 4] + assert y[0].tolist() == [3, 4, 5] + + def test_timestamp_ordering(self) -> None: + """Items should be ordered by timestamp regardless of input order.""" + user_ids = torch.tensor([0, 0, 0]) + item_ids = torch.tensor([30, 10, 20]) + timestamps = torch.tensor([3, 1, 2]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # unique_items (sorted by value): [10, 20, 30] => internal 1, 2, 3 + # By timestamp: 10(t=1), 20(t=2), 30(t=3) => internal [1, 2, 3] + # x = [0, 0, 1, 2] + # y = [0, 0, 2, 3] + assert unique_items.tolist() == [10, 20, 30] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + def test_left_padding(self) -> None: + """Sequences shorter than max_len should be left-padded with zeros.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([10, 20]) + timestamps = torch.tensor([1, 2]) + + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE) + + # 2 items => effective_len = 1 (capped_lens = 2, effective = 1) + # x = [0, 0, 0, 0, 1], y = [0, 0, 0, 0, 2] + assert x[0].tolist() == [0, 0, 0, 0, 1] + assert y[0].tolist() == [0, 0, 0, 0, 2] + + def test_result_users_preserves_external_ids(self) -> None: + """result_users should contain external user IDs, not internal indices.""" + user_ids = torch.tensor([100, 100, 100, 200, 200, 200]) + item_ids = torch.tensor([1, 2, 3, 4, 5, 6]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + _, _, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert result_users.tolist() == [100, 200] + + def test_shared_items_across_users(self) -> None: + """Same items used by different users should share internal IDs.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 20, 30, 40]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # unique_items: [10, 20, 30, 40] => internal 1, 2, 3, 4 + assert unique_items.tolist() == [10, 20, 30, 40] + + # User 0: 10(1), 20(2), 30(3) => x=[0, 1, 2], y=[0, 2, 3] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + # User 1: 20(2), 30(3), 40(4) => x=[0, 2, 3], y=[0, 3, 4] + assert x[1].tolist() == [0, 0, 2, 3] + assert y[1].tolist() == [0, 0, 3, 4] + + def test_output_device(self) -> None: + """All output tensors should be on the specified device.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([1, 2]) + timestamps = torch.tensor([1, 2]) + + x, y, unique_items, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE + ) + + assert x.device.type == DEVICE + assert y.device.type == DEVICE + assert unique_items.device.type == DEVICE + assert result_users.device.type == DEVICE + + def test_output_dtypes(self) -> None: + """x and y should be long tensors.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([1, 2]) + timestamps = torch.tensor([1, 2]) + + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) + + assert x.dtype == torch.long + assert y.dtype == torch.long + + def test_exact_max_len_sequence(self) -> None: + """Sequence with exactly max_len + 1 items should fill entire x and y.""" + user_ids = torch.tensor([0, 0, 0, 0]) + item_ids = torch.tensor([10, 20, 30, 40]) + timestamps = torch.tensor([1, 2, 3, 4]) + + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) + + # 4 items, max_len=3 => capped_lens = min(4, 4) = 4, effective = 3 + # No padding needed + assert 0 not in x[0].tolist() + assert 0 not in y[0].tolist() + + def test_multiple_users_different_lengths(self) -> None: + """Users with different sequence lengths should be properly handled.""" + user_ids = torch.tensor([0, 0, 1, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE + ) + + # unique_items: [10, 20, 30, 40, 50, 60] => internal 1..6 + # User 0: 2 items => effective=1 + # x[0] = [0, 0, 0, 0, 1], y[0] = [0, 0, 0, 0, 2] + assert x[0].tolist() == [0, 0, 0, 0, 1] + assert y[0].tolist() == [0, 0, 0, 0, 2] + + # User 1: 4 items => effective=3 + # x[1] = [0, 0, 3, 4, 5], y[1] = [0, 0, 4, 5, 6] + assert x[1].tolist() == [0, 0, 3, 4, 5] + assert y[1].tolist() == [0, 0, 4, 5, 6] + + +class TestAlignEmbeddings: + """Tests for the align_embeddings function.""" + + def test_2d_pretrained(self) -> None: + """Align 2D pretrained embeddings to internal ID order.""" + pretrained = torch.tensor( + [ + [1.0, 2.0], # external item 0 + [3.0, 4.0], # external item 1 + [5.0, 6.0], # external item 2 + [7.0, 8.0], # external item 3 + ] + ) + # unique_items: external IDs that map to internal IDs 1, 2, 3 + unique_items = torch.tensor([2, 0, 3]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 2) # n_items + 1 + # Row 0 (padding) should be zeros + assert aligned[0].tolist() == [0.0, 0.0] + # Internal ID 1 => external ID 2 => pretrained[2] = [5, 6] + assert aligned[1].tolist() == [5.0, 6.0] + # Internal ID 2 => external ID 0 => pretrained[0] = [1, 2] + assert aligned[2].tolist() == [1.0, 2.0] + # Internal ID 3 => external ID 3 => pretrained[3] = [7, 8] + assert aligned[3].tolist() == [7.0, 8.0] + + def test_3d_pretrained(self) -> None: + """Align 3D pretrained embeddings (multi-variant).""" + pretrained = torch.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], # item 0, 2 variants + [[5.0, 6.0], [7.0, 8.0]], # item 1 + ] + ) + unique_items = torch.tensor([1, 0]) + n_items = 2 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (3, 2, 2) # (n_items+1, n_variants, dim) + # Row 0 (padding) should be zeros + torch.testing.assert_close(aligned[0], torch.zeros(2, 2)) + # Internal ID 1 => external ID 1 + torch.testing.assert_close(aligned[1], pretrained[1]) + # Internal ID 2 => external ID 0 + torch.testing.assert_close(aligned[2], pretrained[0]) + + def test_padding_row_is_zero(self) -> None: + """The first row (padding, internal ID 0) should always be zeros.""" + pretrained = torch.randn(10, 8) + unique_items = torch.tensor([0, 1, 2]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + torch.testing.assert_close(aligned[0], torch.zeros(8)) + + def test_out_of_range_indices(self) -> None: + """Items with external IDs outside pretrained range should get zero embeddings.""" + pretrained = torch.tensor( + [ + [1.0, 2.0], # external 0 + [3.0, 4.0], # external 1 + ] + ) + # External ID 5 is out of range (pretrained has only 2 rows) + unique_items = torch.tensor([0, 5, 1]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 2) + # Internal 1 => external 0 => valid + assert aligned[1].tolist() == [1.0, 2.0] + # Internal 2 => external 5 => out of range => zeros + assert aligned[2].tolist() == [0.0, 0.0] + # Internal 3 => external 1 => valid + assert aligned[3].tolist() == [3.0, 4.0] + + def test_negative_indices_handled(self) -> None: + """Negative external IDs should be treated as invalid and get zeros.""" + pretrained = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + unique_items = torch.tensor([-1, 0]) + n_items = 2 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (3, 2) + # Internal 1 => external -1 => invalid => zeros + assert aligned[1].tolist() == [0.0, 0.0] + # Internal 2 => external 0 => valid + assert aligned[2].tolist() == [1.0, 2.0] + + def test_output_shape_matches_n_items_plus_one(self) -> None: + """Output shape should be (n_items + 1, D) regardless of unique_items length.""" + pretrained = torch.randn(20, 4) + unique_items = torch.tensor([3, 7, 15]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 4) + + +class TestSequenceBatchDataset: + """Tests for SequenceBatchDataset.""" + + def test_length(self) -> None: + x = torch.zeros(5, 3) + y = torch.zeros(5, 3) + ds = SequenceBatchDataset(x, y) + assert len(ds) == 5 + + def test_getitem_returns_dict(self) -> None: + x = torch.tensor([[1, 2, 3], [4, 5, 6]]) + y = torch.tensor([[7, 8, 9], [10, 11, 12]]) + ds = SequenceBatchDataset(x, y) + + batch = ds[0] + assert isinstance(batch, dict) + assert "x" in batch + assert "y" in batch + assert batch["x"].tolist() == [1, 2, 3] + assert batch["y"].tolist() == [7, 8, 9] + + def test_getitem_second_element(self) -> None: + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6], [7, 8]]) + ds = SequenceBatchDataset(x, y) + + batch = ds[1] + assert batch["x"].tolist() == [3, 4] + assert batch["y"].tolist() == [7, 8] + + def test_transform_applied(self) -> None: + x = torch.tensor([[1, 2]]) + y = torch.tensor([[3, 4]]) + + def double_x(batch: dict) -> dict: + batch["x"] = batch["x"] * 2 + return batch + + ds = SequenceBatchDataset(x, y, transform=double_x) + batch = ds[0] + assert batch["x"].tolist() == [2, 4] + assert batch["y"].tolist() == [3, 4] + + def test_no_transform(self) -> None: + x = torch.tensor([[10, 20]]) + y = torch.tensor([[30, 40]]) + ds = SequenceBatchDataset(x, y, transform=None) + + batch = ds[0] + assert batch["x"].tolist() == [10, 20] + assert batch["y"].tolist() == [30, 40] diff --git a/tests/fast_transformers/test_unisrec_lightning.py b/tests/fast_transformers/test_unisrec_lightning.py new file mode 100644 index 00000000..3fc81237 --- /dev/null +++ b/tests/fast_transformers/test_unisrec_lightning.py @@ -0,0 +1,466 @@ +"""Tests for UniSRecLightning wrapper and _cosine_warmup_scheduler.""" + +import math +import typing as tp + +import pytest +import torch + +from rectools.fast_transformers.unisrec.lightning import ( + SUPPORTED_LOSSES, + SUPPORTED_OPTIMIZERS, + SUPPORTED_SCHEDULERS, + UniSRecLightning, + _cosine_warmup_scheduler, +) +from rectools.fast_transformers.unisrec.net import UniSRec + + +@pytest.fixture() +def pretrained_emb() -> torch.Tensor: + """Fake pretrained embeddings: (11, 32) -- 10 items + 1 padding.""" + torch.manual_seed(0) + emb = torch.randn(11, 32) + emb[0] = 0.0 # padding + return emb + + +@pytest.fixture() +def net(pretrained_emb: torch.Tensor) -> UniSRec: + return UniSRec( + n_items=10, + pretrained_embeddings=pretrained_emb, + n_factors=8, + projection_hidden=16, + n_blocks=1, + n_heads=1, + session_max_len=5, + dropout=0.0, + adaptor_dropout=0.0, + ) + + +def _make_module( + net: UniSRec, + loss: str = "softmax", + n_negatives: tp.Optional[int] = None, + optimizer: str = "adamw", + scheduler: tp.Optional[str] = None, + total_steps: tp.Optional[int] = None, + lr: float = 1e-3, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, + gbce_t: float = 0.2, +) -> UniSRecLightning: + """Build a UniSRecLightning with a single param group.""" + param_groups = [{"params": list(net.parameters()), "lr": lr}] + return UniSRecLightning( + net=net, + param_groups=param_groups, + loss=loss, + n_negatives=n_negatives, + gbce_t=gbce_t, + optimizer=optimizer, + scheduler=scheduler, + warmup_ratio=warmup_ratio, + min_lr_ratio=min_lr_ratio, + total_steps=total_steps, + ) + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + + +class TestConstants: + def test_supported_losses(self) -> None: + assert SUPPORTED_LOSSES == ("softmax", "BCE", "gBCE", "sampled_softmax") + + def test_supported_optimizers(self) -> None: + assert SUPPORTED_OPTIMIZERS == ("adam", "adamw") + + def test_supported_schedulers(self) -> None: + assert SUPPORTED_SCHEDULERS == (None, "cosine_warmup") + + +# --------------------------------------------------------------------------- +# configure_optimizers +# --------------------------------------------------------------------------- + + +class TestConfigureOptimizers: + def test_adam_returns_adam(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="adam") + result = module.configure_optimizers() + assert isinstance(result, torch.optim.Adam) + + def test_adamw_returns_adamw(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="adamw") + result = module.configure_optimizers() + assert isinstance(result, torch.optim.AdamW) + + def test_no_scheduler_returns_optimizer_only(self, net: UniSRec) -> None: + module = _make_module(net, scheduler=None) + result = module.configure_optimizers() + # When scheduler is None, returns just the optimizer (not a dict) + assert isinstance(result, torch.optim.Optimizer) + + def test_cosine_warmup_returns_dict(self, net: UniSRec) -> None: + module = _make_module(net, scheduler="cosine_warmup", total_steps=100) + result = module.configure_optimizers() + assert isinstance(result, dict) + assert "optimizer" in result + assert "lr_scheduler" in result + assert result["lr_scheduler"]["interval"] == "step" + + def test_unknown_optimizer_raises(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="sgd") + with pytest.raises(ValueError, match="Unknown optimizer"): + module.configure_optimizers() + + def test_unknown_scheduler_raises(self, net: UniSRec) -> None: + module = _make_module(net, scheduler="step_lr") + with pytest.raises(ValueError, match="Unknown scheduler"): + module.configure_optimizers() + + def test_cosine_warmup_total_steps_default(self, net: UniSRec) -> None: + """When total_steps is None, it defaults to 1.""" + module = _make_module(net, scheduler="cosine_warmup", total_steps=None) + result = module.configure_optimizers() + assert isinstance(result, dict) + + def test_optimizer_lr(self, net: UniSRec) -> None: + lr = 5e-4 + module = _make_module(net, optimizer="adam", lr=lr) + opt = module.configure_optimizers() + assert opt.param_groups[0]["lr"] == lr + + +# --------------------------------------------------------------------------- +# _cosine_warmup_scheduler +# --------------------------------------------------------------------------- + + +class TestCosineWarmupScheduler: + def test_lr_at_step_zero_is_zero(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=10, total_steps=100, min_lr_ratio=0.0) + # LambdaLR stores the lambda; get factor for step 0 + lr_factor = scheduler.lr_lambdas[0](0) + assert lr_factor == 0.0 + + def test_lr_during_warmup_is_linear(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + warmup_steps = 10 + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=warmup_steps, total_steps=100) + lr_fn = scheduler.lr_lambdas[0] + for step in range(1, warmup_steps): + assert lr_fn(step) == pytest.approx(step / warmup_steps) + + def test_lr_at_warmup_end_is_one(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=10, total_steps=100) + lr_fn = scheduler.lr_lambdas[0] + # At warmup_steps, progress = 0, cos(0) = 1 => factor = 1.0 + assert lr_fn(10) == pytest.approx(1.0) + + def test_lr_at_end_equals_min_lr_ratio(self) -> None: + min_lr_ratio = 0.1 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, + warmup_steps=10, + total_steps=100, + min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + # At total_steps, progress = 1, cos(pi) = -1 => factor = min_lr_ratio + assert lr_fn(100) == pytest.approx(min_lr_ratio) + + def test_lr_at_cosine_midpoint(self) -> None: + """At the midpoint of the cosine phase, factor should be (1 + min_lr_ratio) / 2.""" + warmup_steps = 10 + total_steps = 110 + min_lr_ratio = 0.0 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, + warmup_steps=warmup_steps, + total_steps=total_steps, + min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + midpoint = warmup_steps + (total_steps - warmup_steps) // 2 # 60 + # progress = 0.5 => cos(pi/2) = 0 => factor = 0.5 + expected = min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * 0.5)) + assert lr_fn(midpoint) == pytest.approx(expected, abs=1e-6) + + def test_lr_with_nonzero_min_lr_ratio(self) -> None: + min_lr_ratio = 0.3 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, + warmup_steps=0, + total_steps=100, + min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + # At step 0 (warmup_steps=0, so cosine phase), progress=0, cos(0)=1 => factor=1.0 + assert lr_fn(0) == pytest.approx(1.0) + # At total_steps => factor = min_lr_ratio + assert lr_fn(100) == pytest.approx(min_lr_ratio) + + def test_returns_lambda_lr(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=5, total_steps=50) + assert isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR) + + +# --------------------------------------------------------------------------- +# training_step +# --------------------------------------------------------------------------- + + +class TestTrainingStep: + def test_softmax_returns_scalar(self, net: UniSRec) -> None: + module = _make_module(net, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0, "Loss should be a scalar" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be Inf" + + def test_softmax_positive_loss(self, net: UniSRec) -> None: + module = _make_module(net, loss="softmax") + batch = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.item() > 0, "Cross-entropy loss should be positive" + + def test_bce_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, loss="BCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_gbce_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, loss="gBCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_sampled_softmax_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, loss="sampled_softmax", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_softmax_ignores_negatives_when_present(self, net: UniSRec) -> None: + """Softmax loss uses full softmax even when negatives are provided.""" + module_no_neg = _make_module(net, loss="softmax") + module_with_neg = _make_module(net, loss="softmax") + net.eval() + + batch_no_neg = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + } + batch_with_neg = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + "negatives": torch.randint(1, 10, (1, 5, 3)), + } + with torch.no_grad(): + loss_no_neg = module_no_neg.training_step(batch_no_neg, batch_idx=0) + loss_with_neg = module_with_neg.training_step(batch_with_neg, batch_idx=0) + torch.testing.assert_close(loss_no_neg, loss_with_neg) + + def test_all_padding_softmax(self, net: UniSRec) -> None: + """When all targets are padding, cross_entropy with ignore_index returns NaN.""" + module = _make_module(net, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 0, 0, 0]]), + "y": torch.tensor([[0, 0, 0, 0, 0]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert torch.isnan(loss) + + +# --------------------------------------------------------------------------- +# validation_step +# --------------------------------------------------------------------------- + + +class TestValidationStep: + def test_validation_returns_scalar(self, net: UniSRec) -> None: + module = _make_module(net, loss="softmax") + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[4], [8]]), # (B, 1) + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_validation_uses_last_hidden(self, net: UniSRec) -> None: + """Validation slices hidden to [:, -1:, :], so y shape (B, 1) works.""" + module = _make_module(net, loss="softmax") + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3]]), + "y": torch.tensor([[4]]), # single target per sequence + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + def test_validation_with_negatives(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, loss="BCE", n_negatives=n_negatives) + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[4], [8]]), + "negatives": torch.randint(1, 10, (2, 1, n_negatives)), + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + +# --------------------------------------------------------------------------- +# _calc_loss dispatch +# --------------------------------------------------------------------------- + + +class TestCalcLossDispatch: + def test_softmax_without_negatives_uses_full_softmax(self, net: UniSRec) -> None: + module = _make_module(net, loss="softmax") + hidden = torch.randn(2, 5, 8) + batch = { + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module._calc_loss(hidden, batch) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + def test_bce_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, loss="BCE") + hidden = torch.randn(2, 5, 8) + batch = { + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_gbce_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, loss="gBCE") + hidden = torch.randn(2, 5, 8) + batch = {"y": torch.tensor([[1, 2, 3, 4, 5]])} + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_sampled_softmax_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, loss="sampled_softmax") + hidden = torch.randn(1, 5, 8) + batch = {"y": torch.tensor([[1, 2, 3, 4, 5]])} + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_unknown_loss_raises(self, net: UniSRec) -> None: + module = _make_module(net, loss="mse") + hidden = torch.randn(1, 5, 8) + batch = { + "y": torch.tensor([[1, 2, 3, 4, 5]]), + "negatives": torch.randint(1, 10, (1, 5, 3)), + } + with pytest.raises(ValueError, match="Unknown loss"): + module._calc_loss(hidden, batch) + + +# --------------------------------------------------------------------------- +# _get_item_embs / _get_all_embs +# --------------------------------------------------------------------------- + + +class TestEmbeddingHelpers: + def test_get_item_embs(self, net: UniSRec) -> None: + module = _make_module(net) + item_ids = torch.tensor([[1, 2, 3]]) + embs = module._get_item_embs(item_ids) + assert embs.shape == (1, 3, 8) # (B, L, n_factors) + + def test_get_all_embs(self, net: UniSRec) -> None: + module = _make_module(net) + all_embs = module._get_all_embs() + assert all_embs.shape == (11, 8) # n_items + 1 + + def test_get_pos_neg_logits_shape(self, net: UniSRec) -> None: + module = _make_module(net) + hidden = torch.randn(2, 5, 8) + labels = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + negatives = torch.randint(1, 10, (2, 5, 3)) + logits = module._get_pos_neg_logits(hidden, labels, negatives) + assert logits.shape == (2, 5, 4) # 1 positive + 3 negatives + + +# --------------------------------------------------------------------------- +# Init stores params +# --------------------------------------------------------------------------- + + +class TestInit: + def test_stores_all_attributes(self, net: UniSRec) -> None: + module = _make_module( + net, + loss="BCE", + n_negatives=5, + optimizer="adam", + scheduler="cosine_warmup", + total_steps=200, + warmup_ratio=0.1, + min_lr_ratio=0.05, + gbce_t=0.3, + ) + assert module.loss_name == "BCE" + assert module.n_negatives == 5 + assert module.optimizer_name == "adam" + assert module.scheduler_name == "cosine_warmup" + assert module.total_steps == 200 + assert module.warmup_ratio == 0.1 + assert module.min_lr_ratio == 0.05 + assert module.gbce_t == 0.3 + assert module.net is net diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py new file mode 100644 index 00000000..3335921e --- /dev/null +++ b/tests/fast_transformers/test_unisrec_model.py @@ -0,0 +1,231 @@ +"""Tests for UniSRecModel (standalone, tensor-based API).""" + +import pytest +import torch + +from rectools.fast_transformers import UniSRecModel + + +def _make_embeddings(n_items: int = 25, dim: int = 64) -> torch.Tensor: + torch.manual_seed(0) + emb = torch.randn(n_items, dim) + emb[0] = 0.0 + return emb + + +def _make_interactions(n_users: int = 20, n_items: int = 25, seed: int = 42): + """Generate synthetic (user_ids, item_ids, timestamps) tensors.""" + rng = torch.Generator().manual_seed(seed) + users, items, timestamps = [], [], [] + for u in range(n_users): + n_inter = torch.randint(3, 8, (1,), generator=rng).item() + item_pool = torch.randperm(n_items, generator=rng)[:n_inter] + 1 # 1-based + for rank, item in enumerate(item_pool): + users.append(u) + items.append(item.item()) + timestamps.append(rank) + return ( + torch.tensor(users, dtype=torch.long), + torch.tensor(items, dtype=torch.long), + torch.tensor(timestamps, dtype=torch.long), + ) + + +def _make_model(**kwargs) -> UniSRecModel: + defaults = dict( + pretrained_item_embeddings=_make_embeddings(), + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + epochs=1, + batch_size=16, + verbose=0, + ) + defaults.update(kwargs) + return UniSRecModel(**defaults) + + +class TestFit: + def test_fit_returns_self(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model() + result = model.fit(user_ids, item_ids, timestamps) + assert result is model + + def test_is_fitted_after_fit(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model() + assert not model.is_fitted + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_net_accessible_after_fit(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model() + model.fit(user_ids, item_ids, timestamps) + net = model.net + assert net is not None + + def test_item_id_mapping_has_original_ids(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model() + model.fit(user_ids, item_ids, timestamps) + mapping = model.item_id_mapping + original_unique = torch.unique(item_ids) + assert set(mapping.tolist()) == set(original_unique.tolist()) + + def test_net_not_accessible_before_fit(self) -> None: + model = _make_model() + with pytest.raises(AssertionError): + _ = model.net + + +class TestLosses: + def test_softmax_loss(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="softmax", epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_bce_loss(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="BCE", n_negatives=3, epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_gbce_loss(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="gBCE", n_negatives=3, epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_sampled_softmax_loss(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="sampled_softmax", n_negatives=3, epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_bce_loss_with_patience(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="BCE", n_negatives=3, patience=2, epochs=3) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_gbce_loss_with_patience(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="gBCE", n_negatives=3, patience=2, epochs=3) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_sampled_softmax_loss_with_patience(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="sampled_softmax", n_negatives=3, patience=2, epochs=3) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_invalid_loss_raises(self) -> None: + with pytest.raises(ValueError, match="Unsupported loss"): + _make_model(loss="invalid") + + def test_n_negatives_zero_raises(self) -> None: + with pytest.raises(ValueError, match="positive integer"): + _make_model(loss="BCE", n_negatives=0) + + def test_n_negatives_negative_raises(self) -> None: + with pytest.raises(ValueError, match="positive integer"): + _make_model(loss="BCE", n_negatives=-1) + + def test_n_negatives_none_for_bce_raises(self) -> None: + with pytest.raises(ValueError, match="positive integer"): + _make_model(loss="BCE", n_negatives=None) + + +class TestOptimizer: + def test_adam(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(optimizer="adam", epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_adamw(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(optimizer="adamw", epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_invalid_optimizer_raises(self) -> None: + with pytest.raises(ValueError, match="Unsupported optimizer"): + _make_model(optimizer="sgd") + + +class TestScheduler: + def test_cosine_warmup(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(scheduler="cosine_warmup", warmup_ratio=0.1, epochs=2) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_invalid_scheduler_raises(self) -> None: + with pytest.raises(ValueError, match="Unsupported scheduler"): + _make_model(scheduler="step") + + +class TestCheckpoint: + def test_save_load_roundtrip(self, tmp_path) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(epochs=1) + model.fit(user_ids, item_ids, timestamps) + + ckpt_path = tmp_path / "model.pt" + model.save_checkpoint(ckpt_path) + + model2 = _make_model(epochs=1) + model2.load_checkpoint(ckpt_path, device="cpu") + assert model2.is_fitted + + mapping1 = model.item_id_mapping + mapping2 = model2.item_id_mapping + assert torch.equal(mapping1, mapping2) + + +class TestFFNTypes: + @pytest.mark.parametrize("ffn_type", ["conv1d", "linear_gelu", "linear_relu"]) + def test_ffn_type(self, ffn_type: str) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(ffn_type=ffn_type, ffn_expansion=2, epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + +class TestEarlyStopping: + def test_patience(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(patience=2, epochs=5) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + +class TestMapItemIds: + def test_dense_known_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(epochs=1) + model.fit(user_ids, item_ids, timestamps) + unique = model.item_id_mapping + result = model.map_item_ids(unique) + expected = torch.arange(1, len(unique) + 1, dtype=torch.long) + assert result.tolist() == expected.tolist() + + def test_dense_unknown_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(epochs=1) + model.fit(user_ids, item_ids, timestamps) + unknown = torch.tensor([9999, 8888], dtype=torch.long) + result = model.map_item_ids(unknown) + assert result.tolist() == [0, 0] + + def test_unfitted_raises(self) -> None: + model = _make_model() + with pytest.raises(AssertionError): + model.map_item_ids(torch.tensor([1, 2])) diff --git a/tests/fast_transformers/test_unisrec_net.py b/tests/fast_transformers/test_unisrec_net.py new file mode 100644 index 00000000..1825bb0e --- /dev/null +++ b/tests/fast_transformers/test_unisrec_net.py @@ -0,0 +1,93 @@ +"""Tests for UniSRec network.""" + +import pytest +import torch + +from rectools.fast_transformers.unisrec.net import UniSRec + + +@pytest.fixture() +def pretrained_emb() -> torch.Tensor: + """Fake pretrained embeddings: (31, 64) — 30 items + 1 padding.""" + torch.manual_seed(0) + emb = torch.randn(31, 64) + emb[0] = 0.0 # padding + return emb + + +@pytest.fixture() +def net(pretrained_emb: torch.Tensor) -> UniSRec: + return UniSRec( + n_items=30, + pretrained_embeddings=pretrained_emb, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + + +class TestUniSRecShapes: + def test_forward_shape(self, net: UniSRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]) + h = net(x) + assert h.shape == (2, 5, 16) + + def test_encode_last_shape(self, net: UniSRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3]]) + emb = net.encode_last(x) + assert emb.shape == (1, 16) + + def test_project_all_shape(self, net: UniSRec) -> None: + proj = net.project_all() + assert proj.shape == (31, 16) # n_items + 1 (with padding) + + +class TestUniSRecAdaptor: + def test_pca_no_ffn(self, pretrained_emb: torch.Tensor) -> None: + net = UniSRec( + n_items=30, + pretrained_embeddings=pretrained_emb, + n_factors=16, + n_blocks=1, + n_heads=2, + session_max_len=8, + adaptor_type="pca", + use_adaptor_ffn=False, + ) + proj = net.project_all() + assert proj.shape == (31, 16) + assert net.head is None + + def test_multi_variant(self) -> None: + torch.manual_seed(0) + emb = torch.randn(31, 3, 64) # 3 variants + emb[0] = 0.0 + net = UniSRec( + n_items=30, + pretrained_embeddings=emb, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + ) + assert net.n_variants == 3 + x = torch.tensor([[0, 0, 1, 2, 3]]) + h = net(x) + assert h.shape == (1, 5, 16) + + +class TestPaddingInvariance: + def test_determinism_and_padding_masking(self, net: UniSRec) -> None: + """Same input produces identical output; padding positions are zeroed.""" + net.eval() + x_a = torch.tensor([[0, 0, 0, 5, 10]]) + x_b = torch.tensor([[0, 0, 0, 5, 10]]) + with torch.no_grad(): + e_a = net.encode_last(x_a) + e_b = net.encode_last(x_b) + torch.testing.assert_close(e_a, e_b)