Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add RPC backend #6829

Merged
merged 19 commits into from
May 14, 2024
Merged
Prev Previous commit
fix compile warnings on macos
  • Loading branch information
rgerganov committed May 14, 2024
commit 1519cb4582db5966656b889dda419baead501c31
13 changes: 7 additions & 6 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "ggml.h"
#include "ggml-backend-impl.h"

#include <cinttypes>
#include <string>
#include <vector>
#include <memory>
Expand Down Expand Up @@ -729,7 +730,7 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
uint64_t remote_size = buffer->size;
GGML_PRINT_DEBUG("[%s] size: %lu -> remote_ptr: %lx, remote_size: %lu\n", __func__, size, remote_ptr, remote_size);
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
Expand Down Expand Up @@ -758,7 +759,7 @@ static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<
// input serialization format: | remote_ptr (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %lx\n", __func__, remote_ptr);
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
void * base = ggml_backend_buffer_get_base(buffer);
// output serialization format: | base_ptr (8 bytes) |
Expand All @@ -771,7 +772,7 @@ static void rpc_free_buffer(const std::vector<uint8_t> & input) {
// input serialization format: | remote_ptr (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %lx\n", __func__, remote_ptr);
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
ggml_backend_buffer_free(buffer);
}
Expand All @@ -782,7 +783,7 @@ static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
uint8_t value;
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
GGML_PRINT_DEBUG("[%s] remote_ptr: %lx, value: %u\n", __func__, remote_ptr, value);
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
ggml_backend_buffer_clear(buffer, value);
}
Expand All @@ -801,7 +802,7 @@ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %lu, size: %lu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
ggml_backend_tensor_set(tensor, data, offset, size);
ggml_free(ctx);
Expand All @@ -822,7 +823,7 @@ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %lu, size: %lu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
// output serialization format: | data (size bytes) |
output.resize(size, 0);
ggml_backend_tensor_get(tensor, output.data(), offset, size);
Expand Down
Loading