综合介绍
Gaze-LLE是一款基于大规模学习编码器的注视目标预测工具。该项目由Fiona Ryan、Ajay Bati、Sangmin Lee、Daniel Bolya、Judy Hoffman和James M. Rehg开发,旨在通过预训练的视觉基础模型(如DINOv2)实现高效的注视目标预测。Gaze-LLE的架构简洁,仅在冻结的预训练视觉编码器上学习轻量级的注视解码器,相比之前的工作,参数量减少了1-2个数量级,并且不需要额外的输入模态如深度和姿态信息。
功能列表
- 注视目标预测:基于预训练的视觉编码器进行注视目标的高效预测。
- 多人注视预测:支持对单张图像中的多个人进行注视预测。
- 预训练模型:提供多种预训练模型,支持不同的骨干网络和训练数据。
- 轻量级架构:仅在冻结的预训练视觉编码器上学习轻量级的注视解码器。
- 无额外输入模态:不需要额外的深度和姿态信息输入。
使用帮助
安装流程
- 克隆仓库:
git clone https://github.com/fkryan/gazelle.git
cd gazelle
- 创建虚拟环境并安装依赖:
conda env create -f environment.yml
conda activate gazelle
pip install -e .
- 可选:安装xformers以加速注意力计算(如果系统支持):
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
使用预训练模型
Gaze-LLE提供了多种预训练模型,用户可以根据需求下载并使用:
- gazelledinov2vitb14:基于DINOv2 ViT-B的模型,训练数据为GazeFollow。
- gazelledinov2vitl14:基于DINOv2 ViT-L的模型,训练数据为GazeFollow。
- gazelledinov2vitb14_inout:基于DINOv2 ViT-B的模型,训练数据为GazeFollow和VideoAttentionTarget。
- gazellelargevitl14_inout:基于DINOv2 ViT-L的模型,训练数据为GazeFollow和VideoAttentionTarget。
使用示例
- 在PyTorch Hub中加载模型:
import torch
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14')
- 在Google Colab中查看演示笔记本,了解如何检测图像中所有人的注视目标。
注视预测
Gaze-LLE支持多人的注视预测,即对单张图像进行一次编码,然后使用特征预测图像中多个人的注视目标。模型输出一个空间热图,表示场景中注视目标的位置概率,值范围为[0,1],其中1表示注视目标位置的最高概率。