From 3fa4ef001ae06de06344bd3c72fad86adc9d08f2 Mon Sep 17 00:00:00 2001 From: PAB Date: Tue, 10 Oct 2023 23:56:39 +0200 Subject: [PATCH] fix: debug forward pass (#14) --- convert.py | 7 +- encodec.cpp | 296 ++++++++++++++++++++++++++--------------- encodec.h | 29 ++-- examples/main/main.cpp | 13 +- ggml | 2 +- 5 files changed, 220 insertions(+), 127 deletions(-) diff --git a/convert.py b/convert.py index f3860d273..f45cdeb9e 100644 --- a/convert.py +++ b/convert.py @@ -80,7 +80,12 @@ def parse_codec_model(checkpoint, outfile, use_f16): print(f"Processing variable: {name} with shape: {var_data.shape}") if use_f16: - if "weight" in name or "embed" in name: + if "embed" in name: + print(" Converting to float32") + var_data = var_data.astype(np.float32) + ftype_cur = 0 + n_f32 += 1 + elif "weight" in name: print(" Converting to float16") var_data = var_data.astype(np.float16) ftype_cur = 1 diff --git a/encodec.cpp b/encodec.cpp index acbca37a9..0acfbaf72 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -8,12 +8,37 @@ #include #include -#include "encodec.h" #include "ggml.h" #include "ggml-alloc.h" +#include "ggml-backend.h" +#include "encodec.h" -static const size_t TENSOR_ALIGNMENT = 32; +void print_tensor(struct ggml_tensor * a) { + if (a) { + for (int i = 0; i < a->ne[3]; i++) { + for (int j = 0; j < a->ne[2]; j++) { + for (int k = 0; k < a->ne[1]; k++) { + for (int l = 0; l < a->ne[0]; l++) { + if (a->type == GGML_TYPE_F32) { + float * aval = (float *) ( + (char *) a->data + i*a->nb[3] + j*a->nb[2] + k*a->nb[1] + l*a->nb[0]); + printf("%.4f ", *aval); + } else if (a->type == GGML_TYPE_I32) { + int32_t * aval = (int32_t *) ( + (char *) a->data + i*a->nb[3] + j*a->nb[2] + k*a->nb[1] + l*a->nb[0]); + printf("%d ", *aval); + } else { + throw; + } + } + printf("\n"); + } + printf("\n\n"); + } + } + } +} template static void read_safe(std::ifstream& infile, T& dest) { @@ -140,8 +165,7 @@ static struct ggml_tensor * forward_pass_lstm_unilayer( struct ggml_tensor * weight_ih, struct ggml_tensor * weight_hh, struct ggml_tensor * bias_ih, - struct ggml_tensor * bias_hh, - bool is_measure) { + struct ggml_tensor * bias_hh) { const int input_dim = inp->ne[1]; const int hidden_dim = weight_ih->ne[1]/4; @@ -152,11 +176,6 @@ static struct ggml_tensor * forward_pass_lstm_unilayer( struct ggml_tensor * c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); struct ggml_tensor * h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); - // if (!is_measure) { - // h_t = ggml_set_zero(h_t); - // c_t = ggml_set_zero(c_t); - // } - struct ggml_tensor * current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); for (int t = 0; t < seq_length; t++) { @@ -176,6 +195,7 @@ static struct ggml_tensor * forward_pass_lstm_unilayer( struct ggml_tensor * o_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3*sizeof(float)*hidden_dim)); c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t)); + h_t = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_t)); hs = ggml_set_1d(ctx0, hs, h_t, t*hs->nb[1]); @@ -270,7 +290,9 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) } auto & ctx = model.ctx; - size_t ctx_size = 0; + + size_t buffer_size = 0; + size_t n_tensors = 0; // Evaluating context size { @@ -291,54 +313,60 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) int mult = 1; // scaling factor for hidden size // initial conv1d layer - ctx_size += in_channels * n_filters * kernel_size * ggml_type_size(wtype); // weight - ctx_size += n_filters * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += in_channels * n_filters * kernel_size * ggml_type_size(wtype); // weight + buffer_size += n_filters * ggml_type_size(GGML_TYPE_F32); // bias // resnet blocks for (int i = 0; i < 4; i++) { // conv1 - ctx_size += res_kernel_sz * (mult*n_filters) * (mult*n_filters/2) * ggml_type_size(wtype); // weight - ctx_size += (mult*n_filters/2) * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += res_kernel_sz * (mult*n_filters) * (mult*n_filters/2) * ggml_type_size(wtype); // weight + buffer_size += (mult*n_filters/2) * ggml_type_size(GGML_TYPE_F32); // bias // conv2 - ctx_size += (mult*n_filters/2) * (mult*n_filters) * ggml_type_size(wtype); // weight - ctx_size += (mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += (mult*n_filters/2) * (mult*n_filters) * ggml_type_size(wtype); // weight + buffer_size += (mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias // shortcut - ctx_size += (mult*n_filters) * (mult*n_filters) * ggml_type_size(wtype); // weight - ctx_size += (mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += (mult*n_filters) * (mult*n_filters) * ggml_type_size(wtype); // weight + buffer_size += (mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias // downsampling layers - ctx_size += (2*ratios[3-i]) * (mult*n_filters) * (mult*n_filters*2) * ggml_type_size(wtype); // weight - ctx_size += (2*mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += (2*ratios[3-i]) * (mult*n_filters) * (mult*n_filters*2) * ggml_type_size(wtype); // weight + buffer_size += (2*mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias mult *= 2; } // lstm - ctx_size += 2 * n_lstm_layers * (mult*n_filters) * (4*mult*n_filters) * ggml_type_size(wtype); // weight_ih and weight_hh - ctx_size += 2 * n_lstm_layers * (4*mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias_ih and bias_hh + buffer_size += 2 * n_lstm_layers * (mult*n_filters) * (4*mult*n_filters) * ggml_type_size(wtype); // weight_ih and weight_hh + buffer_size += 2 * n_lstm_layers * (4*mult*n_filters) * ggml_type_size(GGML_TYPE_F32); // bias_ih and bias_hh // final conv - ctx_size += kernel_size * (mult*n_filters) * hidden_dim * ggml_type_size(wtype); // weight - ctx_size += hidden_dim * ggml_type_size(GGML_TYPE_F32); // bias + buffer_size += kernel_size * (mult*n_filters) * hidden_dim * ggml_type_size(wtype); // weight + buffer_size += hidden_dim * ggml_type_size(GGML_TYPE_F32); // bias } // decoder mirrors the encoder (same number of parameters), just double context size - ctx_size *= 2; + buffer_size *= 2; // quantizer - ctx_size += hidden_dim * n_bins * ggml_type_size(wtype); // embed + buffer_size += n_q * hidden_dim * n_bins * ggml_type_size(GGML_TYPE_F32); // embed + + buffer_size += 10ull*MB; // object overhead - ctx_size += 10ull*MB; // object overhead + n_tensors = ((4 * 2) * 4 + 2 + 4 * n_lstm_layers + 2) * 2; // encoder and decoder + n_tensors += n_q * 1; // quantizer + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0)); } // create the ggml context { struct ggml_init_params params = { - /* .mem_size = */ ctx_size, + /* .mem_size = */ ggml_tensor_overhead() * n_tensors, /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, + /* .no_alloc = */ true, }; model.ctx = ggml_init(params); @@ -348,6 +376,20 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) } } + if (!model.backend) { + // fallback to CPU backend + fprintf(stderr, "%s: using CPU backend\n", __func__); + model.backend = ggml_backend_cpu_init(); + } + + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__); + return false; + } + + // allocate weights buffer + model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size); + // prepare memory for the weights { const auto & hparams = model.hparams; @@ -519,7 +561,7 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) model.quantizer.blocks.resize(n_q); for (int i = 0; i < n_q; i++) { - model.quantizer.blocks[i].embed = ggml_new_tensor_2d(ctx, wtype, hidden_dim, n_bins); + model.quantizer.blocks[i].embed = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_dim, n_bins); model.tensors["quantizer.vq.layers." + std::to_string(i) + "._codebook.embed"] = model.quantizer.blocks[i].embed; } @@ -529,9 +571,13 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) // load weights { + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w); + size_t total_size = 0; model.n_loaded = 0; + std::vector read_buf; + while(true) { int32_t n_dims; int32_t length; @@ -563,6 +609,7 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) } auto tensor = model.tensors[name.data()]; + ggml_set_name(tensor, name.c_str()); if (ggml_nelements(tensor) != nelements) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); return false; @@ -575,13 +622,23 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) } const size_t bpe = ggml_type_size(ggml_type(ftype)); + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; } - infile.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + ggml_allocr_alloc(alloc, tensor); + + if (ggml_backend_is_cpu(model.backend)) { + infile.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + infile.read(read_buf.data(), ggml_nbytes(tensor)); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } // printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); @@ -589,7 +646,8 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) model.n_loaded++; } - fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + ggml_allocr_free(alloc); + printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); } infile.close(); @@ -597,26 +655,36 @@ bool encodec_load_model_weights(const std::string& fname, encodec_model& model) return true; } -static struct ggml_cgraph * encodec_build_graph( - encodec_context & ectx, +struct ggml_cgraph * encodec_graph( + struct encodec_context * ectx, const std::vector & inp_audio) { - const int32_t audio_length = inp_audio.size(); + const int N = inp_audio.size(); + + const auto & model = ectx->model; - const auto & model = *ectx.model; + auto & allocr = ectx->allocr; + + // since we are using ggml-alloc, this buffer only needs enough space to hold the + // ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead(); + static std::vector buf(buf_size); struct ggml_init_params ggml_params = { - /*.mem_size =*/ ectx.buf_compute.size(), - /*.mem_buffer =*/ ectx.buf_compute.data(), + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements }; struct ggml_context * ctx0 = ggml_init(ggml_params); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * inp = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, audio_length); - ggml_allocr_alloc(ectx.allocr, inp); - if (!ggml_allocr_is_measure(ectx.allocr)) { - memcpy(inp->data, inp_audio.data(), audio_length*ggml_element_size(inp)); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * inp = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N); + ggml_allocr_alloc(allocr, inp); + + // avoid writing to tensors if we are only measuring the memory usage + if (!ggml_allocr_is_measure(allocr)) { + ggml_backend_tensor_set(inp, inp_audio.data(), 0, N*ggml_element_size(inp)); } // encoder @@ -671,13 +739,11 @@ static struct ggml_cgraph * encodec_build_graph( // first lstm layer struct ggml_tensor * hs1 = forward_pass_lstm_unilayer( - ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, - ggml_allocr_is_measure(ectx.allocr)); + ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b); // second lstm layer struct ggml_tensor * out = forward_pass_lstm_unilayer( - ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, - ggml_allocr_is_measure(ectx.allocr)); + ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b); inpL = ggml_add(ctx0, inpL, out); } @@ -792,13 +858,11 @@ static struct ggml_cgraph * encodec_build_graph( // first lstm layer struct ggml_tensor * hs1 = forward_pass_lstm_unilayer( - ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, - ggml_allocr_is_measure(ectx.allocr)); + ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b); // second lstm layer struct ggml_tensor * out = forward_pass_lstm_unilayer( - ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, - ggml_allocr_is_measure(ectx.allocr)); + ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b); inpL = ggml_add(ctx0, inpL, out); } @@ -845,95 +909,117 @@ static struct ggml_cgraph * encodec_build_graph( out = decoded_inp; } - out = ggml_cpy(ectx.ctx_audio, out, ectx.reconstructed_audio); - ggml_build_forward_expand(gf, out); - ggml_disconnect_node_from_graph(ectx.reconstructed_audio); ggml_free(ctx0); return gf; } -bool encodec_reconstruct_audio( - encodec_context & ectx, - std::vector & raw_audio, - int n_threads) { - const int64_t t_start_ms = ggml_time_ms(); +bool encodec_eval( + struct encodec_context * ectx, + std::vector & raw_audio, + const int n_threads) { + auto & model = ectx->model; + auto & allocr = ectx->allocr; + + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = encodec_graph(ectx, raw_audio); - static const size_t buf_size = 256u*1024*1024; + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); - if (ectx.ctx_audio) { - ggml_free(ectx.ctx_audio); - ectx.ctx_audio = {}; + // run the computation + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); } + ggml_backend_graph_compute(model.backend, gf); - struct ggml_init_params ggml_params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, - }; + // reconstructed audio is the last one in the graph + struct ggml_tensor * out = gf->nodes[gf->n_nodes - 1]; - ectx.ctx_audio = ggml_init(ggml_params); + auto & out_audio = ectx->out_audio; - ectx.reconstructed_audio = ggml_new_tensor_1d(ectx.ctx_audio, GGML_TYPE_F32, 100160); + int out_length = out->ne[0]; + out_audio.resize(out_length); - // reconstruct the audio - ectx.buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - ectx.allocr = ggml_allocr_new_measure(TENSOR_ALIGNMENT); - struct ggml_cgraph * gf_measure = encodec_build_graph(ectx, raw_audio); - if (!gf_measure) { - fprintf(stderr, "%s: failed to build graph\n", __func__); - return false; - } + ggml_backend_tensor_get(out, out_audio.data(), 0, out_length*ggml_element_size(out)); - size_t alloc_size = ggml_allocr_alloc_graph(ectx.allocr, gf_measure) + TENSOR_ALIGNMENT; - ggml_allocr_free(ectx.allocr); + return true; +} - // recreate allocator with exact memory requirements - ectx.buf_alloc.resize(alloc_size); - ectx.allocr = ggml_allocr_new(ectx.buf_alloc.data(), ectx.buf_alloc.size(), TENSOR_ALIGNMENT); +bool encodec_reconstruct_audio( + struct encodec_context * ectx, + std::vector & raw_audio, + int n_threads) { + const int64_t t_start_ms = ggml_time_ms(); - // compute the graph with the measured exact memory requirements from above - ggml_allocr_reset(ectx.allocr); + // allocate the compute buffer + { + // alignment required by the backend + size_t align = ggml_backend_get_alignment(ectx->model.backend); + ectx->allocr = ggml_allocr_new_measure(align); - struct ggml_cgraph * gf = encodec_build_graph(ectx, raw_audio); - if (!gf) { - fprintf(stderr, "%s: failed to build graph\n", __func__); - return false; - } + // create the graph for memory usage estimation + struct ggml_cgraph * gf = encodec_graph(ectx, raw_audio); - ggml_allocr_alloc_graph(ectx.allocr, gf); + // compute the required memory + size_t mem_size = ggml_allocr_alloc_graph(ectx->allocr, gf); - ggml_graph_compute_helper(ectx.work_buffer, gf, n_threads); + // recreate the allocator with the required memory + ggml_allocr_free(ectx->allocr); + ectx->buf_compute = ggml_backend_alloc_buffer(ectx->model.backend, mem_size); + ectx->allocr = ggml_allocr_new_from_buffer(ectx->buf_compute); - ggml_allocr_free(ectx.allocr); - ectx.allocr = NULL; - ectx.work_buffer.clear(); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0); + } + + printf("\n\n"); + + // encodec eval + if (!encodec_eval(ectx, raw_audio, n_threads)) { + fprintf(stderr, "%s: failed to run encodec eval\n", __func__); + return false; + } - ectx.t_compute_ms = ggml_time_ms() - t_start_ms; + ectx->t_compute_ms = ggml_time_ms() - t_start_ms; return true; } -std::shared_ptr encodec_load_model(const std::string & model_path) { +struct encodec_context * encodec_load_model(const std::string & model_path) { int64_t t_start_load_us = ggml_time_us(); - encodec_context ectx; + struct encodec_context * ectx = new encodec_context(); - ectx.model = std::make_unique(); - if (!encodec_load_model_weights(model_path, *ectx.model)) { + ectx->model = encodec_model(); + if (!encodec_load_model_weights(model_path, ectx->model)) { fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path.c_str()); return {}; } - ectx.t_load_us = ggml_time_us() - t_start_load_us; + ectx->t_load_us = ggml_time_us() - t_start_load_us; - return std::make_unique(std::move(ectx)); + return ectx; } -void encodec_free(encodec_context & ectx) { - if (ectx.ctx_audio) { - ggml_free(ectx.ctx_audio); +void encodec_free(struct encodec_context * ectx) { + if (!ectx) { + return; + } + + if (ectx->model.ctx) { + ggml_free(ectx->model.ctx); } + + if (ectx->buf_compute) { + ggml_backend_buffer_free(ectx->buf_compute); + } + + ggml_backend_buffer_free(ectx->model.buffer_w); + ggml_backend_free(ectx->model.backend); + + delete ectx; } diff --git a/encodec.h b/encodec.h index 00dcb1cf0..7e62eb206 100644 --- a/encodec.h +++ b/encodec.h @@ -9,9 +9,9 @@ #include #include "ggml.h" +#include "ggml-backend.h" #define ENCODEC_FILE_MAGIC 'ggml' -#define ENCODEC_FILE_VERSION 1 static const size_t MB = 1024*1024; @@ -139,34 +139,35 @@ struct encodec_model { struct ggml_context * ctx; int n_loaded; + ggml_backend_t backend = NULL; + + ggml_backend_buffer_t buffer_w; + std::map tensors; }; struct encodec_context { - std::unique_ptr model; - - struct ggml_context * ctx_audio; - struct ggml_tensor * reconstructed_audio; + encodec_model model; - // buffer for `ggml_graph_plan.work_data` - std::vector work_buffer; + // buffer for model evaluation + ggml_backend_buffer_t buf_compute; - // buffers to evaluate the model - std::vector buf_alloc; - std::vector buf_compute; + // custom allocrator + struct ggml_allocr * allocr = NULL; - struct ggml_allocr * allocr = {}; + // output audio + std::vector out_audio; // statistics int64_t t_load_us = 0; int64_t t_compute_ms = 0; }; -std::shared_ptr encodec_load_model(const std::string & model_path); +struct encodec_context * encodec_load_model(const std::string & model_path); bool encodec_reconstruct_audio( - encodec_context & ectx, + struct encodec_context * ectx, std::vector & raw_audio, int n_threads); -void encodec_free(encodec_context & ectx); \ No newline at end of file +void encodec_free(struct encodec_context * ectx); \ No newline at end of file diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 720c26ebf..9a20a364e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -77,7 +77,7 @@ bool read_wav_from_disk(std::string in_path, std::vector& audio_arr) { return false; } - fprintf(stderr, "%s: Number of frames read = %lld.\n", __func__, total_frame_count); + fprintf(stderr, "\n%s: Number of frames read = %lld.\n", __func__, total_frame_count); audio_arr.resize(total_frame_count); memcpy(audio_arr.data(), raw_audio, total_frame_count * sizeof(float)); @@ -115,7 +115,7 @@ int main(int argc, char **argv) { } // initialize encodec context - std::shared_ptr ectx = encodec_load_model(params.model_path); + struct encodec_context * ectx = encodec_load_model(params.model_path); if (!ectx) { printf("%s: error during loading model\n", __func__); return 1; @@ -128,15 +128,16 @@ int main(int argc, char **argv) { return 1; } + printf("\n"); + // reconstruct audio - if (!encodec_reconstruct_audio(*ectx, original_audio_arr, params.n_threads)) { + if (!encodec_reconstruct_audio(ectx, original_audio_arr, params.n_threads)) { printf("%s: error during inference\n", __func__); return 1; } // write reconstructed audio on disk - std::vector audio_arr(ectx->reconstructed_audio->ne[0]); - memcpy(ectx->reconstructed_audio->data, audio_arr.data(), audio_arr.size() * sizeof(float)); + auto & audio_arr = ectx->out_audio; write_wav_on_disk(audio_arr, params.dest_wav_path); // report timing @@ -149,7 +150,7 @@ int main(int argc, char **argv) { printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } - encodec_free(*ectx); + encodec_free(ectx); return 0; } \ No newline at end of file diff --git a/ggml b/ggml index a16b01d68..a7e0350b3 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit a16b01d6891fd885800988003d53755c9574c6e4 +Subproject commit a7e0350b3e74f42d2aaf12202cc9dfe47467aa39