Skip to content

Commit

Permalink
feat: add apply_lora_from_file method to Context
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Apr 20, 2023
1 parent 4ad38f7 commit 3bbae28
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ext/llama_cpp/dummy.rb
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def free; end
# @param model_path [String] The path to the model file.
# @param params [ContextParams] The parameters for context.
def load(model_path:, params:); end

# Applies LLoRa from file.
#
# @param lora_path [String] The path to the LoRA file.
# @param base_model_path [String] The path to the base model file.
# @param n_threads [Integer] The number of threads.
def apply_lora_from_file(lora_path:, base_model_path: nil, n_threads: 1); end
end

# Class for parameters
Expand Down
38 changes: 38 additions & 0 deletions ext/llama_cpp/llama_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class RbLLaMAContext {
rb_define_method(rb_cLLaMAContext, "reset_timings", RUBY_METHOD_FUNC(_llama_context_reset_timings), 0);
rb_define_method(rb_cLLaMAContext, "free", RUBY_METHOD_FUNC(_llama_context_free), 0);
rb_define_method(rb_cLLaMAContext, "load", RUBY_METHOD_FUNC(_llama_context_load), -1);
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
};

private:
Expand Down Expand Up @@ -560,6 +561,43 @@ class RbLLaMAContext {
RB_GC_GUARD(filename);
return Qnil;
};

static VALUE _llama_context_apply_lora_from_file(int argc, VALUE* argv, VALUE self) {
VALUE kw_args = Qnil;
ID kw_table[3] = { rb_intern("lora_path"), rb_intern("base_model_path"), rb_intern("n_threads") };
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
rb_scan_args(argc, argv, ":", &kw_args);
rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);

if (!RB_TYPE_P(kw_values[0], T_STRING)) {
rb_raise(rb_eArgError, "lora_path must be a string");
return Qnil;
}
if (kw_values[1] != Qundef && !RB_TYPE_P(kw_values[1], T_STRING)) {
rb_raise(rb_eArgError, "base_model_path must be a string");
return Qnil;
}
if (kw_values[2] != Qundef && !RB_INTEGER_TYPE_P(kw_values[2])) {
rb_raise(rb_eArgError, "n_threads must be an integer");
return Qnil;
}

const char* lora_path = StringValueCStr(kw_values[0]);
const char* base_model_path = kw_values[1] == Qundef ? NULL : StringValueCStr(kw_values[1]);
const int n_threads = kw_values[2] == Qundef ? 1 : NUM2INT(kw_values[2]);

LLaMAContextWrapper* ptr = get_llama_context(self);
if (ptr->ctx != NULL) {
rb_raise(rb_eRuntimeError, "LLaMA context is already loaded");
return Qnil;
}

if (llama_apply_lora_from_file(ptr->ctx, lora_path, base_model_path, n_threads) != 0) {
rb_raise(rb_eRuntimeError, "Failed to apply LoRA");
return Qnil;
}
return Qnil;
};
};

const rb_data_type_t RbLLaMAContext::llama_context_type = {
Expand Down
1 change: 1 addition & 0 deletions sig/llama_cpp.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module LLaMACpp
def sample_top_p_top_k: (top_k: Integer, top_p: Float, temp: Float, penalty: Float) -> Integer
def token_to_str: (Integer) -> String
def tokenize: (text: String, ?n_max_tokens: Integer, ?add_bos: bool) -> Array[Integer]
def apply_lora_from_file: (lora_path: String, ?base_model_path: String, ?n_threads: Integer) -> void
end

class ContextParams
Expand Down

0 comments on commit 3bbae28

Please sign in to comment.