Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NMS to CoreML model output, works with Vision #7263

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 3, 2022
commit 749d8eb7689eeaf591cf88792359e7bdd6025261
14 changes: 7 additions & 7 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):

class CoreMLExportModel(torch.nn.Module):
def __init__(self, base_model, img_size):
super(CoreMLExportModel, self).__init__()
super().__init__()
self.base_model = base_model
self.img_size = img_size

Expand All @@ -194,9 +194,10 @@ def forward(self, x):
h = self.img_size[1]
objectness = x[:, 4:5]
class_probs = x[:, 5:] * objectness
boxes = x[:, :4] * torch.tensor([1./w, 1./h, 1./w, 1./h])
boxes = x[:, :4] * torch.tensor([1. / w, 1. / h, 1. / w, 1. / h])
return class_probs, boxes


def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, iou_thres, prefix=colorstr('CoreML:')):
# YOLOv5 CoreML export
try:
Expand Down Expand Up @@ -288,8 +289,8 @@ def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, i
# Add descriptions to the inputs and outputs
pipeline.spec.description.input[1].shortDescription = "(optional) IOU Threshold override"
pipeline.spec.description.input[2].shortDescription = "(optional) Confidence Threshold override"
pipeline.spec.description.output[0].shortDescription = u"Boxes Class confidence"
pipeline.spec.description.output[1].shortDescription = u"Boxes [x, y, width, height] (normalized to [0...1])"
pipeline.spec.description.output[0].shortDescription = "Boxes Class confidence"
pipeline.spec.description.output[1].shortDescription = "Boxes [x, y, width, height] (normalized to [0...1])"

# Add metadata to the model
pipeline.spec.description.metadata.shortDescription = "YOLOv5 object detector"
Expand All @@ -299,15 +300,14 @@ def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, i
user_defined_metadata = {
"iou_threshold": str(iou_thres),
"confidence_threshold": str(conf_thres),
"classes": ", ".join(labels)
}
"classes": ", ".join(labels)}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hietalajulius @mshamash Also relevant here.

pipeline.spec.description.metadata.userDefined.update(user_defined_metadata)

# Don't forget this or Core ML might attempt to run the model on an unsupported operating system version!
pipeline.spec.specificationVersion = 3

ct_model = ct.models.MLModel(pipeline.spec)

f = str(file).replace('.pt', '.mlmodel')
ct_model.save(f)

Expand Down