Skip to content

Commit

Permalink
examples : add sample SAM inference (ggerganov#74)
Browse files Browse the repository at this point in the history
* sam : image + prompt encoder, store embeddings

* sam : add the dense img pe in SAM state (ggerganov#401)

* Add SAM decoder & output masks as png (ggerganov#418)

* Add loading of decoder layers in Model

* Multiply by hypernet_layer_cnt for ctx_size on model load

* Add decoder layers to py conversion script

* Fix wrong and reversed tensor sizes for decoder

* Add decoder transformer implementation

* Add decoder hypernet and iou prediction mlps

* Add transpose convolution operation and unit test

* Finish mask decoder and write the decoder output in the model state

* Output masks to png after removing padding and upsampling to original size

- Also filter based on the iou treshold
- Additionally filtering based on the stability score and crop boxes
should be done

* Add stb image write in order to output masks from SAM

* Add transpose convolution 2d name and symbol to ggml ops static arrays

* Comment out debug print in transpose convolution test to fix compilation

ggml-ci

* Multithread GGML_OP_ADD_REL_POS operation

* ggml : fix GGML_OP_NAME array

* Disable and comment out debug prints in SAM example

* Add README for the SAM example

* Calculate & filter based on stability score and calculate bounding box

ggml-ci

---------

Co-authored-by: Yavor Ivanov <yivanov@viewray.com>
  • Loading branch information
ggerganov and Yavor Ivanov authored Aug 18, 2023
1 parent 7cf109e commit 8da5be2
Show file tree
Hide file tree
Showing 13 changed files with 12,891 additions and 15 deletions.
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ add_subdirectory(dolly-v2)
add_subdirectory(replit)
add_subdirectory(mpt)
add_subdirectory(starcoder)
add_subdirectory(sam)
43 changes: 43 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,46 @@ float similarity(const std::string & s0, const std::string & s1) {

return 1.0f - (dist / std::max(s0.size(), s1.size()));
}

bool sam_params_parse(int argc, char ** argv, sam_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--inp") {
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
sam_print_usage(argc, argv, params);
exit(0);
}
}

return true;
}

void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
19 changes: 18 additions & 1 deletion examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#define COMMON_SAMPLE_RATE 16000

//
// CLI argument parsing
// GPT CLI argument parsing
//

struct gpt_params {
Expand Down Expand Up @@ -157,3 +157,20 @@ bool vad_simple(

// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1);

//
// SAM argument parsing
//

struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
};

bool sam_params_parse(int argc, char ** argv, sam_params & params);

void sam_print_usage(int argc, char ** argv, const sam_params & params);
13 changes: 13 additions & 0 deletions examples/sam/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# sam

set(TEST_TARGET sam)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)

#
# sam-quantize

#set(TEST_TARGET sam-quantize)
#add_executable(${TEST_TARGET} quantize.cpp)
#target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
89 changes: 89 additions & 0 deletions examples/sam/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SAM.cpp

Inference of Meta's [Segment Anything Model](https://github.com/facebookresearch/segment-anything/) in pure C/C++

## Description

The example currently supports only the [ViT-B SAM model checkpoint](https://huggingface.co/facebook/sam-vit-base).

## Next steps

- [ ] Reduce memory usage by utilizing the new ggml-alloc
- [ ] Remove redundant graph nodes
- [ ] Make inference faster
- [ ] Fix the difference in output masks compared to the PyTorch implementation
- [ ] Filter masks based on stability score and based on boxes, which touch crop boundaries
- [ ] Add support for user input
- [ ] Support F16 for heavy F32 ops
- [ ] Test quantization
- [ ] Support bigger model checkpoints
- [ ] GPU support

## Quick start
```bash
git clone https://github.com/ggerganov/ggml
cd ggml

# Install Python dependencies
python3 -m pip install -r requirements.txt

# Convert PTH model to ggml
python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1

# Build ggml + examples
mkdir build && cd build
cmake .. && make -j4

# run inference
./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin
```

## Downloading and converting the model checkpoints

You can download a [model checkpoint](https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints) and convert it to `ggml` format using the script `convert-pth-to-ggml.py`:

```
# Convert PTH model to ggml
python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1
```

## Example output
```
$ ./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin
main: seed = 1692347524
main: loaded image '../img.jpg' (680 x 453)
sam_image_preprocess: scale = 0.664062
main: preprocessed image (1024 x 1024)
sam_model_load: loading model from '../examples/sam/ggml-model-f16.bin' - please wait ...
sam_model_load: n_enc_state = 768
sam_model_load: n_enc_layer = 12
sam_model_load: n_enc_head = 12
sam_model_load: n_enc_out_chans = 256
sam_model_load: n_pt_embd = 4
sam_model_load: ftype = 1
sam_model_load: qntvr = 0
operator(): ggml ctx size = 202.32 MB
sam_model_load: ...................................... done
sam_model_load: model size = 185.05 MB / num tensors = 304
point: 624.500000 245.593750
main: load time = 88.36 ms
main: total time = 5697.57 ms
```

Input point is (414.375, 162.796875) (currently hardcoded)

Input image:

![llamas](https://user-images.githubusercontent.com/8558655/261301565-37b7bf4b-bf91-40cf-8ec1-1532316e1612.jpg)

Output mask:

![mask_glasses](https://user-images.githubusercontent.com/8558655/261301844-9fc2dbbc-5fd6-42ce-af69-643df9e6fad1.png)

## References

- [ggml](https://github.com/ggerganov/ggml)
- [SAM](https://segment-anything.com/)
- [SAM demo](https://segment-anything.com/demo)
134 changes: 134 additions & 0 deletions examples/sam/convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Convert a SAM model checkpoint to a ggml compatible file
#

import os
import sys
import code
import json
import torch
import struct
import numpy as np

if len(sys.argv) < 3:
print("Usage: convert-pth-to-ggml.py file-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)

# output in the same directory as the model
fname_model = sys.argv[1]
fname_out = os.path.dirname(fname_model) + "/ggml-model.bin"

# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])

if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)

fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")

model = torch.load(fname_model, map_location="cpu")

# TODO: determine based on model data
# TODO: add decoder / prompt encoder if needed
hparams = {
"n_enc_state": 768,
"n_enc_layers": 12,
"n_enc_heads": 12,
"n_enc_out_chans": 256,

"n_pt_embd": 4,
}

print(hparams)

for k, v in model.items():
print(k, v.shape)

#exit()
#code.interact(local=locals())

fout = open(fname_out, "wb")

fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["n_enc_state"]))
fout.write(struct.pack("i", hparams["n_enc_layers"]))
fout.write(struct.pack("i", hparams["n_enc_heads"]))
fout.write(struct.pack("i", hparams["n_enc_out_chans"]))
fout.write(struct.pack("i", hparams["n_pt_embd"]))
fout.write(struct.pack("i", ftype))

for k, v in model.items():
name = k
shape = v.shape

if name[:19] == "prompt_encoder.mask":
continue

print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)

#data = tf.train.load_variable(dir_model, name).squeeze()
#data = v.numpy().squeeze()
data = v.numpy()
n_dims = len(data.shape);

# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()

dshape = data.shape

# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1 or \
name == "image_encoder.pos_embed" or \
name.startswith("prompt_encoder") or \
name.startswith("mask_decoder.iou_token") or \
name.startswith("mask_decoder.mask_tokens"):
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
print(" Converting to float16")
data = data.astype(np.float16)

# reshape the 1D bias into a 4D tensor so we can use ggml_repeat
# keep it in F32 since the data is small
if name == "image_encoder.patch_embed.proj.bias":
data = data.reshape(1, data.shape[0], 1, 1)
n_dims = len(data.shape);
dshape = data.shape

print(" New shape: ", dshape)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(str);

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")
Loading

0 comments on commit 8da5be2

Please sign in to comment.