-
Notifications
You must be signed in to change notification settings - Fork 0
/
zoo.py
152 lines (116 loc) · 4.82 KB
/
zoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
OpenAI CLIP wrapper for the FiftyOne Model Zoo.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
import logging
import os
from packaging.version import Version
import warnings
import fiftyone.core.models as fom
import fiftyone.utils.torch as fout
import torch
from .tokenizer import SimpleTokenizer
from .model import build_model
class TorchCLIPModelConfig(fout.TorchImageModelConfig):
"""Configuration for running a :class:`TorchCLIPModel`.
See :class:`fiftyone.utils.torch.TorchImageModelConfig` for additional
arguments.
Args:
model_path: the path to the model's weights on disk
tokenizer_path: the path to the model's tokenizer on disk
context_length: the model's context length
text_prompt: the text prompt to use, e.g., ``"A photo of"``
classes: the list of classes to use for zero-shot prediction
"""
def __init__(self, d):
super().__init__(d)
self.model_path = self.parse_string(d, "model_path")
self.tokenizer_path = self.parse_string(d, "tokenizer_path")
self.context_length = self.parse_int(d, "context_length")
self.text_prompt = self.parse_string(d, "text_prompt")
class TorchCLIPModel(fout.TorchImageModel, fom.PromptMixin):
"""Wrapper for CLIP from https://github.com/openai/CLIP.
Args:
config: a :class:`TorchCLIPModelConfig`
"""
def __init__(self, config):
super().__init__(config)
self._tokenizer = SimpleTokenizer(config.tokenizer_path)
self._text_features = None
@property
def can_embed_prompts(self):
return True
def embed_prompt(self, prompt):
return self.embed_prompts([prompt])[0]
def embed_prompts(self, prompts):
return self._embed_prompts(prompts).detach().cpu().numpy()
def _load_model(self, config):
with open(config.model_path, "rb") as f:
model = torch.jit.load(f, map_location=self.device).eval()
return build_model(model.state_dict()).to(self.device).float()
def _embed_prompts(self, prompts):
# source: https://github.com/openai/CLIP/blob/main/clip/clip.py
sot_token = self._tokenizer.encoder["<|startoftext|>"]
eot_token = self._tokenizer.encoder["<|endoftext|>"]
all_tokens = [
[sot_token] + self._tokenizer.encode(p) + [eot_token]
for p in prompts
]
if Version(torch.__version__) < Version("1.8.0"):
dtype = torch.long
else:
dtype = torch.int
text_features = torch.zeros(
len(all_tokens),
self.config.context_length,
dtype=dtype,
device=self.device,
)
for i, (prompt, tokens) in enumerate(zip(prompts, all_tokens)):
if len(tokens) > self.config.context_length:
tokens = tokens[: self.config.context_length]
tokens[-1] = eot_token
msg = (
"Truncating prompt '%s'; too long for context length '%d'"
% (prompt, self.config.context_length)
)
warnings.warn(msg)
text_features[i, : len(tokens)] = torch.tensor(tokens)
with torch.no_grad():
return self._model.encode_text(text_features)
def _get_text_features(self):
if self._text_features is None:
prompts = [
"%s %s" % (self.config.text_prompt, c) for c in self.classes
]
self._text_features = self._embed_prompts(prompts)
return self._text_features
def _get_class_logits(self, text_features, image_features):
# source: https://github.com/openai/CLIP/blob/main/README.md
image_features = image_features / image_features.norm(
dim=1, keepdim=True
)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
logit_scale = self._model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
def _predict_all(self, imgs):
if self._preprocess:
imgs = [self._transforms(img) for img in imgs]
if isinstance(imgs, (list, tuple)):
imgs = torch.stack(imgs)
height, width = imgs.size()[-2:]
frame_size = (width, height)
if self._using_gpu:
imgs = imgs.cuda()
text_features = self._get_text_features()
image_features = self._model.encode_image(imgs)
output, _ = self._get_class_logits(text_features, image_features)
if self.has_logits:
self._output_processor.store_logits = self.store_logits
return self._output_processor(
output, frame_size, confidence_thresh=self.config.confidence_thresh
)