Реализация модели SASRec — Self-Attentive Sequential Recommendation — на PyTorch с использованием PyTorch Lightning, Hydra, DVC, ONNX и MLflow.
- Клонируйте репозиторий:
git clone git@github.com:Sapr7/PyTorch-Implementation-for-SASRec.git
cd PyTorch-Implementation-for-SASRec- Создайте виртуальное окружение и установите зависимости через uv:
uv venv
uv pip install
⚠️ Требуется Python ≥ 3.13 и установленныйuvУстановить:pip install uv
- Загрузите данные с помощью DVC:
dvc pullЗапуск обучения:
python -m scripts.train📁 Используется конфигурация configs/train.yaml, которая ссылается на:
- общие параметры в
configs/config.yaml - модель SASRec
- данные:
data/ml-1m.txt - директория чекпойнтов:
checkpoints/ml-1m_default/
🛠 Примеры параметров:
batch_size: 128
lr: 0.001
maxlen: 50
num_epochs: 20Все метрики и параметры обучения логируются в директорию mlruns/.
Для запуска веб-интерфейса MLflow:
mlflow ui --backend-store-uri mlruns --port 5000Затем откройте в браузере:
http://localhost:5000
💡 Можно изменить порт или использовать
--host 0.0.0.0для удалённого доступа.
Файл test_input.txt, где каждая строка — пара user_id item_id:
1 10
1 20
1 35
2 50
2 60
python -m scripts.inference📁 Конфигурация configs/infer.yaml указывает:
input_txt: data/test_input.txt
output_path: output/recommendations.json
save_dir: checkpoints/ml-1m_default📝 Пример результата (recommendations.json):
{
"1": [42, 13, 17, 55, 88],
"2": [61, 44, 32, 79, 11]
}python -m scripts.convert_and_export📦 Что создаётся:
checkpoints/ml-1m_default/sasrec.onnxcheckpoints/ml-1m_default/sasrec.trt(если доступен TensorRT)checkpoints/ml-1m_default/last.ckpt
Эти файлы можно использовать в продакшн-инференсе без PyTorch.
├── scripts/ # train.py, inference.py, convert_and_export.py
├── model/ # SASRec, Lightning-модуль, attention-блоки
├── utils/ # загрузка данных, валидация, экспорт
├── configs/ # Hydra-конфиги
├── data/ # данные под DVC
├── checkpoints/ # чекпойнты и onnx/tensorrt модели
├── output/ # json с результатами предсказания
├── mlruns/ # метрики обучения (MLflow)
├── pyproject.toml # зависимости (uv)
├── uv-pyproject.lock # фиксированные версии
python >= 3.13uvtorch >= 2.7pytorch-lightning >= 2.5hydra-core,omegaconfmlflowdvc,dvc-gdriveonnx,onnxruntimetensorrt(опционально)pycuda(опционально)
MIT License.