Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add RPC backend #6829

Merged
merged 19 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add get_device_memory
  • Loading branch information
rgerganov committed May 14, 2024
commit 0b5e8a71839c77d74be0e7ba124dfaf1a782cace
20 changes: 15 additions & 5 deletions examples/rpc/rpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ static ggml_backend_t create_backend() {
if (!backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
#endif

#ifdef GGML_USE_METAL
#elif GGML_USE_METAL
fprintf(stderr, "%s: using Metal backend\n", __func__);
backend = ggml_backend_metal_init();
if (!backend) {
Expand All @@ -44,6 +42,16 @@ static ggml_backend_t create_backend() {
return backend;
}

static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
#ifdef GGML_USE_CUDA
ggml_backend_cuda_get_device_memory(0, free_mem, total_mem);
#else
// TODO: implement for other backends
*free_mem = 1;
*total_mem = 1;
#endif
}

static int create_server_socket(const char * host, int port) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
Expand Down Expand Up @@ -101,8 +109,10 @@ int main(int argc, char * argv[])
close(client_socket);
continue;
}
printf("Accepted client connection\n");
rpc_serve_client(backend, client_socket);
size_t free_mem, total_mem;
get_backend_memory(&free_mem, &total_mem);
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(backend, client_socket, free_mem, total_mem);
printf("Client connection closed\n");
close(client_socket);
}
Expand Down
37 changes: 32 additions & 5 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,31 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}

static void get_device_memory(const std::shared_ptr<sockfd> & sock, size_t * free, size_t * total) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) |
uint64_t free_mem;
memcpy(&free_mem, output.data(), sizeof(free_mem));
uint64_t total_mem;
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
*free = free_mem;
*total = total_mem;
}

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
UNUSED(endpoint);
// TODO: implement
*free = 1;
*total = 1;
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
if (backend == nullptr) {
*free = 0;
*total = 0;
return;
}
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
get_device_memory(ctx->sock, free, total);
}

// RPC server-side implementation
Expand Down Expand Up @@ -759,7 +779,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
ggml_free(ctx);
}

void rpc_serve_client(ggml_backend_t backend, int sockfd) {
void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) {
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
Expand Down Expand Up @@ -816,6 +836,13 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
rpc_graph_compute(backend, input, output);
break;
}
case GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
break;
}
default: {
fprintf(stderr, "Unknown command: %d\n", cmd);
rgerganov marked this conversation as resolved.
Show resolved Hide resolved
break;
Expand Down
3 changes: 2 additions & 1 deletion ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum rpc_cmd {
GET_TENSOR,
COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
};

#define GGML_RPC_MAX_SERVERS 16
Expand All @@ -49,7 +50,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total);

GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd);
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem);

#ifdef __cplusplus
}
Expand Down