From 1937532d60b7f61f1e102ed535561ab252a02755 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Fri, 18 Feb 2022 11:27:04 -0500 Subject: [PATCH] Patch clip model for ONNX compatibility Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64) Use explicit dimension for norm --- clip/clip.py | 4 ++-- clip/model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 2c911d060..6c8035e9a 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -192,7 +192,7 @@ def patch_float(module): return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.IntTensor: """ Returns the tokenized representation of given input string(s) @@ -217,7 +217,7 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: diff --git a/clip/model.py b/clip/model.py index f7958f171..e743d2c78 100644 --- a/clip/model.py +++ b/clip/model.py @@ -356,8 +356,8 @@ def forward(self, image, text): text_features = self.encode_text(text) # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp()