Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add RPC backend #6829

Merged
merged 19 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address review comments
  • Loading branch information
rgerganov committed May 14, 2024
commit dfadd1a82c471e67d06cfa0c1ff3056151bad894
35 changes: 23 additions & 12 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
// RPC data structures

static ggml_guid_t ggml_backend_rpc_guid() {
static ggml_guid guid = { 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff};
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
return &guid;
}

Expand All @@ -45,6 +45,7 @@ struct ggml_backend_rpc_context {

struct ggml_backend_rpc_buffer_context {
int sockfd;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
uint64_t remote_ptr;
std::string name;
};
Expand All @@ -62,17 +63,20 @@ static int socket_connect(const char * host, int port) {
int flag = 1;
int ret = setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
if (ret < 0) {
close(sock);
return -1;
}
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
struct hostent * server = gethostbyname(host);
if (server == NULL) {
fprintf(stderr, "Cannot resolve host '%s'\n", host);
close(sock);
return -1;
}
bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length);
if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
close(sock);
return -1;
}
return sock;
Expand Down Expand Up @@ -152,11 +156,10 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
}

GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
static std::unordered_map<ggml_backend_buffer_t, void *> cache;
if (cache.find(buffer) != cache.end()) {
return cache[buffer];
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
return ctx->base_cache[buffer];
}
// input serialization format: | remote_ptr (8 bytes) |
std::vector<uint8_t> input(sizeof(uint64_t), 0);
uint64_t remote_ptr = ctx->remote_ptr;
Expand All @@ -169,7 +172,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t base_ptr;
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
void * base = reinterpret_cast<void *>(base_ptr);
cache[buffer] = base;
ctx->base_cache[buffer] = base;
return base;
}

Expand Down Expand Up @@ -331,7 +334,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer

ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, remote_ptr, "RPC"},
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, {}, remote_ptr, "RPC"},
remote_size);

return buffer;
Expand All @@ -343,6 +346,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
return 128;
}

GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
UNUSED(buft);
// TODO: this is hardcoded for now but it should come from the remote backend
return SIZE_MAX;
}

GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
UNUSED(buft);
return ggml_nbytes(tensor);
Expand All @@ -361,7 +370,7 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
/* .is_host = */ NULL,
Expand Down Expand Up @@ -475,7 +484,7 @@ static std::unordered_map<std::string, ggml_backend_t> instances;

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const std::string & endpoint) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
return ggml_backend_rpc_get_default_buffer_type(backend);
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
Expand All @@ -488,7 +497,9 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
std::string host = endpoint.substr(0, pos);
int port = std::stoi(endpoint.substr(pos + 1));
int sockfd = socket_connect(host.c_str(), port);
GGML_ASSERT(sockfd >= 0 && "failed to connect to the server");
if (sockfd < 0) {
return nullptr;
}

ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sockfd = */ sockfd,
Expand All @@ -502,7 +513,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {

ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC",
/* .name = */ "RPC" + std::to_string(sockfd),
/* .sockfd = */ sockfd,
/* .buft = */ buft
};
Expand All @@ -522,9 +533,9 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
UNUSED(endpoint);
UNUSED(total);
// TODO: implement
*free = 1;
rgerganov marked this conversation as resolved.
Show resolved Hide resolved
*total = 1;
}

// RPC server-side implementation
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15729,7 +15729,7 @@ struct llama_context * llama_new_context_with_model(
for (auto & server : model->rpc_servers) {
ggml_backend_t backend = ggml_backend_rpc_init(server);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize RPC backend, endpoint: %s\n", __func__, server.c_str());
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
llama_free(ctx);
return nullptr;
}
Expand Down