Skip to content

Commit

Permalink
feat: 适配transformers==4.10.2及网络下载
Browse files Browse the repository at this point in the history
  • Loading branch information
TingsongYu committed Apr 23, 2024
1 parent b1fe6c7 commit f2001f8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
3 changes: 3 additions & 0 deletions code/chapter-8/07_image_captioning/clip_cap_base/01_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tqdm
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
# from transformers import AdamW, WarmupLinearSchedule
from my_models.models import *
from my_datasets.cocodataset import *

Expand All @@ -36,6 +37,8 @@ def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
)
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=epochs * len(train_dataloader))

# save_config(args)
for epoch in range(epochs):
print(f">>> Training epoch {epoch}")
Expand Down
23 changes: 13 additions & 10 deletions code/chapter-8/07_image_captioning/clip_cap_base/02_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@
@brief : 模型推理
"""

import os
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# os.environ['HF_ENDPOINT'] = "https://ai.gitee.com/huggingface"
import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
import sys
from pathlib import Path

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # project root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH

# debug: windows下会报错:OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
import platform

if platform.system() == 'Windows':
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import clip
Expand Down Expand Up @@ -55,7 +63,7 @@ def __init__(self, path_ckpt):
model = ClipCaptionPrefix(args.prefix_length, clip_length=args.prefix_length, prefix_size=512,
num_layers=args.num_layers, mapping_type=args.mapping_type)

model.load_state_dict(torch.load(path_ckpt, map_location=torch.device("cpu")))
model.load_state_dict(torch.load(path_ckpt, map_location=torch.device("cpu")), strict=False)
model = model.eval()
model = model.to(self.device)
self.model = model
Expand All @@ -78,7 +86,8 @@ def main():
# download from :提取码:mqri](https://pan.baidu.com/s/1CuTDtCeT2-nIvRG7N4iKtw)
ckpt_path = r'coco_prefix-009-2023-0411.pt'
path_img = r'G:\deep_learning_data\coco_2014\images\val2014'
out_dir = './inference_output'
# path_img = r'G:\deep_learning_data\coco_2017\images\train2017\train2017'
out_dir = './inference_output2'

# 获取路径
img_paths = []
Expand All @@ -94,7 +103,7 @@ def main():
for idx, path_img in tqdm(enumerate(img_paths)):
caps, pil_image = predictor.predict(path_img, False)
img_bgr = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
cv2.putText(img_bgr, caps, (0, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 1)
cv2.putText(img_bgr, caps, (0, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)

# 保存
path_out = os.path.join(out_dir, os.path.basename(path_img))
Expand All @@ -105,9 +114,3 @@ def main():

if __name__ == '__main__':
main()






Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
# self.gpt = GPT2LMHeadModel.from_pretrained('gpt2', force_download=True)
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if mapping_type == MappingType.MLP:
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def generate2(

outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
# logits = outputs[0]
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
Expand Down

0 comments on commit f2001f8

Please sign in to comment.