Skip to content

Commit

Permalink
sam from bounding box
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Cohen committed Jan 18, 2024
1 parent fd23b16 commit 01e9a96
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
'TrimapToMask (segment anything)': TrimapToMask,
'InvertMask (segment anything)': InvertMask,
"IsMaskEmpty": IsMaskEmptyNode,
"BoundingBox (segment anything)":BoundingBox,
"MaskToBoundingBox (segment anything)":MaskToBoundingBox,
"BoundingBoxSAMSegment (segment anything)":BoundingBoxSAMSegment
}

__all__ = ['NODE_CLASS_MAPPINGS']
Expand Down
80 changes: 79 additions & 1 deletion node.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,5 +497,83 @@ def INPUT_TYPES(cls):

def to_mask(self, trimap: torch.Tensor):
return (trimap,)

class MaskToBoundingBox:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK", {}),
}
}
CATEGORY = "mask"
FUNCTION = "main"
RETURN_TYPES = ("BOUNDING_BOX")


def main(self, mask):
mask_np = np.clip(255. * mask.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
axis = np.where(mask_np != 0)
rmin = np.min(axis[0])
rmax = np.max(axis[0])
cmin = np.min(axis[1])
cmax = np.max(axis[1])
return (torch.FloatTensor([cmin,rmin,cmax,rmax]),)

class BoundingBox:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"left": ("INT", {"default":0,"step":1}),
"top": ("INT", {"default":0,"step":1}),
"right": ("INT", {"default":0,"step":1}),
"bottom": ("INT", {"default":0,"step":1}),

}
}
CATEGORY = "mask"
FUNCTION = "util"
RETURN_TYPES = ("BOUNDING_BOX")



def main(self, left,top,right,bottom,):
return (torch.FloatTensor([left,top,right,bottom]),)


class BoundingBoxSAMSegment:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"sam_model": ('SAM_MODEL', {}),
"image": ('IMAGE', {}),
"bounding_box": ("BOUNDING_BOX", {}),

}
}
CATEGORY = "segment_anything"
FUNCTION = "main"
RETURN_TYPES = ("IMAGE", "MASK")

def main(self, sam_model, image, bounding_box):
res_images = []
res_masks = []

for item in image:
item = Image.fromarray(
np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA')

(images, masks) = sam_segment(
sam_model,
item,
bounding_box
)
res_images.extend(images)
res_masks.extend(masks)

if len(res_images) == 0:
_, height, width, _ = image.size()
empty_mask = torch.zeros((1, height, width), dtype=torch.uint8, device="cpu")
return (empty_mask, empty_mask)
return (torch.cat(res_images, dim=0), torch.cat(res_masks, dim=0))

0 comments on commit 01e9a96

Please sign in to comment.